diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a0528e4011..fc4fcd458b 100755 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -18,3 +18,4 @@ conda/ @rapidsai/ops-codeowners **/Dockerfile @rapidsai/ops-codeowners **/.dockerignore @rapidsai/ops-codeowners docker/ @rapidsai/ops-codeowners +dependencies.yaml @rapidsai/ops-codeowners diff --git a/.github/labeler.yml b/.github/labeler.yml index 9809e2cc2e..56f77e69c0 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -12,5 +12,5 @@ CMake: - '**/CMakeLists.txt' - '**/cmake/**' -gpuCI: +ci: - 'ci/**' diff --git a/.github/ops-bot.yaml b/.github/ops-bot.yaml index 236696d948..2d1444c595 100644 --- a/.github/ops-bot.yaml +++ b/.github/ops-bot.yaml @@ -6,3 +6,4 @@ branch_checker: true label_checker: true release_drafter: true copy_prs: true +recently_updated: true diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000000..7a780e3de1 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,97 @@ +name: build + +on: + push: + branches: + - "branch-*" + tags: + - v[0-9][0-9].[0-9][0-9].[0-9][0-9] + workflow_dispatch: + inputs: + branch: + required: true + type: string + date: + required: true + type: string + sha: + required: true + type: string + build_type: + type: string + default: nightly + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + cpp-build: + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.02 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + date: ${{ inputs.date }} + sha: ${{ inputs.sha }} + python-build: + needs: [cpp-build] + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.02 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + date: ${{ inputs.date }} + sha: ${{ inputs.sha }} + upload-conda: + needs: [cpp-build, python-build] + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/conda-upload-packages.yaml@branch-23.02 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + date: ${{ inputs.date }} + sha: ${{ inputs.sha }} + wheel-build-pylibraft: + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.02 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + package-name: pylibraft + package-dir: python/pylibraft + skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + wheel-publish-pylibraft: + needs: wheel-build-pylibraft + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@branch-23.02 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + package-name: pylibraft + wheel-build-raft-dask: + needs: wheel-publish-pylibraft + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.02 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + package-name: raft_dask + package-dir: python/raft-dask + skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + wheel-publish-raft-dask: + needs: wheel-build-raft-dask + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@branch-23.02 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + package-name: raft_dask diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml new file mode 100644 index 0000000000..03b66debb8 --- /dev/null +++ b/.github/workflows/pr.yaml @@ -0,0 +1,96 @@ +name: pr + +on: + push: + branches: + - "pull-request/[0-9]+" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + pr-builder: + needs: + - checks + - conda-cpp-build + - conda-cpp-tests + - conda-python-build + - conda-python-tests + - wheel-build-pylibraft + - wheel-tests-pylibraft + - wheel-build-raft-dask + - wheel-tests-raft-dask + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/pr-builder.yaml@branch-23.02 + checks: + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/checks.yaml@branch-23.02 + conda-cpp-build: + needs: checks + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.02 + with: + build_type: pull-request + node_type: cpu16 + conda-cpp-tests: + needs: conda-cpp-build + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.02 + with: + build_type: pull-request + conda-python-build: + needs: conda-cpp-build + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.02 + with: + build_type: pull-request + conda-python-tests: + needs: conda-python-build + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.02 + with: + build_type: pull-request + wheel-build-pylibraft: + needs: checks + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.02 + with: + build_type: pull-request + package-name: pylibraft + package-dir: python/pylibraft + skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + wheel-tests-pylibraft: + needs: wheel-build-pylibraft + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.02 + with: + build_type: pull-request + package-name: pylibraft + test-before-amd64: "pip install cupy-cuda11x" + # On arm also need to install cupy from the specific webpage. + test-before-arm64: "pip install cupy-cuda11x -f https://pip.cupy.dev/aarch64" + test-unittest: "python -m pytest -v ./python/pylibraft/pylibraft/test" + test-smoketest: "python ./ci/wheel_smoke_test_pylibraft.py" + wheel-build-raft-dask: + needs: wheel-tests-pylibraft + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.02 + with: + build_type: pull-request + package-name: raft_dask + package-dir: python/raft-dask + before-wheel: "RAPIDS_PY_WHEEL_NAME=pylibraft_cu11 rapids-download-wheels-from-s3 ./local-wheelhouse" + skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + wheel-tests-raft-dask: + needs: wheel-build-raft-dask + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.02 + with: + build_type: pull-request + package-name: raft_dask + # Always want to test against latest dask/distributed. + test-before-amd64: "RAPIDS_PY_WHEEL_NAME=pylibraft_cu11 rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@2023.1.1 git+https://github.com/dask/distributed.git@2023.1.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.02" + test-before-arm64: "RAPIDS_PY_WHEEL_NAME=pylibraft_cu11 rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@2023.1.1 git+https://github.com/dask/distributed.git@2023.1.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.02" + test-unittest: "python -m pytest -v ./python/raft-dask/raft_dask/test" + test-smoketest: "python ./ci/wheel_smoke_test_raft_dask.py" diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000000..739f50861e --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,56 @@ +name: test + +on: + workflow_dispatch: + inputs: + branch: + required: true + type: string + date: + required: true + type: string + sha: + required: true + type: string + +jobs: + conda-cpp-tests: + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.02 + with: + build_type: nightly + branch: ${{ inputs.branch }} + date: ${{ inputs.date }} + sha: ${{ inputs.sha }} + conda-python-tests: + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.02 + with: + build_type: nightly + branch: ${{ inputs.branch }} + date: ${{ inputs.date }} + sha: ${{ inputs.sha }} + wheel-tests-pylibraft: + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.02 + with: + build_type: nightly + branch: ${{ inputs.branch }} + date: ${{ inputs.date }} + sha: ${{ inputs.sha }} + package-name: pylibraft + test-before-amd64: "pip install cupy-cuda11x" + test-before-arm64: "pip install cupy-cuda11x -f https://pip.cupy.dev/aarch64" + test-unittest: "python -m pytest -v ./python/pylibraft/pylibraft/test" + wheel-tests-raft-dask: + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.02 + with: + build_type: nightly + branch: ${{ inputs.branch }} + date: ${{ inputs.date }} + sha: ${{ inputs.sha }} + package-name: raft_dask + test-before-amd64: "pip install git+https://github.com/dask/dask.git@2023.1.1 git+https://github.com/dask/distributed.git@2023.1.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.02" + test-before-arm64: "pip install git+https://github.com/dask/dask.git@2023.1.1 git+https://github.com/dask/distributed.git@2023.1.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.02" + test-unittest: "python -m pytest -v ./python/raft-dask/raft_dask/test" diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml deleted file mode 100644 index 0a681b864b..0000000000 --- a/.github/workflows/wheels.yml +++ /dev/null @@ -1,72 +0,0 @@ -name: RAFT wheels - -on: - workflow_call: - inputs: - versioneer-override: - type: string - default: '' - build-tag: - type: string - default: '' - branch: - required: true - type: string - date: - required: true - type: string - sha: - required: true - type: string - build-type: - type: string - default: nightly - -concurrency: - group: "raft-${{ github.workflow }}-${{ github.ref }}" - cancel-in-progress: true - -jobs: - pylibraft-wheel: - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux.yml@main - with: - repo: rapidsai/raft - - build-type: ${{ inputs.build-type }} - branch: ${{ inputs.branch }} - sha: ${{ inputs.sha }} - date: ${{ inputs.date }} - - package-dir: python/pylibraft - package-name: pylibraft - - python-package-versioneer-override: ${{ inputs.versioneer-override }} - python-package-build-tag: ${{ inputs.build-tag }} - - skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" - - test-extras: test - test-unittest: "python -m pytest -v ./python/pylibraft/pylibraft/test" - secrets: inherit - raft-dask-wheel: - needs: pylibraft-wheel - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux.yml@main - with: - repo: rapidsai/raft - - build-type: ${{ inputs.build-type }} - branch: ${{ inputs.branch }} - sha: ${{ inputs.sha }} - date: ${{ inputs.date }} - - package-dir: python/raft-dask - package-name: raft_dask - - python-package-versioneer-override: ${{ inputs.versioneer-override }} - python-package-build-tag: ${{ inputs.build-tag }} - - skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" - - test-extras: test - test-unittest: "python -m pytest -v ./python/raft-dask/raft_dask/test" - secrets: inherit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c244200d1..b766bfc066 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ repos: - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort # Use the config file specific to each subproject so that each @@ -97,6 +97,8 @@ repos: rev: v2.1.0 hooks: - id: codespell + exclude: (?x)^(^CHANGELOG.md$) + default_language_version: python: python3 diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ff454f7a0..c4701f587f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,105 @@ +# raft 23.02.00 (9 Feb 2023) + +## 🚨 Breaking Changes + +- Remove faiss ANN code from knnIndex ([#1121](https://github.com/rapidsai/raft/pull/1121)) [@benfred](https://github.com/benfred) +- Use `GenPC` (Permuted Congruential) as the default random number generator everywhere ([#1099](https://github.com/rapidsai/raft/pull/1099)) [@Nyrio](https://github.com/Nyrio) + +## 🐛 Bug Fixes + +- Reverting a few commits from 23.02 and speeding up end-to-end build time ([#1232](https://github.com/rapidsai/raft/pull/1232)) [@cjnolet](https://github.com/cjnolet) +- Update README.md: fix a missing word ([#1185](https://github.com/rapidsai/raft/pull/1185)) [@achirkin](https://github.com/achirkin) +- balanced-k-means: fix a too large initial memory pool size ([#1148](https://github.com/rapidsai/raft/pull/1148)) [@achirkin](https://github.com/achirkin) +- Catch signal handler change error ([#1147](https://github.com/rapidsai/raft/pull/1147)) [@tfeher](https://github.com/tfeher) +- Squared norm fix follow-up (change was lost in merge conflict) ([#1144](https://github.com/rapidsai/raft/pull/1144)) [@Nyrio](https://github.com/Nyrio) +- IVF-Flat bug fix: the *squared* norm is required for expanded distance calculations ([#1141](https://github.com/rapidsai/raft/pull/1141)) [@Nyrio](https://github.com/Nyrio) +- build.sh switch to use `RAPIDS` magic value ([#1132](https://github.com/rapidsai/raft/pull/1132)) [@robertmaynard](https://github.com/robertmaynard) +- Fix `euclidean_dist` in IVF-Flat search ([#1122](https://github.com/rapidsai/raft/pull/1122)) [@Nyrio](https://github.com/Nyrio) +- Update handle docstring ([#1103](https://github.com/rapidsai/raft/pull/1103)) [@dantegd](https://github.com/dantegd) +- Pin libcusparse and libcusolver to avoid CUDA 12 ([#1095](https://github.com/rapidsai/raft/pull/1095)) [@wphicks](https://github.com/wphicks) +- Fix race condition in `raft::random::discrete` ([#1094](https://github.com/rapidsai/raft/pull/1094)) [@Nyrio](https://github.com/Nyrio) +- Fixing libraft conda recipes ([#1084](https://github.com/rapidsai/raft/pull/1084)) [@cjnolet](https://github.com/cjnolet) +- Ensure that we get the cuda version of faiss. ([#1078](https://github.com/rapidsai/raft/pull/1078)) [@vyasr](https://github.com/vyasr) +- Fix double definition error in ANN refinement header ([#1067](https://github.com/rapidsai/raft/pull/1067)) [@tfeher](https://github.com/tfeher) +- Specify correct global targets names to raft_export ([#1054](https://github.com/rapidsai/raft/pull/1054)) [@robertmaynard](https://github.com/robertmaynard) +- Fix concurrency issues in k-means++ initialization ([#1048](https://github.com/rapidsai/raft/pull/1048)) [@Nyrio](https://github.com/Nyrio) + +## 📖 Documentation + +- Adding small comms tutorial to docs ([#1204](https://github.com/rapidsai/raft/pull/1204)) [@cjnolet](https://github.com/cjnolet) +- Separating more namespaces into easier-to-consume sections ([#1091](https://github.com/rapidsai/raft/pull/1091)) [@cjnolet](https://github.com/cjnolet) +- Paying down some tech debt on docs, runtime API, and cython ([#1055](https://github.com/rapidsai/raft/pull/1055)) [@cjnolet](https://github.com/cjnolet) + +## 🚀 New Features + +- Add function to convert mdspan to a const view ([#1188](https://github.com/rapidsai/raft/pull/1188)) [@lowener](https://github.com/lowener) +- Internal library to share headers between test and bench ([#1162](https://github.com/rapidsai/raft/pull/1162)) [@achirkin](https://github.com/achirkin) +- Add public API and tests for hierarchical balanced k-means ([#1113](https://github.com/rapidsai/raft/pull/1113)) [@Nyrio](https://github.com/Nyrio) +- Export NCCL dependency as part of raft::distributed. ([#1077](https://github.com/rapidsai/raft/pull/1077)) [@vyasr](https://github.com/vyasr) +- Serialization of IVF Flat and IVF PQ ([#919](https://github.com/rapidsai/raft/pull/919)) [@tfeher](https://github.com/tfeher) + +## 🛠️ Improvements + +- Pin `dask` and `distributed` for release ([#1242](https://github.com/rapidsai/raft/pull/1242)) [@galipremsagar](https://github.com/galipremsagar) +- Update shared workflow branches ([#1241](https://github.com/rapidsai/raft/pull/1241)) [@ajschmidt8](https://github.com/ajschmidt8) +- Removing interruptible from basic handle sync. ([#1224](https://github.com/rapidsai/raft/pull/1224)) [@cjnolet](https://github.com/cjnolet) +- pre-commit: Update isort version to 5.12.0 ([#1215](https://github.com/rapidsai/raft/pull/1215)) [@wence-](https://github.com/wence-) +- Pin wheel dependencies to same RAPIDS release ([#1200](https://github.com/rapidsai/raft/pull/1200)) [@sevagh](https://github.com/sevagh) +- Serializer for mdspans ([#1173](https://github.com/rapidsai/raft/pull/1173)) [@hcho3](https://github.com/hcho3) +- Use CTK 118/cp310 branch of wheel workflows ([#1169](https://github.com/rapidsai/raft/pull/1169)) [@sevagh](https://github.com/sevagh) +- Enable shallow copy of `handle_t`'s resources with different workspace_resource ([#1165](https://github.com/rapidsai/raft/pull/1165)) [@cjnolet](https://github.com/cjnolet) +- Protect balanced k-means out-of-memory in some cases ([#1161](https://github.com/rapidsai/raft/pull/1161)) [@achirkin](https://github.com/achirkin) +- Use squeuclidean for metric name in ivf_pq python bindings ([#1160](https://github.com/rapidsai/raft/pull/1160)) [@benfred](https://github.com/benfred) +- ANN tests: make the min_recall check strict ([#1156](https://github.com/rapidsai/raft/pull/1156)) [@achirkin](https://github.com/achirkin) +- Make cutlass use static ctk ([#1155](https://github.com/rapidsai/raft/pull/1155)) [@sevagh](https://github.com/sevagh) +- Fix various build errors ([#1152](https://github.com/rapidsai/raft/pull/1152)) [@hcho3](https://github.com/hcho3) +- Remove faiss bfKnn call from fused_l2_knn unittest ([#1150](https://github.com/rapidsai/raft/pull/1150)) [@benfred](https://github.com/benfred) +- Fix `unary_op` docs and add `map_offset` as an improved version of `write_only_unary_op` ([#1149](https://github.com/rapidsai/raft/pull/1149)) [@Nyrio](https://github.com/Nyrio) +- Improvement of the math API wrappers ([#1146](https://github.com/rapidsai/raft/pull/1146)) [@Nyrio](https://github.com/Nyrio) +- Changing handle_t to device_resources everywhere ([#1140](https://github.com/rapidsai/raft/pull/1140)) [@cjnolet](https://github.com/cjnolet) +- Add L2SqrtExpanded support to ivf_pq ([#1138](https://github.com/rapidsai/raft/pull/1138)) [@benfred](https://github.com/benfred) +- Adding workspace resource ([#1137](https://github.com/rapidsai/raft/pull/1137)) [@cjnolet](https://github.com/cjnolet) +- Add raft::void_op functor ([#1136](https://github.com/rapidsai/raft/pull/1136)) [@ahendriksen](https://github.com/ahendriksen) +- IVF-PQ: tighten the test criteria ([#1135](https://github.com/rapidsai/raft/pull/1135)) [@achirkin](https://github.com/achirkin) +- Fix documentation author ([#1134](https://github.com/rapidsai/raft/pull/1134)) [@bdice](https://github.com/bdice) +- Add L2SqrtExpanded support to ivf_flat ANN indices ([#1133](https://github.com/rapidsai/raft/pull/1133)) [@benfred](https://github.com/benfred) +- Improvements in `matrix::gather`: test coverage, compilation errors, performance ([#1126](https://github.com/rapidsai/raft/pull/1126)) [@Nyrio](https://github.com/Nyrio) +- Adding ability to use an existing stream in the pylibraft Handle ([#1125](https://github.com/rapidsai/raft/pull/1125)) [@cjnolet](https://github.com/cjnolet) +- Remove faiss ANN code from knnIndex ([#1121](https://github.com/rapidsai/raft/pull/1121)) [@benfred](https://github.com/benfred) +- Update builds for CUDA `11.8` and Python `3.10` ([#1120](https://github.com/rapidsai/raft/pull/1120)) [@ajschmidt8](https://github.com/ajschmidt8) +- Update workflows for nightly tests ([#1119](https://github.com/rapidsai/raft/pull/1119)) [@ajschmidt8](https://github.com/ajschmidt8) +- Enable `Recently Updated` Check ([#1117](https://github.com/rapidsai/raft/pull/1117)) [@ajschmidt8](https://github.com/ajschmidt8) +- Build wheels alongside conda CI ([#1116](https://github.com/rapidsai/raft/pull/1116)) [@sevagh](https://github.com/sevagh) +- Allow host dataset for IVF-PQ ([#1114](https://github.com/rapidsai/raft/pull/1114)) [@tfeher](https://github.com/tfeher) +- Decoupling raft handle from underlying resources ([#1111](https://github.com/rapidsai/raft/pull/1111)) [@cjnolet](https://github.com/cjnolet) +- Fixing an index error introduced in PR #1109 ([#1110](https://github.com/rapidsai/raft/pull/1110)) [@vinaydes](https://github.com/vinaydes) +- Fixing the sample-without-replacement test failures ([#1109](https://github.com/rapidsai/raft/pull/1109)) [@vinaydes](https://github.com/vinaydes) +- Remove faiss dependency from fused_l2_knn.cuh, selection_faiss.cuh, ball_cover.cuh and haversine_distance.cuh ([#1108](https://github.com/rapidsai/raft/pull/1108)) [@benfred](https://github.com/benfred) +- Remove redundant operators in sparse/distance and move others to raft/core ([#1105](https://github.com/rapidsai/raft/pull/1105)) [@Nyrio](https://github.com/Nyrio) +- Speedup `make_blobs` by up to 2x by fixing inefficient kernel launch configuration ([#1100](https://github.com/rapidsai/raft/pull/1100)) [@Nyrio](https://github.com/Nyrio) +- Use `GenPC` (Permuted Congruential) as the default random number generator everywhere ([#1099](https://github.com/rapidsai/raft/pull/1099)) [@Nyrio](https://github.com/Nyrio) +- Cleanup faiss includes ([#1098](https://github.com/rapidsai/raft/pull/1098)) [@benfred](https://github.com/benfred) +- matrix::select_k: move selection and warp-sort primitives ([#1085](https://github.com/rapidsai/raft/pull/1085)) [@achirkin](https://github.com/achirkin) +- Exclude changelog from pre-commit spellcheck ([#1083](https://github.com/rapidsai/raft/pull/1083)) [@benfred](https://github.com/benfred) +- Add GitHub Actions Workflows. ([#1076](https://github.com/rapidsai/raft/pull/1076)) [@bdice](https://github.com/bdice) +- Adding uninstall option to build.sh ([#1075](https://github.com/rapidsai/raft/pull/1075)) [@cjnolet](https://github.com/cjnolet) +- Use doctest for testing python example docstrings ([#1073](https://github.com/rapidsai/raft/pull/1073)) [@benfred](https://github.com/benfred) +- Minor cython fixes / cleanup ([#1072](https://github.com/rapidsai/raft/pull/1072)) [@benfred](https://github.com/benfred) +- IVF-PQ: tweak launch configuration ([#1069](https://github.com/rapidsai/raft/pull/1069)) [@achirkin](https://github.com/achirkin) +- Unpin `dask` and `distributed` for development ([#1068](https://github.com/rapidsai/raft/pull/1068)) [@galipremsagar](https://github.com/galipremsagar) +- Bifurcate Dependency Lists ([#1065](https://github.com/rapidsai/raft/pull/1065)) [@ajschmidt8](https://github.com/ajschmidt8) +- Add support for 64bit svdeig ([#1060](https://github.com/rapidsai/raft/pull/1060)) [@lowener](https://github.com/lowener) +- switch mma instruction shape to 1684 from current 1688 for 3xTF32 L2/cosine kernel ([#1057](https://github.com/rapidsai/raft/pull/1057)) [@mdoijade](https://github.com/mdoijade) +- Make IVF-PQ build index in batches when necessary ([#1056](https://github.com/rapidsai/raft/pull/1056)) [@achirkin](https://github.com/achirkin) +- Remove unused setuputils modules ([#1053](https://github.com/rapidsai/raft/pull/1053)) [@vyasr](https://github.com/vyasr) +- Branch 23.02 merge 22.12 ([#1051](https://github.com/rapidsai/raft/pull/1051)) [@benfred](https://github.com/benfred) +- Shared-memory-cached kernel for `reduce_cols_by_key` to limit atomic conflicts ([#1050](https://github.com/rapidsai/raft/pull/1050)) [@Nyrio](https://github.com/Nyrio) +- Unify use of common functors ([#1049](https://github.com/rapidsai/raft/pull/1049)) [@Nyrio](https://github.com/Nyrio) +- Replace k-means++ CPU bottleneck with a `random::discrete` prim ([#1039](https://github.com/rapidsai/raft/pull/1039)) [@Nyrio](https://github.com/Nyrio) +- Add python bindings for kmeans fit ([#1016](https://github.com/rapidsai/raft/pull/1016)) [@benfred](https://github.com/benfred) +- Add MaskedL2NN ([#838](https://github.com/rapidsai/raft/pull/838)) [@ahendriksen](https://github.com/ahendriksen) +- Move contractions tiling logic outside of Contractions_NT ([#837](https://github.com/rapidsai/raft/pull/837)) [@ahendriksen](https://github.com/ahendriksen) + # raft 22.12.00 (8 Dec 2022) ## 🚨 Breaking Changes diff --git a/README.md b/README.md index e48a1b6193..ccd0df4926 100755 --- a/README.md +++ b/README.md @@ -25,8 +25,8 @@ While not exhaustive, the following general categories help summarize the accele | Category | Examples | | --- | --- | | **Data Formats** | sparse & dense, conversions, data generation | -| **Dense Operations** | linear algebra, matrix and vector operations, slicing, norms, factorization, least squares, svd & eigenvalue problems | -| **Sparse Operations** | linear algebra, eigenvalue problems, slicing, symmetrization, components & labeling | +| **Dense Operations** | linear algebra, matrix and vector operations, reductions, slicing, norms, factorization, least squares, svd & eigenvalue problems | +| **Sparse Operations** | linear algebra, eigenvalue problems, slicing, norms, reductions, factorization, symmetrization, components & labeling | | **Spatial** | pairwise distances, nearest neighbors, neighborhood graph construction | | **Basic Clustering** | spectral clustering, hierarchical clustering, k-means | | **Solvers** | combinatorial optimization, iterative solvers | @@ -37,7 +37,7 @@ While not exhaustive, the following general categories help summarize the accele All of RAFT's C++ APIs can be accessed header-only and optional pre-compiled shared libraries can 1) speed up compile times and 2) enable the APIs to be used without CUDA-enabled compilers. In addition to the C++ library, RAFT also provides 2 Python libraries: -- `pylibraft` - lightweight low-level Python wrappers around RAFT's host-accessible APIs. +- `pylibraft` - lightweight low-level Python wrappers around RAFT's host-accessible "runtime" APIs. - `raft-dask` - multi-node multi-GPU communicator infrastructure for building distributed algorithms on the GPU with Dask. ## Getting started @@ -65,17 +65,17 @@ auto matrix = raft::make_device_matrix(handle, n_rows, n_cols); ### C++ Example -Most of the primitives in RAFT accept a `raft::handle_t` object for the management of resources which are expensive to create, such CUDA streams, stream pools, and handles to other CUDA libraries like `cublas` and `cusolver`. +Most of the primitives in RAFT accept a `raft::device_resources` object for the management of resources which are expensive to create, such CUDA streams, stream pools, and handles to other CUDA libraries like `cublas` and `cusolver`. The example below demonstrates creating a RAFT handle and using it with `device_matrix` and `device_vector` to allocate memory, generating random clusters, and computing pairwise Euclidean distances: ```c++ -#include +#include #include #include #include -raft::handle_t handle; +raft::device_resources handle; int n_samples = 5000; int n_features = 50; @@ -93,12 +93,12 @@ raft::distance::pairwise_distance(handle, input.view(), input.view(), output.vie It's also possible to create `raft::device_mdspan` views to invoke the same API with raw pointers and shape information: ```c++ -#include +#include #include #include #include -raft::handle_t handle; +raft::device_resources handle; int n_samples = 5000; int n_features = 50; @@ -142,7 +142,7 @@ in2 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32) output = pairwise_distance(in1, in2, metric="euclidean") ``` -The `output` array supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) so it is interoperable with other libraries like CuPy, Numba, and PyTorch that also support it. +The `output` array in the above example is of type `raft.common.device_ndarray`, which supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) making it interoperable with other libraries like CuPy, Numba, and PyTorch that also support it. CuPy supports DLPack, which also enables zero-copy conversion from `raft.common.device_ndarray` to JAX and Tensorflow. Below is an example of converting the output `pylibraft.device_ndarray` to a CuPy array: ```python @@ -156,6 +156,18 @@ import torch torch_tensor = torch.as_tensor(output, device='cuda') ``` +When the corresponding library has been installed and available in your environment, this conversion can also be done automatically by all RAFT compute APIs by setting a global configuration option: +```python +import pylibraft.config +pylibraft.config.set_output_as("cupy") # All compute APIs will return cupy arrays +pylibraft.config.set_output_as("torch") # All compute APIs will return torch tensors +``` + +You can also specify a `callable` that accepts a `pylibraft.common.device_ndarray` and performs a custom conversion. The following example converts all output to `numpy` arrays: +```python +pylibraft.config.set_output_as(lambda device_ndarray: return device_ndarray.copy_to_host()) +``` + `pylibraft` also supports writing to a pre-allocated output array so any `__cuda_array_interface__` supported array can be written to in-place: ```python @@ -176,7 +188,7 @@ pairwise_distance(in1, in2, out=output, metric="euclidean") ## Installing -RAFT itself can be installed through conda, [Cmake Package Manager (CPM)](https://github.com/cpm-cmake/CPM.cmake), pip, or by building the repository from source. Please refer to the [build instructions](docs/source/build.md) for more a comprehensive guide on building RAFT and using it in downstream projects. +RAFT itself can be installed through conda, [Cmake Package Manager (CPM)](https://github.com/cpm-cmake/CPM.cmake), pip, or by building the repository from source. Please refer to the [build instructions](docs/source/build.md) for more a comprehensive guide on installing and building RAFT and using it in downstream projects. ### Conda @@ -257,14 +269,15 @@ Several CMake targets can be made available by adding components in the table be | --- | --- | --- | --- | | n/a | `raft::raft` | Full RAFT header library | CUDA toolkit library, RMM, Thrust (optional), NVTools (optional) | | distance | `raft::distance` | Pre-compiled template specializations for raft::distance | raft::raft, cuCollections (optional) | -| nn | `raft::nn` | Pre-compiled template specializations for raft::spatial::knn | raft::raft, FAISS (optional) | +| nn | `raft::nn` | Pre-compiled template specializations for raft::neighbors | raft::raft, FAISS (optional) | +| distributed | `raft::distributed` | No specializations | raft::raft, UCX, NCCL | ### Source The easiest way to build RAFT from source is to use the `build.sh` script at the root of the repository: -1. Create an environment with the needed dependencies: +1. Create an environment with the needed dependencies: ``` -mamba env create --name raft_dev_env -f conda/environments/raft_dev_cuda11.5.yml +mamba env create --name raft_dev_env -f conda/environments/all_cuda-118_arch-x86_64.yaml mamba activate raft_dev_env ``` ``` @@ -302,6 +315,7 @@ The folder structure mirrors other RAPIDS repos, with the following folders: - `solver`: Sparse solvers for optimization and approximation - `stats`: Moments, summary statistics, model performance measures - `util`: Various reusable tools and utilities for accelerated algorithm development + - `internal`: A private header-only component that hosts the code shared between benchmarks and tests. - `scripts`: Helpful scripts for development - `src`: Compiled APIs and template specializations for the shared libraries - `test`: Googletests source code diff --git a/build.sh b/build.sh index 0708c1b89e..b47e1ed862 100755 --- a/build.sh +++ b/build.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # raft build script @@ -18,7 +18,7 @@ ARGS=$* # script, and that this script resides in the repo dir! REPODIR=$(cd $(dirname $0); pwd) -VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean -v -g -n --compile-libs --compile-nn --compile-dist --allgpuarch --no-nvtx --show_depr_warn -h --buildfaiss --minimal-deps" +VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean --uninstall -v -g -n --compile-libs --compile-nn --compile-dist --allgpuarch --no-nvtx --show_depr_warn -h --buildfaiss --minimal-deps" HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=] [--limit-tests=] [--limit-bench=] where is: clean - remove all existing build artifacts and configuration (start over) @@ -34,6 +34,7 @@ HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=\"] [--cache-tool= /dev/null 2>&1 + fi + fi + + if hasArg pylibraft || (( ${NUMARGS} == 1 )); then + echo "Uninstalling pylibraft package..." + if [ -e ${PYLIBRAFT_BUILD_DIR}/install_manifest.txt ]; then + xargs rm -fv < ${PYLIBRAFT_BUILD_DIR}/install_manifest.txt > /dev/null 2>&1 + fi + + # Try to uninstall via pip if it is installed + if [ -x "$(command -v pip)" ]; then + echo "Using pip to uninstall pylibraft" + pip uninstall -y pylibraft + + # Otherwise, try to uninstall through conda if that's where things are installed + elif [ -x "$(command -v conda)" ] && [ "$INSTALL_PREFIX" == "$CONDA_PREFIX" ]; then + echo "Using conda to uninstall pylibraft" + conda uninstall -y pylibraft + + # Otherwise, fail + else + echo "Could not uninstall pylibraft from pip or conda. pylibraft package will need to be manually uninstalled" + fi + fi + + if hasArg raft-dask || (( ${NUMARGS} == 1 )); then + echo "Uninstalling raft-dask package..." + if [ -e ${RAFT_DASK_BUILD_DIR}/install_manifest.txt ]; then + xargs rm -fv < ${RAFT_DASK_BUILD_DIR}/install_manifest.txt > /dev/null 2>&1 + fi + + # Try to uninstall via pip if it is installed + if [ -x "$(command -v pip)" ]; then + echo "Using pip to uninstall raft-dask" + pip uninstall -y raft-dask + + # Otherwise, try to uninstall through conda if that's where things are installed + elif [ -x "$(command -v conda)" ] && [ "$INSTALL_PREFIX" == "$CONDA_PREFIX" ]; then + echo "Using conda to uninstall raft-dask" + conda uninstall -y raft-dask + + # Otherwise, fail + else + echo "Could not uninstall raft-dask from pip or conda. raft-dask package will need to be manually uninstalled." + fi + fi + exit 0 +fi + + # Process flags if hasArg -n; then INSTALL_TARGET="" @@ -286,9 +347,8 @@ fi if hasArg clean; then CLEAN=1 fi -if hasArg uninstall; then - UNINSTALL=1 -fi + + if [[ ${CMAKE_TARGET} == "" ]]; then CMAKE_TARGET="all" @@ -328,7 +388,7 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has RAFT_CMAKE_CUDA_ARCHITECTURES="NATIVE" echo "Building for the architecture of the GPU in the system..." else - RAFT_CMAKE_CUDA_ARCHITECTURES="ALL" + RAFT_CMAKE_CUDA_ARCHITECTURES="RAPIDS" echo "Building for *ALL* supported GPU architectures..." fi @@ -370,7 +430,7 @@ if (( ${NUMARGS} == 0 )) || hasArg raft-dask; then fi cd ${REPODIR}/python/raft-dask - python setup.py build_ext --inplace -- -DCMAKE_PREFIX_PATH="${LIBRAFT_BUILD_DIR};${INSTALL_PREFIX}" -DCMAKE_LIBRARY_PATH=${LIBRAFT_BUILD_DIR} ${EXTRA_CMAKE_ARGS} -- -j${PARALLEL_LEVEL:-1} + python setup.py build_ext --inplace -- -DCMAKE_PREFIX_PATH="${RAFT_DASK_BUILD_DIR};${INSTALL_PREFIX}" -DCMAKE_LIBRARY_PATH=${LIBRAFT_BUILD_DIR} ${EXTRA_CMAKE_ARGS} -- -j${PARALLEL_LEVEL:-1} if [[ ${INSTALL_TARGET} != "" ]]; then python setup.py install --single-version-externally-managed --record=record.txt -- -DCMAKE_PREFIX_PATH=${INSTALL_PREFIX} ${EXTRA_CMAKE_ARGS} fi @@ -384,7 +444,7 @@ if (( ${NUMARGS} == 0 )) || hasArg pylibraft; then fi cd ${REPODIR}/python/pylibraft - python setup.py build_ext --inplace -- -DCMAKE_PREFIX_PATH="${LIBRAFT_BUILD_DIR};${INSTALL_PREFIX}" -DCMAKE_LIBRARY_PATH=${LIBRAFT_BUILD_DIR} ${EXTRA_CMAKE_ARGS} -- -j${PARALLEL_LEVEL:-1} + python setup.py build_ext --inplace -- -DCMAKE_PREFIX_PATH="${RAFT_DASK_BUILD_DIR};${INSTALL_PREFIX}" -DCMAKE_LIBRARY_PATH=${LIBRAFT_BUILD_DIR} ${EXTRA_CMAKE_ARGS} -- -j${PARALLEL_LEVEL:-1} if [[ ${INSTALL_TARGET} != "" ]]; then python setup.py install --single-version-externally-managed --record=record.txt -- -DCMAKE_PREFIX_PATH=${INSTALL_PREFIX} ${EXTRA_CMAKE_ARGS} fi diff --git a/ci/build_cpp.sh b/ci/build_cpp.sh new file mode 100755 index 0000000000..853ae095d3 --- /dev/null +++ b/ci/build_cpp.sh @@ -0,0 +1,16 @@ +#!/bin/bash +# Copyright (c) 2022, NVIDIA CORPORATION. + +set -euo pipefail + +source rapids-env-update + +export CMAKE_GENERATOR=Ninja + +rapids-print-env + +rapids-logger "Begin cpp build" + +rapids-mamba-retry mambabuild conda/recipes/libraft + +rapids-upload-conda-to-s3 cpp diff --git a/ci/build_python.sh b/ci/build_python.sh new file mode 100755 index 0000000000..b20fd51bca --- /dev/null +++ b/ci/build_python.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright (c) 2022, NVIDIA CORPORATION. + +set -euo pipefail + +source rapids-env-update + +export CMAKE_GENERATOR=Ninja + +rapids-print-env + +rapids-logger "Begin py build" + +CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) + +# TODO: Remove `--no-test` flags once importing on a CPU +# node works correctly +rapids-mamba-retry mambabuild \ + --no-test \ + --channel "${CPP_CHANNEL}" \ + conda/recipes/pylibraft + +rapids-mamba-retry mambabuild \ + --no-test \ + --channel "${CPP_CHANNEL}" \ + --channel "${RAPIDS_CONDA_BLD_OUTPUT_DIR}" \ + conda/recipes/raft-dask + +rapids-upload-conda-to-s3 python diff --git a/ci/check_style.sh b/ci/check_style.sh new file mode 100755 index 0000000000..be3ac3f4b8 --- /dev/null +++ b/ci/check_style.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright (c) 2020-2022, NVIDIA CORPORATION. + +set -euo pipefail + +rapids-logger "Create checks conda environment" +. /opt/conda/etc/profile.d/conda.sh + +rapids-dependency-file-generator \ + --output conda \ + --file_key checks \ + --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml + +rapids-mamba-retry env create --force -f env.yaml -n checks +conda activate checks + +# Run pre-commit checks +pre-commit run --hook-stage manual --all-files --show-diff-on-failure diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index bfef5392f5..43a4a186f8 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ re.compile(r"setup[.]cfg$"), re.compile(r"meta[.]yaml$") ] -ExemptFiles = ["cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh"] +ExemptFiles = ["cpp/include/raft/spatial/knn/detail/faiss_select/"] # this will break starting at year 10000, which is probably OK :) CheckSimple = re.compile( diff --git a/ci/cpu/build.sh b/ci/cpu/build.sh index b87b16d138..2f0e2b94ca 100755 --- a/ci/cpu/build.sh +++ b/ci/cpu/build.sh @@ -31,12 +31,19 @@ fi export GPUCI_CONDA_RETRY_MAX=1 export GPUCI_CONDA_RETRY_SLEEP=30 +# Workaround to keep Jenkins builds working +# until we migrate fully to GitHub Actions +export RAPIDS_CUDA_VERSION="${CUDA}" +export SCCACHE_BUCKET=rapids-sccache +export SCCACHE_REGION=us-west-2 +export SCCACHE_IDLE_TIMEOUT=32768 + # Use Ninja to build export CMAKE_GENERATOR="Ninja" export CONDA_BLD_DIR="${WORKSPACE}/.conda-bld" # ucx-py version -export UCX_PY_VERSION='0.29.*' +export UCX_PY_VERSION='0.30.*' ################################################################################ # SETUP - Check environment @@ -123,5 +130,6 @@ fi # UPLOAD - Conda packages ################################################################################ -gpuci_logger "Upload conda packages" -source ci/cpu/upload.sh +# Uploads disabled due to new GH Actions implementation +# gpuci_logger "Upload conda packages" +# source ci/cpu/upload.sh diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index eb0117cdc3..154edbc7f2 100644 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -21,6 +21,13 @@ export PARALLEL_LEVEL=${PARALLEL_LEVEL:-8} export CUDA_REL=${CUDA_VERSION%.*} CONDA_ARTIFACT_PATH=${WORKSPACE}/ci/artifacts/raft/cpu/.conda-bld/ # notice there is no `linux-64` here +# Workaround to keep Jenkins builds working +# until we migrate fully to GitHub Actions +export RAPIDS_CUDA_VERSION="${CUDA}" +export SCCACHE_BUCKET=rapids-sccache +export SCCACHE_REGION=us-west-2 +export SCCACHE_IDLE_TIMEOUT=32768 + # Set home to the job's workspace export HOME=$WORKSPACE @@ -31,13 +38,13 @@ export MINOR_VERSION=`echo $GIT_DESCRIBE_TAG | grep -o -E '([0-9]+\.[0-9]+)'` unset GIT_DESCRIBE_TAG # ucx-py version -export UCX_PY_VERSION='0.29.*' +export UCX_PY_VERSION='0.30.*' # Whether to install dask nightly or stable packages. export INSTALL_DASK_MAIN=0 # Dask version to install when `INSTALL_DASK_MAIN=0` -export DASK_STABLE_VERSION="2022.11.1" +export DASK_STABLE_VERSION="2023.1.1" ################################################################################ # SETUP - Check environment diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index 4e0ecd8e15..bd0ff1db7b 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. ######################## # RAFT Version Updater # ######################## @@ -17,12 +17,14 @@ CURRENT_MAJOR=$(echo $CURRENT_TAG | awk '{split($0, a, "."); print a[1]}') CURRENT_MINOR=$(echo $CURRENT_TAG | awk '{split($0, a, "."); print a[2]}') CURRENT_PATCH=$(echo $CURRENT_TAG | awk '{split($0, a, "."); print a[3]}') CURRENT_SHORT_TAG=${CURRENT_MAJOR}.${CURRENT_MINOR} +CURRENT_UCX_PY_VERSION="$(curl -sL https://version.gpuci.io/rapids/${CURRENT_SHORT_TAG}).*" #Get . for next version NEXT_MAJOR=$(echo $NEXT_FULL_TAG | awk '{split($0, a, "."); print a[1]}') NEXT_MINOR=$(echo $NEXT_FULL_TAG | awk '{split($0, a, "."); print a[2]}') NEXT_SHORT_TAG=${NEXT_MAJOR}.${NEXT_MINOR} -NEXT_UCX_PY_VERSION="$(curl -sL https://version.gpuci.io/rapids/${NEXT_SHORT_TAG}).*" +NEXT_UCX_PY_SHORT_TAG="$(curl -sL https://version.gpuci.io/rapids/${NEXT_SHORT_TAG})" +NEXT_UCX_PY_VERSION="${NEXT_UCX_PY_SHORT_TAG}.*" echo "Preparing release $CURRENT_TAG => $NEXT_FULL_TAG" @@ -41,7 +43,7 @@ sed_runner 's/'"branch-.*\/RAPIDS.cmake"'/'"branch-${NEXT_SHORT_TAG}\/RAPIDS.cma sed_runner 's/version = .*/version = '"'${NEXT_SHORT_TAG}'"'/g' docs/source/conf.py sed_runner 's/release = .*/release = '"'${NEXT_FULL_TAG}'"'/g' docs/source/conf.py -for FILE in conda/environments/*.yml; do +for FILE in conda/environments/*.yaml dependencies.yaml; do sed_runner "s/dask-cuda=${CURRENT_SHORT_TAG}/dask-cuda=${NEXT_SHORT_TAG}/g" ${FILE}; sed_runner "s/rapids-build-env=${CURRENT_SHORT_TAG}/rapids-build-env=${NEXT_SHORT_TAG}/g" ${FILE}; sed_runner "s/rapids-doc-env=${CURRENT_SHORT_TAG}/rapids-doc-env=${NEXT_SHORT_TAG}/g" ${FILE}; @@ -52,3 +54,22 @@ done sed_runner "s/export UCX_PY_VERSION=.*/export UCX_PY_VERSION='${NEXT_UCX_PY_VERSION}'/g" ci/gpu/build.sh sed_runner "s/export UCX_PY_VERSION=.*/export UCX_PY_VERSION='${NEXT_UCX_PY_VERSION}'/g" ci/cpu/build.sh +sed_runner "/^ucx_py_version:$/ {n;s/.*/ - \"${NEXT_UCX_PY_VERSION}\"/}" conda/recipes/raft-dask/conda_build_config.yaml + +# Wheel builds install dask-cuda from source, update its branch +sed_runner "s/dask-cuda.git@branch-[^\"\s]\+/dask-cuda.git@branch-${NEXT_SHORT_TAG}/g" .github/workflows/*.yaml + +# Need to distutils-normalize the original version +NEXT_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${NEXT_SHORT_TAG}'))") +NEXT_UCX_PY_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${NEXT_UCX_PY_SHORT_TAG}'))") + +# Wheel builds install intra-RAPIDS dependencies from same release +sed_runner "s/{cuda_suffix}[^\"].*\",/{cuda_suffix}==${NEXT_SHORT_TAG_PEP440}.*\",/g" python/pylibraft/setup.py +sed_runner "s/{cuda_suffix}.*\"\]/{cuda_suffix}==${NEXT_SHORT_TAG_PEP440}.*\"\]/g" python/pylibraft/_custom_build/backend.py +sed_runner "s/dask-cuda==.*\",/dask-cuda==${NEXT_SHORT_TAG_PEP440}.*\",/g" python/raft-dask/setup.py +sed_runner "s/pylibraft{cuda_suffix}.*\",/pylibraft{cuda_suffix}==${NEXT_SHORT_TAG_PEP440}.*\",/g" python/raft-dask/setup.py +sed_runner "s/ucx-py{cuda_suffix}.*\",/ucx-py{cuda_suffix}==${NEXT_UCX_PY_SHORT_TAG_PEP440}.*\",/g" python/raft-dask/setup.py + +for FILE in .github/workflows/*.yaml; do + sed_runner "/shared-action-workflows/ s/@.*/@branch-${NEXT_SHORT_TAG}/g" "${FILE}" +done diff --git a/ci/test_cpp.sh b/ci/test_cpp.sh new file mode 100755 index 0000000000..d8538bdf47 --- /dev/null +++ b/ci/test_cpp.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Copyright (c) 2022, NVIDIA CORPORATION. + +set -euo pipefail + +. /opt/conda/etc/profile.d/conda.sh + +rapids-logger "Generate C++ testing dependencies" +rapids-dependency-file-generator \ + --output conda \ + --file_key test_cpp \ + --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch)" | tee env.yaml + +rapids-mamba-retry env create --force -f env.yaml -n test + +# Temporarily allow unbound variables for conda activation. +set +u +conda activate test +set -u + +CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) +RAPIDS_TESTS_DIR=${RAPIDS_TESTS_DIR:-"${PWD}/test-results"}/ +mkdir -p "${RAPIDS_TESTS_DIR}" +SUITEERROR=0 + +rapids-print-env + +rapids-mamba-retry install \ + --channel "${CPP_CHANNEL}" \ + libraft-headers libraft-distance libraft-nn libraft-tests + +rapids-logger "Check GPU usage" +nvidia-smi + +set +e + +# Run libraft gtests from libraft-tests package +rapids-logger "Run gtests" + +# TODO: exit code handling is too verbose. Find a cleaner solution. + +for gt in "$CONDA_PREFIX"/bin/gtests/libraft/* ; do + test_name=$(basename ${gt}) + echo "Running gtest $test_name" + ${gt} --gtest_output=xml:${RAPIDS_TESTS_DIR} + + exitcode=$? + if (( ${exitcode} != 0 )); then + SUITEERROR=${exitcode} + echo "FAILED: GTest ${gt}" + fi +done + +exit ${SUITEERROR} diff --git a/ci/test_python.sh b/ci/test_python.sh new file mode 100755 index 0000000000..eb458d2a5a --- /dev/null +++ b/ci/test_python.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Copyright (c) 2022, NVIDIA CORPORATION. + +set -euo pipefail + +. /opt/conda/etc/profile.d/conda.sh + +rapids-logger "Generate Python testing dependencies" +rapids-dependency-file-generator \ + --output conda \ + --file_key test_python \ + --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml + +rapids-mamba-retry env create --force -f env.yaml -n test + +# Temporarily allow unbound variables for conda activation. +set +u +conda activate test +set -u + +rapids-logger "Downloading artifacts from previous jobs" +CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) +PYTHON_CHANNEL=$(rapids-download-conda-from-s3 python) + +RAPIDS_TESTS_DIR=${RAPIDS_TESTS_DIR:-"${PWD}/test-results"} +RAPIDS_COVERAGE_DIR=${RAPIDS_COVERAGE_DIR:-"${PWD}/coverage-results"} +mkdir -p "${RAPIDS_TESTS_DIR}" "${RAPIDS_COVERAGE_DIR}" +SUITEERROR=0 + +rapids-print-env + +rapids-mamba-retry install \ + --channel "${CPP_CHANNEL}" \ + --channel "${PYTHON_CHANNEL}" \ + libraft-distance libraft-headers pylibraft raft-dask + +rapids-logger "Check GPU usage" +nvidia-smi + +set +e + +rapids-logger "pytest pylibraft" +pushd python/pylibraft/pylibraft +pytest \ + --cache-clear \ + --junitxml="${RAPIDS_TESTS_DIR}/junit-pylibraft.xml" \ + --cov-config=../.coveragerc \ + --cov=pylibraft \ + --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/pylibraft-coverage.xml" \ + --cov-report=term \ + test +exitcode=$? + +if (( ${exitcode} != 0 )); then + SUITEERROR=${exitcode} + echo "FAILED: 1 or more tests in pylibraft" +fi +popd + +rapids-logger "pytest raft-dask" +pushd python/raft-dask/raft_dask +pytest \ + --cache-clear \ + --junitxml="${RAPIDS_TESTS_DIR}/junit-raft-dask.xml" \ + --cov-config=../.coveragerc \ + --cov=raft_dask \ + --cov-report=xml:"${RAPIDS_COVERAGE_DIR}/raft-dask-coverage.xml" \ + --cov-report=term \ + test +exitcode=$? + +if (( ${exitcode} != 0 )); then + SUITEERROR=${exitcode} + echo "FAILED: 1 or more tests in raft-dask" +fi +popd + +exit ${SUITEERROR} diff --git a/ci/wheel_smoke_test_pylibraft.py b/ci/wheel_smoke_test_pylibraft.py new file mode 100644 index 0000000000..7fee674691 --- /dev/null +++ b/ci/wheel_smoke_test_pylibraft.py @@ -0,0 +1,38 @@ +import numpy as np +from scipy.spatial.distance import cdist + +from pylibraft.common import Handle, Stream, device_ndarray +from pylibraft.distance import pairwise_distance + + +if __name__ == "__main__": + metric = "euclidean" + n_rows = 1337 + n_cols = 1337 + + input1 = np.random.random_sample((n_rows, n_cols)) + input1 = np.asarray(input1, order="C").astype(np.float64) + + output = np.zeros((n_rows, n_rows), dtype=np.float64) + + expected = cdist(input1, input1, metric) + + expected[expected <= 1e-5] = 0.0 + + input1_device = device_ndarray(input1) + output_device = None + + s2 = Stream() + handle = Handle(stream=s2) + ret_output = pairwise_distance( + input1_device, input1_device, output_device, metric, handle=handle + ) + handle.sync() + + output_device = ret_output + + actual = output_device.copy_to_host() + + actual[actual <= 1e-5] = 0.0 + + assert np.allclose(expected, actual, rtol=1e-4) diff --git a/ci/wheel_smoke_test_raft_dask.py b/ci/wheel_smoke_test_raft_dask.py new file mode 100644 index 0000000000..32c13e61ca --- /dev/null +++ b/ci/wheel_smoke_test_raft_dask.py @@ -0,0 +1,92 @@ +from dask.distributed import Client, wait +from dask_cuda import LocalCUDACluster, initialize + +from raft_dask.common import ( + Comms, + local_handle, + perform_test_comm_split, + perform_test_comms_allgather, + perform_test_comms_allreduce, + perform_test_comms_bcast, + perform_test_comms_device_multicast_sendrecv, + perform_test_comms_device_send_or_recv, + perform_test_comms_device_sendrecv, + perform_test_comms_gather, + perform_test_comms_gatherv, + perform_test_comms_reduce, + perform_test_comms_reducescatter, + perform_test_comms_send_recv, +) + +import os +os.environ["UCX_LOG_LEVEL"] = "error" + + +def func_test_send_recv(sessionId, n_trials): + handle = local_handle(sessionId) + return perform_test_comms_send_recv(handle, n_trials) + + +def func_test_collective(func, sessionId, root): + handle = local_handle(sessionId) + return func(handle, root) + + +if __name__ == "__main__": + # initial setup + cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) + client = Client(cluster) + + n_trials = 5 + root_location = "client" + + # p2p test for ucx + cb = Comms(comms_p2p=True, verbose=True) + cb.init() + + dfs = [ + client.submit( + func_test_send_recv, + cb.sessionId, + n_trials, + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + + wait(dfs, timeout=5) + + assert list(map(lambda x: x.result(), dfs)) + + cb.destroy() + + # collectives test for nccl + + cb = Comms( + verbose=True, client=client, nccl_root_location=root_location + ) + cb.init() + + for k, v in cb.worker_info(cb.worker_addresses).items(): + + dfs = [ + client.submit( + func_test_collective, + perform_test_comms_allgather, + cb.sessionId, + v["rank"], + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + wait(dfs, timeout=5) + + assert all([x.result() for x in dfs]) + + cb.destroy() + + # final client and cluster teardown + client.close() + cluster.close() diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml new file mode 100644 index 0000000000..7dc305bf97 --- /dev/null +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -0,0 +1,48 @@ +# This file is generated by `rapids-dependency-file-generator`. +# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. +channels: +- rapidsai +- rapidsai-nightly +- dask/label/dev +- conda-forge +- nvidia +dependencies: +- breathe +- c-compiler +- clang-tools=11.1.0 +- clang=11.1.0 +- cmake>=3.23.1,!=3.25.0 +- cuda-profiler-api=11.8.86 +- cuda-python >=11.7.1,<12.0 +- cudatoolkit=11.8 +- cupy +- cxx-compiler +- cython>=0.29,<0.30 +- dask-cuda=23.02 +- dask==2023.1.1 +- distributed==2023.1.1 +- doxygen>=1.8.20 +- faiss-proc=*=cuda +- gcc_linux-64=9 +- libcublas-dev=11.11.3.6 +- libcublas=11.11.3.6 +- libcurand-dev=10.3.0.86 +- libcurand=10.3.0.86 +- libcusolver-dev=11.4.1.48 +- libcusolver=11.4.1.48 +- libcusparse-dev=11.7.5.86 +- libcusparse=11.7.5.86 +- libfaiss>=1.7.1=cuda* +- ninja +- pytest +- pytest-cov +- rmm=23.02 +- scikit-build>=0.13.1 +- scikit-learn +- scipy +- sphinx-markdown-tables +- sysroot_linux-64==2.17 +- ucx-proc=*=gpu +- ucx-py=0.30 +- ucx>=1.13.0 +name: all_cuda-118_arch-x86_64 diff --git a/conda/environments/raft_dev_cuda11.2.yml b/conda/environments/raft_dev_cuda11.2.yml deleted file mode 100644 index 5330227aa4..0000000000 --- a/conda/environments/raft_dev_cuda11.2.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: raft_dev -channels: -- rapidsai -- rapidsai-nightly -- dask/label/dev -- conda-forge -- nvidia -dependencies: -- c-compiler -- cxx-compiler -- cudatoolkit=11.2 -- cuda-python >=11.7.1,<12.0 -- ninja -- clang=11.1.0 -- clang-tools=11.1.0 -- cython>=0.29,<0.30 -- cmake>=3.23.1,!=3.25.0 -- dask==2022.11.1 -- distributed==2022.11.1 -- scikit-build>=0.13.1 -- rapids-build-env=22.12.* -- rapids-notebook-env=22.12.* -- rapids-doc-env=22.12.* -- rmm=22.12.* -- dask-cuda=22.12.* -- ucx>=1.13.0 -- ucx-py=0.29.* -- ucx-proc=*=gpu -- doxygen>=1.8.20 -- libfaiss>=1.7.0 -- faiss-proc=*=cuda -- ccache -- pip -- pip: - - sphinx_markdown_tables - - breathe - -# rapids-build-env, notebook-env and doc-env are defined in -# https://docs.rapids.ai/maintainers/depmgmt/ - -# To install different versions of packages contained in those meta packages, -# it is recommended to remove those meta packages (without removing the actual -# packages contained in the environment) first with: -# conda remove --force rapids-build-env rapids-notebook-env rapids-doc-env diff --git a/conda/environments/raft_dev_cuda11.4.yml b/conda/environments/raft_dev_cuda11.4.yml deleted file mode 100644 index 83eca86a7f..0000000000 --- a/conda/environments/raft_dev_cuda11.4.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: raft_dev -channels: -- rapidsai -- rapidsai-nightly -- dask/label/dev -- conda-forge -- nvidia -dependencies: -- c-compiler -- cxx-compiler -- cudatoolkit=11.4 -- cuda-python >=11.7.1,<12.0 -- ninja -- clang=11.1.0 -- clang-tools=11.1.0 -- cython>=0.29,<0.30 -- cmake>=3.23.1,!=3.25.0 -- dask==2022.11.1 -- distributed==2022.11.1 -- scikit-build>=0.13.1 -- rapids-build-env=22.12.* -- rapids-notebook-env=22.12.* -- rapids-doc-env=22.12.* -- rmm=22.12.* -- dask-cuda=22.12.* -- ucx>=1.13.0 -- ucx-py=0.29.* -- ucx-proc=*=gpu -- doxygen>=1.8.20 -- libfaiss>=1.7.0 -- faiss-proc=*=cuda -- ccache -- pip -- pip: - - sphinx_markdown_tables - - breathe - -# rapids-build-env, notebook-env and doc-env are defined in -# https://docs.rapids.ai/maintainers/depmgmt/ - -# To install different versions of packages contained in those meta packages, -# it is recommended to remove those meta packages (without removing the actual -# packages contained in the environment) first with: -# conda remove --force rapids-build-env rapids-notebook-env rapids-doc-env diff --git a/conda/environments/raft_dev_cuda11.5.yml b/conda/environments/raft_dev_cuda11.5.yml deleted file mode 100644 index f8ef71bac2..0000000000 --- a/conda/environments/raft_dev_cuda11.5.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: raft_dev -channels: -- rapidsai -- rapidsai-nightly -- dask/label/dev -- conda-forge -- nvidia -dependencies: -- c-compiler -- cxx-compiler -- cudatoolkit=11.5 -- cuda-python >=11.7.1,<12.0 -- ninja -- clang=11.1.0 -- clang-tools=11.1.0 -- cython>=0.29,<0.30 -- cmake>=3.23.1,!=3.25.0 -- dask==2022.11.1 -- distributed==2022.11.1 -- scikit-build>=0.13.1 -- rapids-build-env=22.12.* -- rapids-notebook-env=22.12.* -- rapids-doc-env=22.12.* -- rmm=22.12.* -- dask-cuda=22.12.* -- ucx>=1.13.0 -- ucx-py=0.29.* -- ucx-proc=*=gpu -- doxygen>=1.8.20 -- libfaiss>=1.7.0 -- faiss-proc=*=cuda -- ccache -- pip -- pip: - - sphinx_markdown_tables - - breathe - -# rapids-build-env, notebook-env and doc-env are defined in -# https://docs.rapids.ai/maintainers/depmgmt/ - -# To install different versions of packages contained in those meta packages, -# it is recommended to remove those meta packages (without removing the actual -# packages contained in the environment) first with: -# conda remove --force rapids-build-env rapids-notebook-env rapids-doc-env diff --git a/conda/recipes/libraft/build_libraft_distance.sh b/conda/recipes/libraft/build_libraft_distance.sh index 35bf354e9b..d7e995fc03 100644 --- a/conda/recipes/libraft/build_libraft_distance.sh +++ b/conda/recipes/libraft/build_libraft_distance.sh @@ -1,4 +1,4 @@ #!/usr/bin/env bash -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. ./build.sh libraft -v --allgpuarch --compile-dist --no-nvtx diff --git a/conda/recipes/libraft/build_libraft_nn.sh b/conda/recipes/libraft/build_libraft_nn.sh index 773d6ab02e..9865922cd0 100644 --- a/conda/recipes/libraft/build_libraft_nn.sh +++ b/conda/recipes/libraft/build_libraft_nn.sh @@ -1,4 +1,4 @@ #!/usr/bin/env bash -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. ./build.sh libraft -v --allgpuarch --compile-nn --no-nvtx diff --git a/conda/recipes/libraft/build_libraft_tests.sh b/conda/recipes/libraft/build_libraft_tests.sh index 040a2f8b8c..6adbbe78e1 100644 --- a/conda/recipes/libraft/build_libraft_tests.sh +++ b/conda/recipes/libraft/build_libraft_tests.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. ./build.sh tests bench -v --allgpuarch --no-nvtx cmake --install cpp/build --component testing diff --git a/conda/recipes/libraft/conda_build_config.yaml b/conda/recipes/libraft/conda_build_config.yaml index facb478562..1012bddb40 100644 --- a/conda/recipes/libraft/conda_build_config.yaml +++ b/conda/recipes/libraft/conda_build_config.yaml @@ -19,11 +19,43 @@ nccl_version: gtest_version: - "=1.10.0" -libcusolver_version: - - ">=11.2.1,<=11.4.1.48" +libfaiss_version: + - "1.7.2 *_cuda" -libcusparse_version: - - ">=11.5.0,<12.0" +# The CTK libraries below are missing from the conda-forge::cudatoolkit +# package. The "*_host_*" version specifiers correspond to `11.8` packages and the +# "*_run_*" version specifiers correspond to `11.x` packages. -libfaiss_version: - - "1.7.0 *_cuda" +libcublas_host_version: + - "=11.11.3.6" + +libcublas_run_version: + - ">=11.5.2.43,<12.0.0" + +libcurand_host_version: + - "=10.3.0.86" + +libcurand_run_version: + - ">=10.2.5.43,<10.3.1" + +libcusolver_host_version: + - "=11.4.1.48" + +libcusolver_run_version: + - ">=11.2.0.43,<11.4.2" + +libcusparse_host_version: + - "=11.7.5.86" + +libcusparse_run_version: + - ">=11.6.0.43,<12.0.0" + +# `cuda-profiler-api` only has `11.8.0` and `12.0.0` packages for all +# architectures. The "*_host_*" version specifiers correspond to `11.8` packages and the +# "*_run_*" version specifiers correspond to `11.x` packages. + +cuda_profiler_api_host_version: + - "=11.8.86" + +cuda_profiler_api_run_version: + - ">=11.4.240,<12" diff --git a/conda/recipes/libraft/meta.yaml b/conda/recipes/libraft/meta.yaml index 339fa76065..b0d6c47ee9 100644 --- a/conda/recipes/libraft/meta.yaml +++ b/conda/recipes/libraft/meta.yaml @@ -4,9 +4,8 @@ # conda build . -c conda-forge -c nvidia -c rapidsai {% set version = environ.get('GIT_DESCRIBE_TAG', '0.0.0.dev').lstrip('v') + environ.get('VERSION_SUFFIX', '') %} {% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} -{% set cuda_version = '.'.join(environ.get('CUDA', '9.2').split('.')[:2]) %} +{% set cuda_version = '.'.join(environ['RAPIDS_CUDA_VERSION'].split('.')[:2]) %} {% set cuda_major = cuda_version.split('.')[0] %} -{% set ucx_py_version = environ.get('UCX_PY_VERSION') %} {% set cuda_spec = ">=" + cuda_major ~ ",<" + (cuda_major | int + 1) ~ ".0a0" %} # i.e. >=11,<12.0a0 package: @@ -21,18 +20,18 @@ outputs: script: build_libraft_headers.sh build: script_env: &script_env - - PARALLEL_LEVEL - - VERSION_SUFFIX - - PROJECT_FLASH - - CMAKE_GENERATOR + - AWS_ACCESS_KEY_ID + - AWS_SECRET_ACCESS_KEY - CMAKE_C_COMPILER_LAUNCHER - - CMAKE_CXX_COMPILER_LAUNCHER - CMAKE_CUDA_COMPILER_LAUNCHER + - CMAKE_CXX_COMPILER_LAUNCHER + - CMAKE_GENERATOR + - PARALLEL_LEVEL + - SCCACHE_BUCKET + - SCCACHE_IDLE_TIMEOUT + - SCCACHE_REGION - SCCACHE_S3_KEY_PREFIX=libraft-aarch64 # [aarch64] - SCCACHE_S3_KEY_PREFIX=libraft-linux64 # [linux64] - - SCCACHE_BUCKET=rapids-sccache - - SCCACHE_REGION=us-west-2 - - SCCACHE_IDLE_TIMEOUT=32768 number: {{ GIT_DESCRIBE_NUMBER }} string: cuda{{ cuda_major }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} ignore_run_exports_from: @@ -42,26 +41,35 @@ outputs: - {{ compiler('c') }} - {{ compiler('cxx') }} - {{ compiler('cuda') }} {{ cuda_version }} - - sysroot_{{ target_platform }} {{ sysroot_version }} - cmake {{ cmake_version }} + - ninja + - sysroot_{{ target_platform }} {{ sysroot_version }} host: - - cudatoolkit {{ cuda_version }}.* - - libcusolver {{ libcusolver_version }} - - libcusparse {{ libcusparse_version }} - - librmm {{ minor_version }} - - nccl {{ nccl_version }} - - ucx-proc=*=gpu - - ucx-py {{ ucx_py_version }} + - cuda-profiler-api {{ cuda_profiler_api_host_version }} + - cudatoolkit ={{ cuda_version }} + - libcublas {{ libcublas_host_version }} + - libcublas-dev {{ libcublas_host_version }} + - libcurand {{ libcurand_host_version }} + - libcurand-dev {{ libcurand_host_version }} + - libcusolver {{ libcusolver_host_version }} + - libcusolver-dev {{ libcusolver_host_version }} + - libcusparse {{ libcusparse_host_version }} + - libcusparse-dev {{ libcusparse_host_version }} + - librmm ={{ minor_version }} run: - - cudatoolkit {{ cuda_spec }} - - libcusolver {{ libcusolver_version }} - - libcusparse {{ libcusparse_version }} - - librmm {{ minor_version }} - - nccl {{ nccl_version }} - - ucx-proc=*=gpu - - ucx-py {{ ucx_py_version }} + - {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }} + - cuda-profiler-api {{ cuda_profiler_api_run_version }} + - libcublas {{ libcublas_run_version }} + - libcublas-dev {{ libcublas_run_version }} + - libcurand {{ libcurand_run_version }} + - libcurand-dev {{ libcurand_run_version }} + - libcusolver {{ libcusolver_run_version }} + - libcusolver-dev {{ libcusolver_run_version }} + - libcusparse {{ libcusparse_run_version }} + - libcusparse-dev {{ libcusparse_run_version }} + - librmm ={{ minor_version }} about: - home: http://rapids.ai/ + home: https://rapids.ai/ license: Apache-2.0 summary: libraft-headers library - name: libraft-distance @@ -75,29 +83,27 @@ outputs: - {{ compiler('cuda') }} requirements: build: - - cmake {{ cmake_version }} - {{ compiler('c') }} - - {{ compiler('cxx') }} - {{ compiler('cuda') }} {{ cuda_version }} + - {{ compiler('cxx') }} + - cmake {{ cmake_version }} + - ninja - sysroot_{{ target_platform }} {{ sysroot_version }} host: - - cudatoolkit {{ cuda_version }}.* - - librmm {{ minor_version }} - - nccl {{ nccl_version }} - - ucx-proc=*=gpu - - ucx-py {{ ucx_py_version }} - {{ pin_subpackage('libraft-headers', exact=True) }} + - cuda-profiler-api {{ cuda_profiler_api_host_version }} + - libcublas {{ libcublas_host_version }} + - libcublas-dev {{ libcublas_host_version }} + - libcurand {{ libcurand_host_version }} + - libcurand-dev {{ libcurand_host_version }} + - libcusolver {{ libcusolver_host_version }} + - libcusolver-dev {{ libcusolver_host_version }} + - libcusparse {{ libcusparse_host_version }} + - libcusparse-dev {{ libcusparse_host_version }} run: - - cudatoolkit {{ cuda_spec }} - - librmm {{ minor_version }} - - nccl {{ nccl_version }} - - ucx-proc=*=gpu - - ucx-py {{ ucx_py_version }} - - libcusolver {{ libcusolver_version }} - - libcusparse {{ libcusparse_version }} - {{ pin_subpackage('libraft-headers', exact=True) }} about: - home: http://rapids.ai/ + home: https://rapids.ai/ license: Apache-2.0 summary: libraft-distance library - name: libraft-nn @@ -111,28 +117,32 @@ outputs: - {{ compiler('cuda') }} requirements: build: - - cmake {{ cmake_version }} - {{ compiler('c') }} - - {{ compiler('cxx') }} - {{ compiler('cuda') }} {{ cuda_version }} + - {{ compiler('cxx') }} + - cmake {{ cmake_version }} + - ninja - sysroot_{{ target_platform }} {{ sysroot_version }} host: - - cudatoolkit {{ cuda_version }}.* + - {{ pin_subpackage('libraft-headers', exact=True) }} + - cuda-profiler-api {{ cuda_profiler_api_host_version }} - faiss-proc=*=cuda - lapack + - libcublas {{ libcublas_host_version }} + - libcublas-dev {{ libcublas_host_version }} + - libcurand {{ libcurand_host_version }} + - libcurand-dev {{ libcurand_host_version }} + - libcusolver {{ libcusolver_host_version }} + - libcusolver-dev {{ libcusolver_host_version }} + - libcusparse {{ libcusparse_host_version }} + - libcusparse-dev {{ libcusparse_host_version }} - libfaiss {{ libfaiss_version }} - - librmm {{ minor_version }} - - {{ pin_subpackage('libraft-headers', exact=True) }} run: - - cudatoolkit {{ cuda_spec }} - faiss-proc=*=cuda - - libcusolver {{ libcusolver_version }} - - libcusparse {{ libcusparse_version }} - libfaiss {{ libfaiss_version }} - - librmm {{ minor_version }} - {{ pin_subpackage('libraft-headers', exact=True) }} about: - home: http://rapids.ai/ + home: https://rapids.ai/ license: Apache-2.0 summary: libraft-nn library - name: libraft-tests @@ -146,28 +156,34 @@ outputs: - {{ compiler('cuda') }} requirements: build: - - cmake {{ cmake_version }} - {{ compiler('c') }} - - {{ compiler('cxx') }} - {{ compiler('cuda') }} {{ cuda_version }} + - {{ compiler('cxx') }} + - cmake {{ cmake_version }} + - ninja - sysroot_{{ target_platform }} {{ sysroot_version }} host: - - cudatoolkit {{ cuda_version }}.* - - gmock {{ gtest_version }} - - gtest {{ gtest_version }} - {{ pin_subpackage('libraft-distance', exact=True) }} - {{ pin_subpackage('libraft-headers', exact=True) }} - {{ pin_subpackage('libraft-nn', exact=True) }} - run: - - cudatoolkit {{ cuda_spec }} + - cuda-profiler-api {{ cuda_profiler_api_host_version }} - gmock {{ gtest_version }} - gtest {{ gtest_version }} - - libcusolver {{ libcusolver_version }} - - libcusparse {{ libcusparse_version }} + - libcublas {{ libcublas_host_version }} + - libcublas-dev {{ libcublas_host_version }} + - libcurand {{ libcurand_host_version }} + - libcurand-dev {{ libcurand_host_version }} + - libcusolver {{ libcusolver_host_version }} + - libcusolver-dev {{ libcusolver_host_version }} + - libcusparse {{ libcusparse_host_version }} + - libcusparse-dev {{ libcusparse_host_version }} + run: - {{ pin_subpackage('libraft-distance', exact=True) }} - {{ pin_subpackage('libraft-headers', exact=True) }} - {{ pin_subpackage('libraft-nn', exact=True) }} + - gmock {{ gtest_version }} + - gtest {{ gtest_version }} about: - home: http://rapids.ai/ + home: https://rapids.ai/ license: Apache-2.0 summary: libraft tests diff --git a/conda/recipes/pylibraft/meta.yaml b/conda/recipes/pylibraft/meta.yaml index 68e2d5952d..6bc091a219 100644 --- a/conda/recipes/pylibraft/meta.yaml +++ b/conda/recipes/pylibraft/meta.yaml @@ -3,10 +3,10 @@ # Usage: # conda build . -c conda-forge -c numba -c rapidsai -c pytorch {% set version = environ.get('GIT_DESCRIBE_TAG', '0.0.0.dev').lstrip('v') + environ.get('VERSION_SUFFIX', '') %} -{% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} -{% set cuda_version='.'.join(environ.get('CUDA', 'unknown').split('.')[:2]) %} -{% set cuda_major=cuda_version.split('.')[0] %} -{% set py_version=environ.get('CONDA_PY', 36) %} +{% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} +{% set py_version = environ['CONDA_PY'] %} +{% set cuda_version = '.'.join(environ['RAPIDS_CUDA_VERSION'].split('.')[:2]) %} +{% set cuda_major = cuda_version.split('.')[0] %} package: name: pylibraft @@ -23,36 +23,38 @@ build: requirements: build: - - cmake {{ cmake_version }} - {{ compiler('c') }} - {{ compiler('cxx') }} - {{ compiler('cuda') }} {{ cuda_version }} + - cmake {{ cmake_version }} + - ninja - sysroot_{{ target_platform }} {{ sysroot_version }} host: + - cuda-python >=11.7.1,<12.0 + - cudatoolkit ={{ cuda_version }} + - cython >=0.29,<0.30 + - libraft-distance {{ version }} + - libraft-headers {{ version }} - python x.x + - rmm ={{ minor_version }} + - scikit-build >=0.13.1 - setuptools - - cython>=0.29,<0.30 - - scikit-build>=0.13.1 - - rmm {{ minor_version }} - - libraft-headers {{ version }} - - libraft-distance {{ version }} - - cudatoolkit {{ cuda_version }}.* - - cuda-python >=11.7.1,<12.0 run: - - python x.x - - libraft-headers {{ version }} - - libraft-distance {{ version }} - - cuda-python >=11.7.1,<12.0 - {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }} + - cuda-python >=11.7.1,<12.0 + - libraft-distance {{ version }} + - libraft-headers {{ version }} + - python x.x +# TODO: Remove the linux64 tags on tests after disabling gpuCI / Jenkins tests: # [linux64] requirements: # [linux64] - - cudatoolkit {{ cuda_version }}.* # [linux64] + - cudatoolkit ={{ cuda_version }} # [linux64] imports: # [linux64] - pylibraft # [linux64] about: - home: http://rapids.ai/ + home: https://rapids.ai/ license: Apache-2.0 # license_file: LICENSE summary: pylibraft library diff --git a/conda/recipes/raft-dask/conda_build_config.yaml b/conda/recipes/raft-dask/conda_build_config.yaml index 3b42dab182..42d7e3a900 100644 --- a/conda/recipes/raft-dask/conda_build_config.yaml +++ b/conda/recipes/raft-dask/conda_build_config.yaml @@ -13,5 +13,8 @@ sysroot_version: ucx_version: - "1.13.0" +ucx_py_version: + - "0.30.*" + cmake_version: - ">=3.23.1,!=3.25.0" diff --git a/conda/recipes/raft-dask/meta.yaml b/conda/recipes/raft-dask/meta.yaml index f9a7c58e24..a8bc626eaa 100644 --- a/conda/recipes/raft-dask/meta.yaml +++ b/conda/recipes/raft-dask/meta.yaml @@ -3,11 +3,10 @@ # Usage: # conda build . -c conda-forge -c numba -c rapidsai -c pytorch {% set version = environ.get('GIT_DESCRIBE_TAG', '0.0.0.dev').lstrip('v') + environ.get('VERSION_SUFFIX', '') %} -{% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} -{% set cuda_version='.'.join(environ.get('CUDA', 'unknown').split('.')[:2]) %} -{% set cuda_major=cuda_version.split('.')[0] %} -{% set py_version=environ.get('CONDA_PY', 36) %} -{% set ucx_py_version=environ.get('UCX_PY_VERSION') %} +{% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} +{% set py_version = environ['CONDA_PY'] %} +{% set cuda_version = '.'.join(environ['RAPIDS_CUDA_VERSION'].split('.')[:2]) %} +{% set cuda_major = cuda_version.split('.')[0] %} package: name: raft-dask @@ -24,47 +23,49 @@ build: requirements: build: - - cmake {{ cmake_version }} - {{ compiler('c') }} - {{ compiler('cxx') }} - {{ compiler('cuda') }} {{ cuda_version }} + - cmake {{ cmake_version }} + - ninja - sysroot_{{ target_platform }} {{ sysroot_version }} host: + - cuda-python >=11.7.1,<12.0 + - cudatoolkit ={{ cuda_version }} + - cython >=0.29,<0.30 + - nccl >=2.9.9 + - pylibraft {{ version }} - python x.x + - rmm ={{ minor_version }} + - scikit-build >=0.13.1 - setuptools - - cython>=0.29,<0.30 - - scikit-build>=0.13.1 - - rmm {{ minor_version }} - - pylibraft {{ version }} - - cudatoolkit {{ cuda_version }}.* - - cuda-python >=11.7.1,<12.0 - - nccl>=2.9.9 - ucx {{ ucx_version }} - - ucx-py {{ ucx_py_version }} - ucx-proc=*=gpu + - ucx-py {{ ucx_py_version }} run: - - python x.x - - dask-cuda {{ minor_version }} + - {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }} + - cuda-python >=11.7.1,<12.0 + - dask ==2023.1.1 + - dask-cuda ={{ minor_version }} + - distributed ==2023.1.1 + - joblib >=0.11 + - nccl >=2.9.9 - pylibraft {{ version }} - - nccl>=2.9.9 - - rmm {{ minor_version }} + - python x.x + - rmm ={{ minor_version }} - ucx >={{ ucx_version }} - - ucx-py {{ ucx_py_version }} - ucx-proc=*=gpu - - dask==2022.11.1 - - distributed==2022.11.1 - - cuda-python >=11.7.1,<12.0 - - joblib >=0.11 - - {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }} + - ucx-py {{ ucx_py_version }} +# TODO: Remove the linux64 tags on tests after disabling gpuCI / Jenkins tests: # [linux64] requirements: # [linux64] - - cudatoolkit {{ cuda_version }}.* # [linux64] + - cudatoolkit ={{ cuda_version }} # [linux64] imports: # [linux64] - raft_dask # [linux64] about: - home: http://rapids.ai/ + home: https://rapids.ai/ license: Apache-2.0 # license_file: LICENSE summary: raft-dask library diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 2fd10fe067..1d54409ae6 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -10,9 +10,8 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. -# ============================================================================= -set(RAPIDS_VERSION "22.12") -set(RAFT_VERSION "22.12.01") +set(RAPIDS_VERSION "23.02") +set(RAFT_VERSION "23.02.00") cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) include(../fetch_rapids.cmake) @@ -219,6 +218,15 @@ target_link_libraries( target_compile_features(raft INTERFACE cxx_std_17 $) +# Endian detection +include(TestBigEndian) +test_big_endian(BIG_ENDIAN) +if(BIG_ENDIAN) + target_compile_definitions(raft INTERFACE RAFT_SYSTEM_LITTLE_ENDIAN=0) +else() + target_compile_definitions(raft INTERFACE RAFT_SYSTEM_LITTLE_ENDIAN=1) +endif() + if(RAFT_COMPILE_DIST_LIBRARY OR RAFT_COMPILE_NN_LIBRARY) file( WRITE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld" @@ -279,79 +287,98 @@ set_target_properties(raft_distance PROPERTIES EXPORT_NAME distance) if(RAFT_COMPILE_DIST_LIBRARY) add_library( raft_distance_lib - src/distance/pairwise_distance.cu - src/distance/fused_l2_min_arg.cu - src/distance/update_centroids_float.cu - src/distance/update_centroids_double.cu - src/distance/cluster_cost_float.cu - src/distance/cluster_cost_double.cu - src/distance/specializations/detail/canberra.cu - src/distance/specializations/detail/chebyshev.cu - src/distance/specializations/detail/correlation.cu - src/distance/specializations/detail/cosine.cu - src/distance/specializations/detail/cosine.cu - src/distance/specializations/detail/hamming_unexpanded.cu - src/distance/specializations/detail/hellinger_expanded.cu - src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu - src/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu - src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu - src/distance/specializations/detail/kernels/gram_matrix_base_double.cu - src/distance/specializations/detail/kernels/gram_matrix_base_float.cu - src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu - src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu + src/distance/distance/pairwise_distance.cu + src/distance/distance/fused_l2_min_arg.cu + src/distance/cluster/update_centroids_float.cu + src/distance/cluster/update_centroids_double.cu + src/distance/cluster/cluster_cost_float.cu + src/distance/cluster/cluster_cost_double.cu + src/distance/neighbors/refine_d_uint64_t_float.cu + src/distance/neighbors/refine_d_uint64_t_int8_t.cu + src/distance/neighbors/refine_d_uint64_t_uint8_t.cu + src/distance/neighbors/refine_h_uint64_t_float.cu + src/distance/neighbors/refine_h_uint64_t_int8_t.cu + src/distance/neighbors/refine_h_uint64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_d_uint64_t_float.cu + src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu + src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_h_uint64_t_float.cu + src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu + src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu + src/distance/cluster/kmeans_fit_float.cu + src/distance/cluster/kmeans_fit_double.cu + src/distance/distance/specializations/detail/canberra.cu + src/distance/distance/specializations/detail/chebyshev.cu + src/distance/distance/specializations/detail/correlation.cu + src/distance/distance/specializations/detail/cosine.cu + src/distance/distance/specializations/detail/cosine.cu + src/distance/distance/specializations/detail/hamming_unexpanded.cu + src/distance/distance/specializations/detail/hellinger_expanded.cu + src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu + src/distance/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu + src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu + src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu + src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu + src/distance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu + src/distance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu # These are somehow missing a kernel definition which is causing a compile error. # src/distance/specializations/detail/kernels/rbf_kernel_double.cu # src/distance/specializations/detail/kernels/rbf_kernel_float.cu - src/distance/specializations/detail/kernels/tanh_kernel_double.cu - src/distance/specializations/detail/kernels/tanh_kernel_float.cu - src/distance/specializations/detail/kl_divergence_float_float_float_int.cu - src/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu - src/distance/specializations/detail/kl_divergence_double_double_double_int.cu - src/distance/specializations/detail/l1_float_float_float_int.cu - src/distance/specializations/detail/l1_float_float_float_uint32.cu - src/distance/specializations/detail/l1_double_double_double_int.cu - src/distance/specializations/detail/l2_expanded_float_float_float_int.cu - src/distance/specializations/detail/l2_expanded_float_float_float_uint32.cu - src/distance/specializations/detail/l2_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu - src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_uint32.cu - src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_uint32.cu - src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu - src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu - src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/russel_rao_double_double_double_int.cu - src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu - src/distance/specializations/detail/russel_rao_float_float_float_int.cu - src/distance/specializations/fused_l2_nn_double_int.cu - src/distance/specializations/fused_l2_nn_double_int64.cu - src/distance/specializations/fused_l2_nn_float_int.cu - src/distance/specializations/fused_l2_nn_float_int64.cu - src/nn/specializations/detail/ivfpq_build.cu - src/nn/specializations/detail/ivfpq_compute_similarity_float_fast.cu - src/nn/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu - src/nn/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu - src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu - src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu - src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu - src/nn/specializations/detail/ivfpq_search.cu - src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu - src/nn/specializations/refine.cu - src/random/specializations/rmat_rectangular_generator_int_double.cu - src/random/specializations/rmat_rectangular_generator_int64_double.cu - src/random/specializations/rmat_rectangular_generator_int_float.cu - src/random/specializations/rmat_rectangular_generator_int64_float.cu + src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu + src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu + src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu + src/distance/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu + src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu + src/distance/distance/specializations/detail/l1_float_float_float_int.cu + src/distance/distance/specializations/detail/l1_float_float_float_uint32.cu + src/distance/distance/specializations/detail/l1_double_double_double_int.cu + src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu + src/distance/distance/specializations/detail/l2_expanded_float_float_float_uint32.cu + src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu + src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu + src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_uint32.cu + src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu + src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu + src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_uint32.cu + src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu + src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu + src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu + src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu + src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu + src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu + src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu + src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu + src/distance/distance/specializations/detail/russel_rao_float_float_float_uint32.cu + src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu + src/distance/distance/specializations/fused_l2_nn_double_int.cu + src/distance/distance/specializations/fused_l2_nn_double_int64.cu + src/distance/distance/specializations/fused_l2_nn_float_int.cu + src/distance/distance/specializations/fused_l2_nn_float_int64.cu + src/distance/neighbors/ivfpq_build.cu + src/distance/neighbors/ivfpq_deserialize.cu + src/distance/neighbors/ivfpq_serialize.cu + src/distance/neighbors/ivfpq_search_float_uint64_t.cu + src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu + src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_fast.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_fast.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu + src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu + src/distance/neighbors/specializations/detail/ivfpq_search_float_int64_t.cu + src/distance/neighbors/specializations/detail/ivfpq_search_float_uint64_t.cu + src/distance/neighbors/specializations/detail/ivfpq_search_float_uint32_t.cu + src/distance/random/rmat_rectangular_generator_int_double.cu + src/distance/random/rmat_rectangular_generator_int64_double.cu + src/distance/random/rmat_rectangular_generator_int_float.cu + src/distance/random/rmat_rectangular_generator_int64_float.cu ) set_target_properties( raft_distance_lib @@ -404,33 +431,21 @@ set_target_properties(raft_nn PROPERTIES EXPORT_NAME nn) if(RAFT_COMPILE_NN_LIBRARY) add_library( raft_nn_lib - src/nn/specializations/ball_cover.cu src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu - src/nn/specializations/detail/ivfpq_compute_similarity_float_fast.cu - src/nn/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu - src/nn/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu - src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu - src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu - src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu - src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu - src/nn/specializations/detail/ivfpq_build.cu - src/nn/specializations/detail/ivfpq_search.cu - src/nn/specializations/detail/ivfpq_search_float_int64_t.cu - src/nn/specializations/detail/ivfpq_search_float_uint32_t.cu - src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu + src/nn/specializations/ball_cover_all_knn_query.cu + src/nn/specializations/ball_cover_build_index.cu + src/nn/specializations/ball_cover_knn_query.cu src/nn/specializations/fused_l2_knn_long_float_true.cu src/nn/specializations/fused_l2_knn_long_float_false.cu src/nn/specializations/fused_l2_knn_int_float_true.cu src/nn/specializations/fused_l2_knn_int_float_false.cu - src/nn/specializations/knn.cu + src/nn/specializations/brute_force_knn_long_float_int.cu + src/nn/specializations/brute_force_knn_long_float_uint.cu + src/nn/specializations/brute_force_knn_uint32_t_float_int.cu + src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu ) set_target_properties( raft_nn_lib @@ -478,10 +493,20 @@ endif() set_target_properties(raft_distributed PROPERTIES EXPORT_NAME distributed) +rapids_find_generate_module( + NCCL + HEADER_NAMES nccl.h + LIBRARY_NAMES nccl + BUILD_EXPORT_SET raft-distributed-exports + INSTALL_EXPORT_SET raft-distributed-exports +) + rapids_export_package(BUILD ucx raft-distributed-exports) rapids_export_package(INSTALL ucx raft-distributed-exports) +rapids_export_package(BUILD NCCL raft-distributed-exports) +rapids_export_package(INSTALL NCCL raft-distributed-exports) -target_link_libraries(raft_distributed INTERFACE ucx::ucp) +target_link_libraries(raft_distributed INTERFACE ucx::ucp NCCL::NCCL) # ################################################################################################## # * install targets----------------------------------------------------------- @@ -518,7 +543,7 @@ if(TARGET raft_distance_lib) EXPORT raft-distance-lib-exports ) install( - DIRECTORY include/raft_distance + DIRECTORY include/raft_runtime DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} COMPONENT distance ) @@ -665,27 +690,17 @@ raft_export( # ################################################################################################## # * build export ------------------------------------------------------------- raft_export( - BUILD - raft - EXPORT_SET - raft-exports - COMPONENTS - nn - distance - distributed - GLOBAL_TARGETS - raft - raft_distance - distributed - raft_nn - DOCUMENTATION - doc_string - NAMESPACE - raft:: - FINAL_CODE_BLOCK - code_string + BUILD raft EXPORT_SET raft-exports COMPONENTS nn distance distributed GLOBAL_TARGETS raft + distance distributed nn DOCUMENTATION doc_string NAMESPACE raft:: FINAL_CODE_BLOCK code_string ) +# ################################################################################################## +# * shared test/bench headers ------------------------------------------------ + +if(BUILD_TESTS OR BUILD_BENCH) + include(internal/CMakeLists.txt) +endif() + # ################################################################################################## # * build test executable ---------------------------------------------------- diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 4e6b6ceb40..b1ffc72ba9 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -30,6 +30,7 @@ function(ConfigureBench) target_link_libraries( ${BENCH_NAME} PRIVATE raft::raft + raft_internal $<$:raft::distance> $<$:raft::nn> benchmark::benchmark @@ -96,12 +97,16 @@ if(BUILD_BENCH) bench/linalg/matrix_vector_op.cu bench/linalg/norm.cu bench/linalg/normalize.cu + bench/linalg/reduce_cols_by_key.cu bench/linalg/reduce_rows_by_key.cu bench/linalg/reduce.cu bench/main.cpp ) - ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/main.cpp) + ConfigureBench( + NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/matrix/select_k.cu + bench/main.cpp + ) ConfigureBench( NAME RANDOM_BENCH PATH bench/random/make_blobs.cu bench/random/permute.cu bench/random/rng.cu @@ -125,7 +130,6 @@ if(BUILD_BENCH) bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu bench/neighbors/knn/ivf_pq_uint8_t_uint32_t.cu bench/neighbors/refine.cu - bench/neighbors/selection.cu bench/main.cpp OPTIONAL DIST diff --git a/cpp/bench/cluster/kmeans_balanced.cu b/cpp/bench/cluster/kmeans_balanced.cu index 210b40ced8..9c53e86d8c 100644 --- a/cpp/bench/cluster/kmeans_balanced.cu +++ b/cpp/bench/cluster/kmeans_balanced.cu @@ -15,20 +15,19 @@ */ #include +#include #include -#include -#if defined RAFT_DISTANCE_COMPILED && defined RAFT_NN_COMPILED -#include +#if defined RAFT_DISTANCE_COMPILED +#include #endif namespace raft::bench::cluster { struct KMeansBalancedBenchParams { DatasetParams data; - uint32_t max_iter; uint32_t n_lists; - raft::distance::DistanceType metric; + raft::cluster::kmeans_balanced_params kb_params; }; template @@ -38,15 +37,10 @@ struct KMeansBalanced : public fixture { void run_benchmark(::benchmark::State& state) override { this->loop_on_state(state, [this]() { - raft::spatial::knn::detail::kmeans::build_hierarchical(this->handle, - this->params.max_iter, - (uint32_t)this->params.data.cols, - this->X.data_handle(), - this->params.data.rows, - this->centroids.data_handle(), - this->params.n_lists, - this->params.metric, - this->handle.get_stream()); + raft::device_matrix_view X_view = this->X.view(); + raft::device_matrix_view centroids_view = this->centroids.view(); + raft::cluster::kmeans_balanced::fit( + this->handle, this->params.kb_params, X_view, centroids_view); }); } @@ -84,8 +78,8 @@ std::vector getKMeansBalancedInputs() std::vector out; KMeansBalancedBenchParams p; p.data.row_major = true; - p.max_iter = 20; - p.metric = raft::distance::DistanceType::L2Expanded; + p.kb_params.n_iters = 20; + p.kb_params.metric = raft::distance::DistanceType::L2Expanded; std::vector> row_cols = { {100000, 128}, {1000000, 128}, {10000000, 128}, // The following dataset sizes are too large for most GPUs. @@ -104,7 +98,5 @@ std::vector getKMeansBalancedInputs() // Note: the datasets sizes are too large for 32-bit index types. RAFT_BENCH_REGISTER((KMeansBalanced), "", getKMeansBalancedInputs()); -RAFT_BENCH_REGISTER((KMeansBalanced), "", getKMeansBalancedInputs()); -RAFT_BENCH_REGISTER((KMeansBalanced), "", getKMeansBalancedInputs()); } // namespace raft::bench::cluster diff --git a/cpp/bench/common/benchmark.hpp b/cpp/bench/common/benchmark.hpp index 13ca40a033..85d5381e2c 100644 --- a/cpp/bench/common/benchmark.hpp +++ b/cpp/bench/common/benchmark.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include #include @@ -110,7 +110,7 @@ class fixture { rmm::device_buffer scratch_buf_; public: - raft::handle_t handle; + raft::device_resources handle; rmm::cuda_stream_view stream; fixture() : stream{handle.get_stream()} diff --git a/cpp/bench/distance/distance_common.cuh b/cpp/bench/distance/distance_common.cuh index 73faacce37..7ddecd7579 100644 --- a/cpp/bench/distance/distance_common.cuh +++ b/cpp/bench/distance/distance_common.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/bench/distance/kernels.cu b/cpp/bench/distance/kernels.cu index 5c9c2cc2ed..027f93171e 100644 --- a/cpp/bench/distance/kernels.cu +++ b/cpp/bench/distance/kernels.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include #include -#include +#include #include #include #include @@ -77,7 +77,7 @@ struct GramMatrix : public fixture { } private: - const raft::handle_t handle; + const raft::device_resources handle; std::unique_ptr> kernel; GramTestParams params; diff --git a/cpp/bench/linalg/norm.cu b/cpp/bench/linalg/norm.cu index cce4195cf1..efecee88c9 100644 --- a/cpp/bench/linalg/norm.cu +++ b/cpp/bench/linalg/norm.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -60,7 +61,7 @@ struct rowNorm : public fixture { output_view, raft::linalg::L2Norm, raft::linalg::Apply::ALONG_ROWS, - raft::SqrtOp()); + raft::sqrt_op()); }); } diff --git a/cpp/bench/linalg/reduce_cols_by_key.cu b/cpp/bench/linalg/reduce_cols_by_key.cu new file mode 100644 index 0000000000..43aeb69ab0 --- /dev/null +++ b/cpp/bench/linalg/reduce_cols_by_key.cu @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +namespace raft::bench::linalg { + +template +struct rcbk_params { + IdxT rows, cols; + IdxT keys; +}; + +template +inline auto operator<<(std::ostream& os, const rcbk_params& p) -> std::ostream& +{ + os << p.rows << "#" << p.cols << "#" << p.keys; + return os; +} + +template +struct reduce_cols_by_key : public fixture { + reduce_cols_by_key(const rcbk_params& p) + : params(p), in(p.rows * p.cols, stream), out(p.rows * p.keys, stream), keys(p.cols, stream) + { + raft::random::RngState rng{42}; + raft::random::uniformInt(rng, keys.data(), p.cols, (KeyT)0, (KeyT)p.keys, stream); + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + loop_on_state(state, [this]() { + raft::linalg::reduce_cols_by_key( + in.data(), keys.data(), out.data(), params.rows, params.cols, params.keys, stream, false); + }); + } + + protected: + rcbk_params params; + rmm::device_uvector in, out; + rmm::device_uvector keys; +}; // struct reduce_cols_by_key + +const std::vector> rcbk_inputs_i32 = + raft::util::itertools::product>( + {1, 10, 100, 1000}, {1000, 10000, 100000}, {8, 32, 128, 512, 2048}); +const std::vector> rcbk_inputs_i64 = + raft::util::itertools::product>( + {1, 10, 100, 1000}, {1000, 10000, 100000}, {8, 32, 128, 512, 2048}); + +RAFT_BENCH_REGISTER((reduce_cols_by_key), "", rcbk_inputs_i32); +RAFT_BENCH_REGISTER((reduce_cols_by_key), "", rcbk_inputs_i32); +RAFT_BENCH_REGISTER((reduce_cols_by_key), "", rcbk_inputs_i64); +RAFT_BENCH_REGISTER((reduce_cols_by_key), "", rcbk_inputs_i64); + +} // namespace raft::bench::linalg diff --git a/cpp/bench/matrix/argmin.cu b/cpp/bench/matrix/argmin.cu index 0d0dea0fdb..3869f0c5e1 100644 --- a/cpp/bench/matrix/argmin.cu +++ b/cpp/bench/matrix/argmin.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,11 @@ #include #include #include +#include #include -namespace raft::bench::linalg { +namespace raft::bench::matrix { template struct ArgminParams { @@ -45,9 +46,7 @@ struct Argmin : public fixture { void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { - auto matrix_const_view = raft::make_device_matrix_view( - matrix.data_handle(), matrix.extent(0), matrix.extent(1)); - raft::matrix::argmin(handle, matrix_const_view, indices.view()); + raft::matrix::argmin(handle, raft::make_const_mdspan(matrix.view()), indices.view()); }); } @@ -57,15 +56,11 @@ struct Argmin : public fixture { raft::device_vector indices; }; // struct Argmin -const std::vector> argmin_inputs_i64{ - {1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024}, - {10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024}, - {100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024}, - {1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024}, - {10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024}, -}; +const std::vector> argmin_inputs_i64 = + raft::util::itertools::product>({1000, 10000, 100000, 1000000, 10000000}, + {64, 128, 256, 512, 1024}); RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); -} // namespace raft::bench::linalg +} // namespace raft::bench::matrix diff --git a/cpp/bench/matrix/gather.cu b/cpp/bench/matrix/gather.cu new file mode 100644 index 0000000000..c5d80744cd --- /dev/null +++ b/cpp/bench/matrix/gather.cu @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +namespace raft::bench::matrix { + +template +struct GatherParams { + IdxT rows, cols, map_length; +}; + +template +inline auto operator<<(std::ostream& os, const GatherParams& p) -> std::ostream& +{ + os << p.rows << "#" << p.cols << "#" << p.map_length; + return os; +} + +template +struct Gather : public fixture { + Gather(const GatherParams& p) : params(p) {} + + void allocate_data(const ::benchmark::State& state) override + { + matrix = raft::make_device_matrix(handle, params.rows, params.cols); + map = raft::make_device_vector(handle, params.map_length); + out = raft::make_device_matrix(handle, params.map_length, params.cols); + stencil = raft::make_device_vector(handle, Conditional ? params.map_length : IdxT(0)); + + raft::random::RngState rng{1234}; + raft::random::uniform( + rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream); + raft::random::uniformInt( + handle, rng, map.data_handle(), params.map_length, (MapT)0, (MapT)params.rows); + if constexpr (Conditional) { + raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream); + } + handle.sync_stream(stream); + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + loop_on_state(state, [this]() { + auto matrix_const_view = raft::make_const_mdspan(matrix.view()); + auto map_const_view = raft::make_const_mdspan(map.view()); + if constexpr (Conditional) { + auto stencil_const_view = raft::make_const_mdspan(stencil.view()); + auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op()); + raft::matrix::gather_if( + handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); + } else { + raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + } + }); + } + + private: + GatherParams params; + raft::device_matrix matrix, out; + raft::device_vector stencil; + raft::device_vector map; +}; // struct Gather + +template +using GatherIf = Gather; + +const std::vector> gather_inputs_i64 = + raft::util::itertools::product>( + {1000000}, {10, 20, 50, 100, 200, 500}, {1000, 10000, 100000, 1000000}); + +RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); +} // namespace raft::bench::matrix diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu new file mode 100644 index 0000000000..2c8b8bb67b --- /dev/null +++ b/cpp/bench/matrix/select_k.cu @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace raft::matrix { + +using namespace raft::bench; // NOLINT + +template +struct selection : public fixture { + explicit selection(const select::params& p) + : params_(p), + in_dists_(p.batch_size * p.len, stream), + in_ids_(p.batch_size * p.len, stream), + out_dists_(p.batch_size * p.k, stream), + out_ids_(p.batch_size * p.k, stream) + { + raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream); + raft::random::RngState state{42}; + raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); + } + + void run_benchmark(::benchmark::State& state) override // NOLINT + { + device_resources handle{stream}; + using_pool_memory_res res; + try { + std::ostringstream label_stream; + label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; + state.SetLabel(label_stream.str()); + loop_on_state(state, [this, &handle]() { + select::select_k_impl(handle, + Algo, + in_dists_.data(), + in_ids_.data(), + params_.batch_size, + params_.len, + params_.k, + out_dists_.data(), + out_ids_.data(), + params_.select_min); + }); + } catch (raft::exception& e) { + state.SkipWithError(e.what()); + } + } + + private: + const select::params params_; + rmm::device_uvector in_dists_, out_dists_; + rmm::device_uvector in_ids_, out_ids_; +}; + +const std::vector kInputs{ + {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, + {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, + {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, + + {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, + {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, + {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, + + {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, + {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, + {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, + + {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, + {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, + {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, +}; + +#define SELECTION_REGISTER(KeyT, IdxT, A) \ + namespace BENCHMARK_PRIVATE_NAME(selection) \ + { \ + using SelectK = selection; \ + RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ + } + +SELECTION_REGISTER(float, int, kPublicApi); // NOLINT +SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT + +SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT + +SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT + +} // namespace raft::matrix diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index d38631b289..633ea33670 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ #include #if defined RAFT_DISTANCE_COMPILED #include +#include #endif #endif @@ -148,7 +149,7 @@ struct ivf_flat_knn { raft::neighbors::ivf_flat::search_params search_params; params ps; - ivf_flat_knn(const raft::handle_t& handle, const params& ps, const ValT* data) : ps(ps) + ivf_flat_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps) { index_params.n_lists = 4096; index_params.metric = raft::distance::DistanceType::L2Expanded; @@ -156,7 +157,7 @@ struct ivf_flat_knn { handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); } - void search(const raft::handle_t& handle, + void search(const raft::device_resources& handle, const ValT* search_items, dist_t* out_dists, IdxT* out_idxs) @@ -176,7 +177,7 @@ struct ivf_pq_knn { raft::neighbors::ivf_pq::search_params search_params; params ps; - ivf_pq_knn(const raft::handle_t& handle, const params& ps, const ValT* data) : ps(ps) + ivf_pq_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps) { index_params.n_lists = 4096; index_params.metric = raft::distance::DistanceType::L2Expanded; @@ -184,7 +185,7 @@ struct ivf_pq_knn { handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); } - void search(const raft::handle_t& handle, + void search(const raft::device_resources& handle, const ValT* search_items, dist_t* out_dists, IdxT* out_idxs) @@ -202,12 +203,12 @@ struct brute_force_knn { ValT* index; params ps; - brute_force_knn(const raft::handle_t& handle, const params& ps, const ValT* data) + brute_force_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : index(const_cast(data)), ps(ps) { } - void search(const raft::handle_t& handle, + void search(const raft::device_resources& handle, const ValT* search_items, dist_t* out_dists, IdxT* out_idxs) @@ -287,7 +288,7 @@ struct knn : public fixture { std::ostringstream label_stream; label_stream << params_ << "#" << strategy_ << "#" << scope_; state.SetLabel(label_stream.str()); - raft::handle_t handle(stream); + raft::device_resources handle(stream); std::optional index; if (scope_ == Scope::SEARCH) { // also implies TransferStrategy::NO_COPY diff --git a/cpp/bench/neighbors/refine.cu b/cpp/bench/neighbors/refine.cu index a038905ace..255004361c 100644 --- a/cpp/bench/neighbors/refine.cu +++ b/cpp/bench/neighbors/refine.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,15 +14,16 @@ * limitations under the License. */ -#include +#include -#include +#include #include -#include +#include #include #include #include +#include #if defined RAFT_DISTANCE_COMPILED #include @@ -36,12 +37,10 @@ #include #include -#include "../../test/neighbors/refine_helper.cuh" - #include #include -using namespace raft::neighbors::detail; +using namespace raft::neighbors; namespace raft::bench::neighbors { @@ -53,7 +52,7 @@ inline auto operator<<(std::ostream& os, const RefineInputs& p) -> std::os return os; } -RefineInputs p; +RefineInputs p; template class RefineAnn : public fixture { @@ -95,28 +94,28 @@ class RefineAnn : public fixture { } private: - raft::handle_t handle_; + raft::device_resources handle_; RefineHelper data; }; -std::vector> getInputs() +std::vector> getInputs() { - std::vector> out; + std::vector> out; raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; for (bool host_data : {true, false}) { - for (int64_t n_queries : {1000, 10000}) { - for (int64_t dim : {128, 512}) { - out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); - out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); + for (uint64_t n_queries : {1000, 10000}) { + for (uint64_t dim : {128, 512}) { + out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); + out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); } } } return out; } -using refine_float_int64 = RefineAnn; +using refine_float_int64 = RefineAnn; RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); -using refine_uint8_int64 = RefineAnn; +using refine_uint8_int64 = RefineAnn; RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); } // namespace raft::bench::neighbors diff --git a/cpp/bench/neighbors/selection.cu b/cpp/bench/neighbors/selection.cu deleted file mode 100644 index 1f116c199f..0000000000 --- a/cpp/bench/neighbors/selection.cu +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#if defined RAFT_NN_COMPILED -#include -#endif - -#include -#include - -#include -#include - -namespace raft::bench::spatial { - -struct params { - int n_inputs; - int input_len; - int k; - int select_min; -}; - -template -struct selection : public fixture { - explicit selection(const params& p) - : params_(p), - in_dists_(p.n_inputs * p.input_len, stream), - in_ids_(p.n_inputs * p.input_len, stream), - out_dists_(p.n_inputs * p.k, stream), - out_ids_(p.n_inputs * p.k, stream) - { - raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream); - raft::random::RngState state{42}; - raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); - } - - void run_benchmark(::benchmark::State& state) override - { - using_pool_memory_res res; - try { - std::ostringstream label_stream; - label_stream << params_.n_inputs << "#" << params_.input_len << "#" << params_.k; - state.SetLabel(label_stream.str()); - loop_on_state(state, [this]() { - raft::spatial::knn::select_k(in_dists_.data(), - in_ids_.data(), - params_.n_inputs, - params_.input_len, - out_dists_.data(), - out_ids_.data(), - params_.select_min, - params_.k, - stream, - Algo); - }); - } catch (raft::exception& e) { - state.SkipWithError(e.what()); - } - } - - private: - const params params_; - rmm::device_uvector in_dists_, out_dists_; - rmm::device_uvector in_ids_, out_ids_; -}; - -const std::vector kInputs{ - {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, - {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, - {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, - - {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, - {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, - {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, - - {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, - {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, - {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, - - {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, - {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, - {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, -}; - -#define SELECTION_REGISTER(KeyT, IdxT, Algo) \ - namespace BENCHMARK_PRIVATE_NAME(selection) \ - { \ - using SelectK = selection; \ - RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \ - } - -SELECTION_REGISTER(float, int, FAISS); -SELECTION_REGISTER(float, int, RADIX_8_BITS); -SELECTION_REGISTER(float, int, RADIX_11_BITS); -SELECTION_REGISTER(float, int, WARP_SORT); - -SELECTION_REGISTER(double, int, FAISS); -SELECTION_REGISTER(double, int, RADIX_8_BITS); -SELECTION_REGISTER(double, int, RADIX_11_BITS); -SELECTION_REGISTER(double, int, WARP_SORT); - -SELECTION_REGISTER(double, size_t, FAISS); -SELECTION_REGISTER(double, size_t, RADIX_8_BITS); -SELECTION_REGISTER(double, size_t, RADIX_11_BITS); -SELECTION_REGISTER(double, size_t, WARP_SORT); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/random/make_blobs.cu b/cpp/bench/random/make_blobs.cu index fdd4ef61d2..950d80c499 100644 --- a/cpp/bench/random/make_blobs.cu +++ b/cpp/bench/random/make_blobs.cu @@ -25,6 +25,12 @@ struct make_blobs_inputs { bool row_major; }; // struct make_blobs_inputs +inline auto operator<<(std::ostream& os, const make_blobs_inputs& p) -> std::ostream& +{ + os << p.rows << "#" << p.cols << "#" << p.clusters << "#" << p.row_major; + return os; +} + template struct make_blobs : public fixture { make_blobs(const make_blobs_inputs& p) @@ -34,6 +40,10 @@ struct make_blobs : public fixture { void run_benchmark(::benchmark::State& state) override { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + loop_on_state(state, [this]() { raft::random::make_blobs(data.data(), labels.data(), diff --git a/cpp/bench/random/permute.cu b/cpp/bench/random/permute.cu index 5364bb44e3..cb9e21868b 100644 --- a/cpp/bench/random/permute.cu +++ b/cpp/bench/random/permute.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ struct permute : public fixture { } private: - raft::handle_t handle; + raft::device_resources handle; permute_inputs params; rmm::device_uvector out, in; rmm::device_uvector perms; diff --git a/cpp/bench/sparse/convert_csr.cu b/cpp/bench/sparse/convert_csr.cu index 830fab13cc..c9dcae6985 100644 --- a/cpp/bench/sparse/convert_csr.cu +++ b/cpp/bench/sparse/convert_csr.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -107,7 +107,7 @@ struct bench_base : public fixture { } protected: - raft::handle_t handle; + raft::device_resources handle; bench_param params; rmm::device_uvector adj; rmm::device_uvector row_ind; diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 811a5466c3..3e02ce064e 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -30,6 +30,10 @@ function(find_and_configure_cutlass) CACHE BOOL "Disable CUTLASS to build with cuBLAS library." ) + if (CUDA_STATIC_RUNTIME) + set(CUDART_LIBRARY "${CUDA_cudart_static_LIBRARY}" CACHE FILEPATH "fixing cutlass cmake code" FORCE) + endif() + rapids_cpm_find( NvidiaCutlass ${PKG_VERSION} GLOBAL_TARGETS nvidia::cutlass::cutlass diff --git a/cpp/include/raft/cluster/detail/agglomerative.cuh b/cpp/include/raft/cluster/detail/agglomerative.cuh index 618f852bba..f4b2ecf051 100644 --- a/cpp/include/raft/cluster/detail/agglomerative.cuh +++ b/cpp/include/raft/cluster/detail/agglomerative.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -100,7 +100,7 @@ class UnionFind { * @param[out] out_size cluster sizes of output */ template -void build_dendrogram_host(const handle_t& handle, +void build_dendrogram_host(raft::device_resources const& handle, const value_idx* rows, const value_idx* cols, const value_t* data, @@ -236,7 +236,7 @@ struct init_label_roots { * @param n_leaves */ template -void extract_flattened_clusters(const raft::handle_t& handle, +void extract_flattened_clusters(raft::device_resources const& handle, value_idx* labels, const value_idx* children, size_t n_clusters, diff --git a/cpp/include/raft/cluster/detail/connectivities.cuh b/cpp/include/raft/cluster/detail/connectivities.cuh index a07045f0d2..163670f29a 100644 --- a/cpp/include/raft/cluster/detail/connectivities.cuh +++ b/cpp/include/raft/cluster/detail/connectivities.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -39,7 +40,7 @@ namespace raft::cluster::detail { template struct distance_graph_impl { - void run(const raft::handle_t& handle, + void run(raft::device_resources const& handle, const value_t* X, size_t m, size_t n, @@ -57,7 +58,7 @@ struct distance_graph_impl { */ template struct distance_graph_impl { - void run(const raft::handle_t& handle, + void run(raft::device_resources const& handle, const value_t* X, size_t m, size_t n, @@ -103,6 +104,98 @@ struct distance_graph_impl +__global__ void fill_indices2(value_idx* indices, size_t m, size_t nnz) +{ + value_idx tid = (blockIdx.x * blockDim.x) + threadIdx.x; + if (tid >= nnz) return; + value_idx v = tid % m; + indices[tid] = v; +} + +/** + * Compute connected CSR of pairwise distances + * @tparam value_idx + * @tparam value_t + * @param handle + * @param X + * @param m + * @param n + * @param metric + * @param[out] indptr + * @param[out] indices + * @param[out] data + */ +template +void pairwise_distances(const raft::device_resources& handle, + const value_t* X, + size_t m, + size_t n, + raft::distance::DistanceType metric, + value_idx* indptr, + value_idx* indices, + value_t* data) +{ + auto stream = handle.get_stream(); + auto exec_policy = handle.get_thrust_policy(); + + value_idx nnz = m * m; + + value_idx blocks = raft::ceildiv(nnz, (value_idx)256); + fill_indices2<<>>(indices, m, nnz); + + thrust::sequence(exec_policy, indptr, indptr + m, 0, (int)m); + + raft::update_device(indptr + m, &nnz, 1, stream); + + // TODO: It would ultimately be nice if the MST could accept + // dense inputs directly so we don't need to double the memory + // usage to hand it a sparse array here. + distance::pairwise_distance(handle, X, X, data, m, m, n, metric); + // self-loops get max distance + auto transform_in = + thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), data)); + + thrust::transform(exec_policy, + transform_in, + transform_in + nnz, + data, + [=] __device__(const thrust::tuple& tup) { + value_idx idx = thrust::get<0>(tup); + bool self_loop = idx % m == idx / m; + return (self_loop * std::numeric_limits::max()) + + (!self_loop * thrust::get<1>(tup)); + }); +} + +/** + * Connectivities specialization for pairwise distances + * @tparam value_idx + * @tparam value_t + */ +template +struct distance_graph_impl { + void run(const raft::device_resources& handle, + const value_t* X, + size_t m, + size_t n, + raft::distance::DistanceType metric, + rmm::device_uvector& indptr, + rmm::device_uvector& indices, + rmm::device_uvector& data, + int c) + { + auto stream = handle.get_stream(); + + size_t nnz = m * m; + + indices.resize(nnz, stream); + data.resize(nnz, stream); + + pairwise_distances(handle, X, m, n, metric, indptr.data(), indices.data(), data.data()); + } +}; + /** * Returns a CSR connectivities graph based on the given linkage distance. * @tparam value_idx @@ -120,7 +213,7 @@ struct distance_graph_impl -void get_distance_graph(const raft::handle_t& handle, +void get_distance_graph(raft::device_resources const& handle, const value_t* X, size_t m, size_t n, diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 5aa9870b46..9632fedb9d 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,17 +31,19 @@ #include #include #include -#include +#include #include #include #include #include +#include #include #include #include #include #include #include +#include #include #include #include @@ -57,7 +59,7 @@ namespace detail { // Selects 'n_clusters' samples randomly from X template -void initRandom(const raft::handle_t& handle, +void initRandom(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids) @@ -83,7 +85,7 @@ void initRandom(const raft::handle_t& handle, * 5: end for */ template -void kmeansPlusPlus(const raft::handle_t& handle, +void kmeansPlusPlus(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroidsRawData, @@ -109,7 +111,7 @@ void kmeansPlusPlus(const raft::handle_t& handle, auto dataBatchSize = getDataBatchSize(params.batch_samples, n_samples); // temporary buffers - std::vector h_wt(n_samples); + auto indices = raft::make_device_vector(handle, n_trials); auto centroidCandidates = raft::make_device_matrix(handle, n_trials, n_features); auto costPerCandidate = raft::make_device_vector(handle, n_trials); auto minClusterDistance = raft::make_device_vector(handle, n_samples); @@ -119,6 +121,17 @@ void kmeansPlusPlus(const raft::handle_t& handle, rmm::device_scalar clusterCost(stream); rmm::device_scalar> minClusterIndexAndDistance(stream); + // Device and matrix views + raft::device_vector_view indices_view(indices.data_handle(), n_trials); + auto const_weights_view = + raft::make_device_vector_view(minClusterDistance.data_handle(), n_samples); + auto const_indices_view = + raft::make_device_vector_view(indices.data_handle(), n_trials); + auto const_X_view = + raft::make_device_matrix_view(X.data_handle(), n_samples, n_features); + raft::device_matrix_view candidates_view( + centroidCandidates.data_handle(), n_trials, n_features); + // L2 norm of X: ||c||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); @@ -133,6 +146,7 @@ void kmeansPlusPlus(const raft::handle_t& handle, stream); } + raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); std::mt19937 gen(params.rng_state.seed); std::uniform_int_distribution<> dis(0, n_samples - 1); @@ -169,20 +183,9 @@ void kmeansPlusPlus(const raft::handle_t& handle, // <<< Step-3 >>> : Sample x in X with probability p_x = d^2(x, C) / phi_X (C) // Choose 'n_trials' centroid candidates from X with probability proportional to the squared // distance to the nearest existing cluster - raft::copy(h_wt.data(), minClusterDistance.data_handle(), minClusterDistance.size(), stream); - handle.sync_stream(stream); - // Note - n_trials is relative small here, we don't need raft::gather call - std::discrete_distribution<> d(h_wt.begin(), h_wt.end()); - for (int cIdx = 0; cIdx < n_trials; ++cIdx) { - auto rand_idx = d(gen); - auto randCentroid = raft::make_device_matrix_view( - X.data_handle() + n_features * rand_idx, 1, n_features); - raft::copy(centroidCandidates.data_handle() + cIdx * n_features, - randCentroid.data_handle(), - randCentroid.size(), - stream); - } + raft::random::discrete(handle, rng, indices_view, const_weights_view); + raft::matrix::gather(handle, const_X_view, const_indices_view, candidates_view); // Calculate pairwise distance between X and the centroid candidates // Output - pwd [n_trials x n_samples] @@ -195,16 +198,15 @@ void kmeansPlusPlus(const raft::handle_t& handle, // Outputs minDistanceBuf[n_trials x n_samples] where minDistance[i, :] contains updated // minClusterDistance that includes candidate-i auto minDistBuf = distBuffer.view(); - raft::linalg::matrixVectorOp( - minDistBuf.data_handle(), - pwd.data_handle(), - minClusterDistance.data_handle(), - pwd.extent(1), - pwd.extent(0), - true, - true, - [=] __device__(DataT mat, DataT vec) { return vec <= mat ? vec : mat; }, - stream); + raft::linalg::matrixVectorOp(minDistBuf.data_handle(), + pwd.data_handle(), + minClusterDistance.data_handle(), + pwd.extent(1), + pwd.extent(0), + true, + true, + raft::min_op{}, + stream); // Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using // centroid candidate-i @@ -226,7 +228,8 @@ void kmeansPlusPlus(const raft::handle_t& handle, temp_storage_bytes, costPerCandidate.data_handle(), minClusterIndexAndDistance.data(), - costPerCandidate.extent(0)); + costPerCandidate.extent(0), + stream); // Allocate temporary storage workspace.resize(temp_storage_bytes, stream); @@ -236,10 +239,12 @@ void kmeansPlusPlus(const raft::handle_t& handle, temp_storage_bytes, costPerCandidate.data_handle(), minClusterIndexAndDistance.data(), - costPerCandidate.extent(0)); + costPerCandidate.extent(0), + stream); int bestCandidateIdx = -1; raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream); + handle.sync_stream(); /// <<< End of Step-3 >>> /// <<< Step-4 >>>: C = C U {x} @@ -277,7 +282,7 @@ void kmeansPlusPlus(const raft::handle_t& handle, * @param[inout] workspace */ template -void update_centroids(const raft::handle_t& handle, +void update_centroids(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_vector_view sample_weights, raft::device_matrix_view centroids, @@ -321,21 +326,15 @@ void update_centroids(const raft::handle_t& handle, // weight_per_cluster[n_clusters] - 1D array, weight_per_cluster[i] contains sum of weights in // cluster-i. // Note - when weight_per_cluster[i] is 0, new_centroids[i] is reset to 0 - raft::linalg::matrixVectorOp( - new_centroids.data_handle(), - new_centroids.data_handle(), - weight_per_cluster.data_handle(), - new_centroids.extent(1), - new_centroids.extent(0), - true, - false, - [=] __device__(DataT mat, DataT vec) { - if (vec == 0) - return DataT(0); - else - return mat / vec; - }, - handle.get_stream()); + raft::linalg::matrixVectorOp(new_centroids.data_handle(), + new_centroids.data_handle(), + weight_per_cluster.data_handle(), + new_centroids.extent(1), + new_centroids.extent(0), + true, + false, + raft::div_checkzero_op{}, + handle.get_stream()); // copy centroids[i] to new_centroids[i] when weight_per_cluster[i] is 0 cub::ArgIndexInputIterator itr_wt(weight_per_cluster.data_handle()); @@ -351,15 +350,13 @@ void update_centroids(const raft::handle_t& handle, // copy when the sum of weights in the cluster is 0 return map.value == 0; }, - [=] __device__(raft::KeyValuePair map) { // map - return map.key; - }, + raft::key_op{}, handle.get_stream()); } // TODO: Resizing is needed to use mdarray instead of rmm::device_uvector template -void kmeans_fit_main(const raft::handle_t& handle, +void kmeans_fit_main(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view weight, @@ -394,7 +391,7 @@ void kmeans_fit_main(const raft::handle_t& handle, // resource auto wtInCluster = raft::make_device_vector(handle, n_clusters); - rmm::device_scalar> clusterCostD(stream); + rmm::device_scalar clusterCostD(stream); // L2 norm of X: ||x||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); @@ -465,16 +462,12 @@ void kmeans_fit_main(const raft::handle_t& handle, // compute the squared norm between the newCentroids and the original // centroids, destructor releases the resource auto sqrdNorm = raft::make_device_scalar(handle, DataT(0)); - raft::linalg::mapThenSumReduce( - sqrdNorm.data_handle(), - newCentroids.size(), - [=] __device__(const DataT a, const DataT b) { - DataT diff = a - b; - return diff * diff; - }, - stream, - centroids.data_handle(), - newCentroids.data_handle()); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + newCentroids.size(), + raft::sqdiff_op{}, + stream, + centroids.data_handle(), + newCentroids.data_handle()); DataT sqrdNormError = 0; raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream); @@ -489,18 +482,11 @@ void kmeans_fit_main(const raft::handle_t& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }); - - DataT curClusteringCost = 0; - raft::copy(&curClusteringCost, &(clusterCostD.data()->value), 1, stream); - - handle.sync_stream(stream); + raft::value_op{}, + raft::add_op{}); + + DataT curClusteringCost = clusterCostD.value(stream); + ASSERT(curClusteringCost != (DataT)0.0, "Too few points and centroids being found is getting 0 cost from " "centers"); @@ -553,15 +539,10 @@ void kmeans_fit_main(const raft::handle_t& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }); + raft::value_op{}, + raft::add_op{}); - raft::copy(inertia.data_handle(), &(clusterCostD.data()->value), 1, stream); + inertia[0] = clusterCostD.value(stream); RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], @@ -592,7 +573,7 @@ void kmeans_fit_main(const raft::handle_t& handle, */ template -void initScalableKMeansPlusPlus(const raft::handle_t& handle, +void initScalableKMeansPlusPlus(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroidsRawData, @@ -673,7 +654,8 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, minClusterDistanceVec.view(), workspace, raft::make_device_scalar_view(clusterCost.data()), - [] __device__(const DataT& a, const DataT& b) { return a + b; }); + raft::identity_op{}, + raft::add_op{}); auto psi = clusterCost.value(stream); @@ -705,7 +687,8 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, minClusterDistanceVec.view(), workspace, raft::make_device_scalar_view(clusterCost.data()), - [] __device__(const DataT& a, const DataT& b) { return a + b; }); + raft::identity_op{}, + raft::add_op{}); psi = clusterCost.value(stream); @@ -833,7 +816,7 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, * @param[out] n_iter Number of iterations run. */ template -void kmeans_fit(handle_t const& handle, +void kmeans_fit(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -972,7 +955,7 @@ void kmeans_fit(handle_t const& handle, } template -void kmeans_fit(handle_t const& handle, +void kmeans_fit(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -997,7 +980,7 @@ void kmeans_fit(handle_t const& handle, } template -void kmeans_predict(handle_t const& handle, +void kmeans_predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -1074,7 +1057,7 @@ void kmeans_predict(handle_t const& handle, workspace); // calculate cluster cost phi_x(C) - rmm::device_scalar> clusterCostD(stream); + rmm::device_scalar clusterCostD(stream); // TODO: add different templates for InType of binaryOp to avoid thrust transform thrust::transform(handle.get_thrust_policy(), minClusterAndDistance.data_handle(), @@ -1092,25 +1075,20 @@ void kmeans_predict(handle_t const& handle, minClusterAndDistance.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - }); - - raft::copy(inertia.data_handle(), &(clusterCostD.data()->value), 1, stream); + raft::value_op{}, + raft::add_op{}); thrust::transform(handle.get_thrust_policy(), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), labels.data_handle(), - [=] __device__(raft::KeyValuePair pair) { return pair.key; }); + raft::key_op{}); + + inertia[0] = clusterCostD.value(stream); } template -void kmeans_predict(handle_t const& handle, +void kmeans_predict(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -1142,7 +1120,7 @@ void kmeans_predict(handle_t const& handle, } template -void kmeans_fit_predict(handle_t const& handle, +void kmeans_fit_predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -1169,7 +1147,7 @@ void kmeans_fit_predict(handle_t const& handle, } template -void kmeans_fit_predict(handle_t const& handle, +void kmeans_fit_predict(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -1209,7 +1187,7 @@ void kmeans_fit_predict(handle_t const& handle, * @param[out] X_new X transformed in the new space.. */ template -void kmeans_transform(const raft::handle_t& handle, +void kmeans_transform(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -1250,7 +1228,7 @@ void kmeans_transform(const raft::handle_t& handle, } template -void kmeans_transform(const raft::handle_t& handle, +void kmeans_transform(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* centroids, diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh new file mode 100644 index 0000000000..3d23c809c3 --- /dev/null +++ b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh @@ -0,0 +1,1095 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace raft::cluster::detail { + +constexpr static inline float kAdjustCentersWeight = 7.0f; + +/** + * @brief Predict labels for the dataset; floating-point types only. + * + * NB: no minibatch splitting is done here, it may require large amount of temporary memory (n_rows + * * n_cluster * sizeof(MathT)). + * + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * + * @param[in] handle The raft handle. + * @param[in] params Structure containing the hyper-parameters + * @param[in] centers Pointer to the row-major matrix of cluster centers [n_clusters, dim] + * @param[in] n_clusters Number of clusters/centers + * @param[in] dim Dimensionality of the data + * @param[in] dataset Pointer to the data [n_rows, dim] + * @param[in] dataset_norm Pointer to the precomputed norm (for L2 metrics only) [n_rows] + * @param[in] n_rows Number samples in the `dataset` + * @param[out] labels Output predictions [n_rows] + * @param[inout] mr (optional) Memory resource to use for temporary allocations + */ +template +inline std::enable_if_t> predict_core( + const raft::device_resources& handle, + const kmeans_balanced_params& params, + const MathT* centers, + IdxT n_clusters, + IdxT dim, + const MathT* dataset, + const MathT* dataset_norm, + IdxT n_rows, + LabelT* labels, + rmm::mr::device_memory_resource* mr) +{ + auto stream = handle.get_stream(); + switch (params.metric) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: { + auto workspace = raft::make_device_mdarray( + handle, mr, make_extents((sizeof(int)) * n_rows)); + + auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( + handle, mr, make_extents(n_rows)); + raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + thrust::fill(handle.get_thrust_policy(), + minClusterAndDistance.data_handle(), + minClusterAndDistance.data_handle() + minClusterAndDistance.size(), + initial_value); + + auto centroidsNorm = + raft::make_device_mdarray(handle, mr, make_extents(n_clusters)); + raft::linalg::rowNorm( + centroidsNorm.data_handle(), centers, dim, n_clusters, raft::linalg::L2Norm, true, stream); + + raft::distance::fusedL2NNMinReduce, IdxT>( + minClusterAndDistance.data_handle(), + dataset, + centers, + dataset_norm, + centroidsNorm.data_handle(), + n_rows, + n_clusters, + dim, + (void*)workspace.data_handle(), + (params.metric == raft::distance::DistanceType::L2Expanded) ? false : true, + false, + stream); + + // todo(lsugy): use KVP + iterator in caller. + // Copy keys to output labels + thrust::transform(handle.get_thrust_policy(), + minClusterAndDistance.data_handle(), + minClusterAndDistance.data_handle() + n_rows, + labels, + raft::compose_op, raft::key_op>()); + break; + } + case raft::distance::DistanceType::InnerProduct: { + // TODO: pass buffer + rmm::device_uvector distances(n_rows * n_clusters, stream, mr); + + MathT alpha = -1.0; + MathT beta = 0.0; + + linalg::gemm(handle, + true, + false, + n_clusters, + n_rows, + dim, + &alpha, + centers, + dim, + dataset, + dim, + &beta, + distances.data(), + n_clusters, + stream); + + auto distances_const_view = raft::make_device_matrix_view( + distances.data(), n_rows, n_clusters); + auto labels_view = raft::make_device_vector_view(labels, n_rows); + raft::matrix::argmin(handle, distances_const_view, labels_view); + break; + } + default: { + RAFT_FAIL("The chosen distance metric is not supported (%d)", int(params.metric)); + } + } +} + +/** + * @brief Suggest a minibatch size for kmeans prediction. + * + * This function is used as a heuristic to split the work over a large dataset + * to reduce the size of temporary memory allocations. + * + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * + * @param[in] n_clusters number of clusters in kmeans clustering + * @param[in] n_rows Number of samples in the dataset + * @param[in] dim Number of features in the dataset + * @param[in] metric Distance metric + * @param[in] needs_conversion Whether the data needs to be converted to MathT + * @return A suggested minibatch size and the expected memory cost per-row (in bytes) + */ +template +constexpr auto calc_minibatch_size(IdxT n_clusters, + IdxT n_rows, + IdxT dim, + raft::distance::DistanceType metric, + bool needs_conversion) -> std::tuple +{ + n_clusters = std::max(1, n_clusters); + + // Estimate memory needs per row (i.e element of the batch). + size_t mem_per_row = 0; + switch (metric) { + // fusedL2NN needs a mutex and a key-value pair for each row. + case distance::DistanceType::L2Expanded: + case distance::DistanceType::L2SqrtExpanded: { + mem_per_row += sizeof(int); + mem_per_row += sizeof(raft::KeyValuePair); + } break; + // Other metrics require storing a distance matrix. + default: { + mem_per_row += sizeof(MathT) * n_clusters; + } + } + + // If we need to convert to MathT, space required for the converted batch. + if (!needs_conversion) { mem_per_row += sizeof(MathT) * dim; } + + // Heuristic: calculate the minibatch size in order to use at most 1GB of memory. + IdxT minibatch_size = (1 << 30) / mem_per_row; + minibatch_size = 64 * div_rounding_up_safe(minibatch_size, IdxT{64}); + minibatch_size = std::min(minibatch_size, n_rows); + return std::make_tuple(minibatch_size, mem_per_row); +} + +/** + * @brief Given the data and labels, calculate cluster centers and sizes in one sweep. + * + * @note all pointers must be accessible on the device. + * + * @tparam T element type + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * @tparam CounterT counter type supported by CUDA's native atomicAdd + * @tparam MappingOpT type of the mapping operation + * + * @param[in] handle The raft handle. + * @param[inout] centers Pointer to the output [n_clusters, dim] + * @param[inout] cluster_sizes Number of rows in each cluster [n_clusters] + * @param[in] n_clusters Number of clusters/centers + * @param[in] dim Dimensionality of the data + * @param[in] dataset Pointer to the data [n_rows, dim] + * @param[in] n_rows Number of samples in the `dataset` + * @param[in] labels Output predictions [n_rows] + * @param[in] reset_counters Whether to clear the output arrays before calculating. + * When set to `false`, this function may be used to update existing centers and sizes using + * the weighted average principle. + * @param[in] mapping_op Mapping operation from T to MathT + * @param[inout] mr (optional) Memory resource to use for temporary allocations on the device + */ +template +void calc_centers_and_sizes(const raft::device_resources& handle, + MathT* centers, + CounterT* cluster_sizes, + IdxT n_clusters, + IdxT dim, + const T* dataset, + IdxT n_rows, + const LabelT* labels, + bool reset_counters, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* mr = nullptr) +{ + auto stream = handle.get_stream(); + if (mr == nullptr) { mr = handle.get_workspace_resource(); } + + if (!reset_counters) { + raft::linalg::matrixVectorOp( + centers, centers, cluster_sizes, dim, n_clusters, true, false, raft::mul_op(), stream); + } + + rmm::device_uvector workspace(0, stream, mr); + + // If we reset the counters, we can compute directly the new sizes in cluster_sizes. + // If we don't reset, we compute in a temporary buffer and add in a separate step. + rmm::device_uvector temp_cluster_sizes(0, stream, mr); + CounterT* temp_sizes = cluster_sizes; + if (!reset_counters) { + temp_cluster_sizes.resize(n_clusters, stream); + temp_sizes = temp_cluster_sizes.data(); + } + + // Apply mapping only when the data and math types are different. + if constexpr (std::is_same_v) { + raft::linalg::reduce_rows_by_key( + dataset, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); + } else { + // todo(lsugy): use iterator from KV output of fusedL2NN + cub::TransformInputIterator mapping_itr(dataset, mapping_op); + raft::linalg::reduce_rows_by_key( + mapping_itr, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); + } + + // Compute weight of each cluster + raft::cluster::detail::countLabels(handle, labels, temp_sizes, n_rows, n_clusters, workspace); + + // Add previous sizes if necessary + if (!reset_counters) { + raft::linalg::add(cluster_sizes, cluster_sizes, temp_sizes, n_clusters, stream); + } + + raft::linalg::matrixVectorOp(centers, + centers, + cluster_sizes, + dim, + n_clusters, + true, + false, + raft::div_checkzero_op(), + stream); +} + +/** Computes the L2 norm of the dataset, converting to MathT if necessary */ +template +void compute_norm(const raft::device_resources& handle, + MathT* dataset_norm, + const T* dataset, + IdxT dim, + IdxT n_rows, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* mr = nullptr) +{ + common::nvtx::range fun_scope("compute_norm"); + auto stream = handle.get_stream(); + if (mr == nullptr) { mr = handle.get_workspace_resource(); } + rmm::device_uvector mapped_dataset(0, stream, mr); + + const MathT* dataset_ptr = nullptr; + + if (std::is_same_v) { + dataset_ptr = reinterpret_cast(dataset); + } else { + mapped_dataset.resize(n_rows * dim, stream); + + linalg::unaryOp(mapped_dataset.data(), dataset, n_rows * dim, mapping_op, stream); + + dataset_ptr = (const MathT*)mapped_dataset.data(); + } + + raft::linalg::rowNorm( + dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream); +} + +/** + * @brief Predict labels for the dataset. + * + * @tparam T element type + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * @tparam MappingOpT type of the mapping operation + * + * @param[in] handle The raft handle + * @param[in] params Structure containing the hyper-parameters + * @param[in] centers Pointer to the row-major matrix of cluster centers [n_clusters, dim] + * @param[in] n_clusters Number of clusters/centers + * @param[in] dim Dimensionality of the data + * @param[in] dataset Pointer to the data [n_rows, dim] + * @param[in] n_rows Number samples in the `dataset` + * @param[out] labels Output predictions [n_rows] + * @param[in] mapping_op Mapping operation from T to MathT + * @param[inout] mr (optional) memory resource to use for temporary allocations + * @param[in] dataset_norm (optional) Pre-computed norms of each row in the dataset [n_rows] + */ +template +void predict(const raft::device_resources& handle, + const kmeans_balanced_params& params, + const MathT* centers, + IdxT n_clusters, + IdxT dim, + const T* dataset, + IdxT n_rows, + LabelT* labels, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* mr = nullptr, + const MathT* dataset_norm = nullptr) +{ + auto stream = handle.get_stream(); + common::nvtx::range fun_scope( + "predict(%zu, %u)", static_cast(n_rows), n_clusters); + if (mr == nullptr) { mr = handle.get_workspace_resource(); } + auto [max_minibatch_size, _mem_per_row] = + calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); + rmm::device_uvector cur_dataset( + std::is_same_v ? 0 : max_minibatch_size * dim, stream, mr); + bool need_compute_norm = + dataset_norm == nullptr && (params.metric == raft::distance::DistanceType::L2Expanded || + params.metric == raft::distance::DistanceType::L2SqrtExpanded); + rmm::device_uvector cur_dataset_norm( + need_compute_norm ? max_minibatch_size : 0, stream, mr); + const MathT* dataset_norm_ptr = nullptr; + auto cur_dataset_ptr = cur_dataset.data(); + for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { + IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); + + if constexpr (std::is_same_v) { + cur_dataset_ptr = const_cast(dataset + offset * dim); + } else { + linalg::unaryOp( + cur_dataset_ptr, dataset + offset * dim, minibatch_size * dim, mapping_op, stream); + } + + // Compute the norm now if it hasn't been pre-computed. + if (need_compute_norm) { + compute_norm( + handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mr); + dataset_norm_ptr = cur_dataset_norm.data(); + } else if (dataset_norm != nullptr) { + dataset_norm_ptr = dataset_norm + offset; + } + + predict_core(handle, + params, + centers, + n_clusters, + dim, + cur_dataset_ptr, + dataset_norm_ptr, + minibatch_size, + labels + offset, + mr); + } +} + +template +__global__ void __launch_bounds__((WarpSize * BlockDimY)) + adjust_centers_kernel(MathT* centers, // [n_clusters, dim] + IdxT n_clusters, + IdxT dim, + const T* dataset, // [n_rows, dim] + IdxT n_rows, + const LabelT* labels, // [n_rows] + const CounterT* cluster_sizes, // [n_clusters] + MathT threshold, + IdxT average, + IdxT seed, + IdxT* count, + MappingOpT mapping_op) +{ + IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.y); + if (l >= n_clusters) return; + auto csize = static_cast(cluster_sizes[l]); + // skip big clusters + if (csize > static_cast(average * threshold)) return; + + // choose a "random" i that belongs to a rather large cluster + IdxT i; + IdxT j = laneId(); + if (j == 0) { + do { + auto old = atomicAdd(count, IdxT{1}); + i = (seed * (old + 1)) % n_rows; + } while (static_cast(cluster_sizes[labels[i]]) < average); + } + i = raft::shfl(i, 0); + + // Adjust the center of the selected smaller cluster to gravitate towards + // a sample from the selected larger cluster. + const IdxT li = static_cast(labels[i]); + // Weight of the current center for the weighted average. + // We dump it for anomalously small clusters, but keep constant otherwise. + const MathT wc = min(static_cast(csize), static_cast(kAdjustCentersWeight)); + // Weight for the datapoint used to shift the center. + const MathT wd = 1.0; + for (; j < dim; j += WarpSize) { + MathT val = 0; + val += wc * centers[j + dim * li]; + val += wd * mapping_op(dataset[j + dim * i]); + val /= wc + wd; + centers[j + dim * l] = val; + } +} + +/** + * @brief Adjust centers for clusters that have small number of entries. + * + * For each cluster, where the cluster size is not bigger than a threshold, the center is moved + * towards a data point that belongs to a large cluster. + * + * NB: if this function returns `true`, you should update the labels. + * + * NB: all pointers must be on the device side. + * + * @tparam T element type + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * @tparam CounterT counter type supported by CUDA's native atomicAdd + * @tparam MappingOpT type of the mapping operation + * + * @param[inout] centers cluster centers [n_clusters, dim] + * @param[in] n_clusters number of rows in `centers` + * @param[in] dim number of columns in `centers` and `dataset` + * @param[in] dataset a host pointer to the row-major data matrix [n_rows, dim] + * @param[in] n_rows number of rows in `dataset` + * @param[in] labels a host pointer to the cluster indices [n_rows] + * @param[in] cluster_sizes number of rows in each cluster [n_clusters] + * @param[in] threshold defines a criterion for adjusting a cluster + * (cluster_sizes <= average_size * threshold) + * 0 <= threshold < 1 + * @param[in] mapping_op Mapping operation from T to MathT + * @param[in] stream CUDA stream + * @param[inout] device_memory memory resource to use for temporary allocations + * + * @return whether any of the centers has been updated (and thus, `labels` need to be recalculated). + */ +template +auto adjust_centers(MathT* centers, + IdxT n_clusters, + IdxT dim, + const T* dataset, + IdxT n_rows, + const LabelT* labels, + const CounterT* cluster_sizes, + MathT threshold, + MappingOpT mapping_op, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* device_memory) -> bool +{ + common::nvtx::range fun_scope( + "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); + if (n_clusters == 0) { return false; } + constexpr static std::array kPrimes{29, 71, 113, 173, 229, 281, 349, 409, 463, 541, + 601, 659, 733, 809, 863, 941, 1013, 1069, 1151, 1223, + 1291, 1373, 1451, 1511, 1583, 1657, 1733, 1811, 1889, 1987, + 2053, 2129, 2213, 2287, 2357, 2423, 2531, 2617, 2687, 2741}; + static IdxT i = 0; + static IdxT i_primes = 0; + + bool adjusted = false; + IdxT average = n_rows / n_clusters; + IdxT ofst; + do { + i_primes = (i_primes + 1) % kPrimes.size(); + ofst = kPrimes[i_primes]; + } while (n_rows % ofst == 0); + + constexpr uint32_t kBlockDimY = 4; + const dim3 block_dim(WarpSize, kBlockDimY, 1); + const dim3 grid_dim(1, raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1); + rmm::device_scalar update_count(0, stream, device_memory); + adjust_centers_kernel<<>>(centers, + n_clusters, + dim, + dataset, + n_rows, + labels, + cluster_sizes, + threshold, + average, + ofst, + update_count.data(), + mapping_op); + adjusted = update_count.value(stream) > 0; // NB: rmm scalar performs the sync + + return adjusted; +} + +/** + * @brief Expectation-maximization-balancing combined in an iterative process. + * + * Note, the `cluster_centers` is assumed to be already initialized here. + * Thus, this function can be used for fine-tuning existing clusters; + * to train from scratch, use `build_clusters` function below. + * + * @tparam T element type + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * @tparam CounterT counter type supported by CUDA's native atomicAdd + * @tparam MappingOpT type of the mapping operation + * + * @param[in] handle The raft handle + * @param[in] params Structure containing the hyper-parameters + * @param[in] n_iters Requested number of iterations (can differ from params.n_iter!) + * @param[in] dim Dimensionality of the dataset + * @param[in] dataset Pointer to a managed row-major array [n_rows, dim] + * @param[in] dataset_norm Pointer to the precomputed norm (for L2 metrics only) [n_rows] + * @param[in] n_rows Number of rows in the dataset + * @param[in] n_cluster Requested number of clusters + * @param[inout] cluster_centers Pointer to a managed row-major array [n_clusters, dim] + * @param[out] cluster_labels Pointer to a managed row-major array [n_rows] + * @param[out] cluster_sizes Pointer to a managed row-major array [n_clusters] + * @param[in] balancing_pullback + * if the cluster centers are rebalanced on this number of iterations, + * one extra iteration is performed (this could happen several times) (default should be `2`). + * In other words, the first and then every `ballancing_pullback`-th rebalancing operation adds + * one more iteration to the main cycle. + * @param[in] balancing_threshold + * the rebalancing takes place if any cluster is smaller than `avg_size * balancing_threshold` + * on a given iteration (default should be `~ 0.25`). + * @param[in] mapping_op Mapping operation from T to MathT + * @param[inout] device_memory + * A memory resource for device allocations (makes sense to provide a memory pool here) + */ +template +void balancing_em_iters(const raft::device_resources& handle, + const kmeans_balanced_params& params, + uint32_t n_iters, + IdxT dim, + const T* dataset, + const MathT* dataset_norm, + IdxT n_rows, + IdxT n_clusters, + MathT* cluster_centers, + LabelT* cluster_labels, + CounterT* cluster_sizes, + uint32_t balancing_pullback, + MathT balancing_threshold, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* device_memory) +{ + auto stream = handle.get_stream(); + uint32_t balancing_counter = balancing_pullback; + for (uint32_t iter = 0; iter < n_iters; iter++) { + // Balancing step - move the centers around to equalize cluster sizes + // (but not on the first iteration) + if (iter > 0 && adjust_centers(cluster_centers, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + cluster_sizes, + balancing_threshold, + mapping_op, + stream, + device_memory)) { + if (balancing_counter++ >= balancing_pullback) { + balancing_counter -= balancing_pullback; + n_iters++; + } + } + switch (params.metric) { + // For some metrics, cluster calculation and adjustment tends to favor zero center vectors. + // To avoid converging to zero, we normalize the center vectors on every iteration. + case raft::distance::DistanceType::InnerProduct: + case raft::distance::DistanceType::CosineExpanded: + case raft::distance::DistanceType::CorrelationExpanded: { + auto clusters_in_view = raft::make_device_matrix_view( + cluster_centers, n_clusters, dim); + auto clusters_out_view = raft::make_device_matrix_view( + cluster_centers, n_clusters, dim); + raft::linalg::row_normalize( + handle, clusters_in_view, clusters_out_view, raft::linalg::L2Norm); + break; + } + default: break; + } + // E: Expectation step - predict labels + predict(handle, + params, + cluster_centers, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + mapping_op, + device_memory, + dataset_norm); + // M: Maximization step - calculate optimal cluster centers + calc_centers_and_sizes(handle, + cluster_centers, + cluster_sizes, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + true, + mapping_op, + device_memory); + } +} + +/** Randomly initialize cluster centers and then call `balancing_em_iters`. */ +template +void build_clusters(const raft::device_resources& handle, + const kmeans_balanced_params& params, + IdxT dim, + const T* dataset, + IdxT n_rows, + IdxT n_clusters, + MathT* cluster_centers, + LabelT* cluster_labels, + CounterT* cluster_sizes, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* device_memory, + const MathT* dataset_norm = nullptr) +{ + auto stream = handle.get_stream(); + + // "randomly" initialize labels + auto labels_view = raft::make_device_vector_view(cluster_labels, n_rows); + linalg::map_offset( + handle, + labels_view, + raft::compose_op(raft::cast_op(), raft::mod_const_op(n_clusters))); + + // update centers to match the initialized labels. + calc_centers_and_sizes(handle, + cluster_centers, + cluster_sizes, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + true, + mapping_op, + device_memory); + + // run EM + balancing_em_iters(handle, + params, + params.n_iters, + dim, + dataset, + dataset_norm, + n_rows, + n_clusters, + cluster_centers, + cluster_labels, + cluster_sizes, + 2, + MathT{0.25}, + mapping_op, + device_memory); +} + +/** Calculate how many fine clusters should belong to each mesocluster. */ +template +inline auto arrange_fine_clusters(IdxT n_clusters, + IdxT n_mesoclusters, + IdxT n_rows, + const CounterT* mesocluster_sizes) +{ + std::vector fine_clusters_nums(n_mesoclusters); + std::vector fine_clusters_csum(n_mesoclusters + 1); + fine_clusters_csum[0] = 0; + + IdxT n_lists_rem = n_clusters; + IdxT n_nonempty_ms_rem = 0; + for (IdxT i = 0; i < n_mesoclusters; i++) { + n_nonempty_ms_rem += mesocluster_sizes[i] > CounterT{0} ? 1 : 0; + } + IdxT n_rows_rem = n_rows; + CounterT mesocluster_size_sum = 0; + CounterT mesocluster_size_max = 0; + IdxT fine_clusters_nums_max = 0; + for (IdxT i = 0; i < n_mesoclusters; i++) { + if (i < n_mesoclusters - 1) { + // Although the algorithm is meant to produce balanced clusters, when something + // goes wrong, we may get empty clusters (e.g. during development/debugging). + // The code below ensures a proportional arrangement of fine cluster numbers + // per mesocluster, even if some clusters are empty. + if (mesocluster_sizes[i] == 0) { + fine_clusters_nums[i] = 0; + } else { + n_nonempty_ms_rem--; + auto s = static_cast( + static_cast(n_lists_rem * mesocluster_sizes[i]) / n_rows_rem + .5); + s = std::min(s, n_lists_rem - n_nonempty_ms_rem); + fine_clusters_nums[i] = std::max(s, IdxT{1}); + } + } else { + fine_clusters_nums[i] = n_lists_rem; + } + n_lists_rem -= fine_clusters_nums[i]; + n_rows_rem -= mesocluster_sizes[i]; + mesocluster_size_max = max(mesocluster_size_max, mesocluster_sizes[i]); + mesocluster_size_sum += mesocluster_sizes[i]; + fine_clusters_nums_max = max(fine_clusters_nums_max, fine_clusters_nums[i]); + fine_clusters_csum[i + 1] = fine_clusters_csum[i] + fine_clusters_nums[i]; + } + + RAFT_EXPECTS(static_cast(mesocluster_size_sum) == n_rows, + "mesocluster sizes do not add up (%zu) to the total trainset size (%zu)", + static_cast(mesocluster_size_sum), + static_cast(n_rows)); + RAFT_EXPECTS(fine_clusters_csum[n_mesoclusters] == n_clusters, + "fine cluster numbers do not add up (%zu) to the total number of clusters (%zu)", + static_cast(fine_clusters_csum[n_mesoclusters]), + static_cast(n_clusters)); + + return std::make_tuple(static_cast(mesocluster_size_max), + fine_clusters_nums_max, + std::move(fine_clusters_nums), + std::move(fine_clusters_csum)); +} + +/** + * Given the (coarse) mesoclusters and the distribution of fine clusters within them, + * build the fine clusters. + * + * Processing one mesocluster at a time: + * 1. Copy mesocluster data into a separate buffer + * 2. Predict fine cluster + * 3. Refince the fine cluster centers + * + * As a result, the fine clusters are what is returned by `build_hierarchical`; + * this function returns the total number of fine clusters, which can be checked to be + * the same as the requested number of clusters. + * + * Note: this function uses at most `fine_clusters_nums_max` points per mesocluster for training; + * if one of the clusters is larger than that (as given by `mesocluster_sizes`), the extra data + * is ignored and a warning is reported. + */ +template +auto build_fine_clusters(const raft::device_resources& handle, + const kmeans_balanced_params& params, + IdxT dim, + const T* dataset_mptr, + const MathT* dataset_norm_mptr, + const LabelT* labels_mptr, + IdxT n_rows, + const IdxT* fine_clusters_nums, + const IdxT* fine_clusters_csum, + const CounterT* mesocluster_sizes, + IdxT n_mesoclusters, + IdxT mesocluster_size_max, + IdxT fine_clusters_nums_max, + MathT* cluster_centers, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* managed_memory, + rmm::mr::device_memory_resource* device_memory) -> IdxT +{ + auto stream = handle.get_stream(); + rmm::device_uvector mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory); + rmm::device_uvector mc_trainset_buf(mesocluster_size_max * dim, stream, device_memory); + rmm::device_uvector mc_trainset_norm_buf(mesocluster_size_max, stream, device_memory); + auto mc_trainset_ids = mc_trainset_ids_buf.data(); + auto mc_trainset = mc_trainset_buf.data(); + auto mc_trainset_norm = mc_trainset_norm_buf.data(); + + // label (cluster ID) of each vector + rmm::device_uvector mc_trainset_labels(mesocluster_size_max, stream, device_memory); + + rmm::device_uvector mc_trainset_ccenters( + fine_clusters_nums_max * dim, stream, device_memory); + // number of vectors in each cluster + rmm::device_uvector mc_trainset_csizes_tmp( + fine_clusters_nums_max, stream, device_memory); + + // Training clusters in each meso-cluster + IdxT n_clusters_done = 0; + for (IdxT i = 0; i < n_mesoclusters; i++) { + IdxT k = 0; + for (IdxT j = 0; j < n_rows && k < mesocluster_size_max; j++) { + if (labels_mptr[j] == LabelT(i)) { mc_trainset_ids[k++] = j; } + } + if (k != static_cast(mesocluster_sizes[i])) + RAFT_LOG_WARN("Incorrect mesocluster size at %d. %zu vs %zu", + static_cast(i), + static_cast(k), + static_cast(mesocluster_sizes[i])); + if (k == 0) { + RAFT_LOG_DEBUG("Empty cluster %d", i); + RAFT_EXPECTS(fine_clusters_nums[i] == 0, + "Number of fine clusters must be zero for the empty mesocluster (got %d)", + static_cast(fine_clusters_nums[i])); + continue; + } else { + RAFT_EXPECTS(fine_clusters_nums[i] > 0, + "Number of fine clusters must be non-zero for a non-empty mesocluster"); + } + + cub::TransformInputIterator mapping_itr(dataset_mptr, mapping_op); + raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream); + if (params.metric == raft::distance::DistanceType::L2Expanded || + params.metric == raft::distance::DistanceType::L2SqrtExpanded) { + thrust::gather(handle.get_thrust_policy(), + mc_trainset_ids, + mc_trainset_ids + k, + dataset_norm_mptr, + mc_trainset_norm); + } + + build_clusters(handle, + params, + dim, + mc_trainset, + k, + fine_clusters_nums[i], + mc_trainset_ccenters.data(), + mc_trainset_labels.data(), + mc_trainset_csizes_tmp.data(), + mapping_op, + device_memory, + mc_trainset_norm); + + raft::copy(cluster_centers + (dim * fine_clusters_csum[i]), + mc_trainset_ccenters.data(), + fine_clusters_nums[i] * dim, + stream); + handle.sync_stream(stream); + n_clusters_done += fine_clusters_nums[i]; + } + return n_clusters_done; +} + +/** + * @brief Hierarchical balanced k-means + * + * @tparam T element type + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * @tparam MappingOpT type of the mapping operation + * + * @param[in] handle The raft handle. + * @param[in] params Structure containing the hyper-parameters + * @param dim number of columns in `centers` and `dataset` + * @param[in] dataset a device pointer to the source dataset [n_rows, dim] + * @param n_rows number of rows in the input + * @param[out] cluster_centers a device pointer to the found cluster centers [n_cluster, dim] + * @param n_cluster + * @param metric the distance type + * @param mapping_op Mapping operation from T to MathT + * @param stream + */ +template +void build_hierarchical(const raft::device_resources& handle, + const kmeans_balanced_params& params, + IdxT dim, + const T* dataset, + IdxT n_rows, + MathT* cluster_centers, + IdxT n_clusters, + MappingOpT mapping_op) +{ + auto stream = handle.get_stream(); + using LabelT = uint32_t; + + common::nvtx::range fun_scope( + "build_hierarchical(%zu, %u)", static_cast(n_rows), n_clusters); + + IdxT n_mesoclusters = std::min(n_clusters, static_cast(std::sqrt(n_clusters) + 0.5)); + RAFT_LOG_DEBUG("build_hierarchical: n_mesoclusters: %u", n_mesoclusters); + + rmm::mr::managed_memory_resource managed_memory; + rmm::mr::device_memory_resource* device_memory = handle.get_workspace_resource(); + auto [max_minibatch_size, mem_per_row] = + calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); + auto pool_guard = + raft::get_pool_memory_resource(device_memory, mem_per_row * size_t(max_minibatch_size)); + if (pool_guard) { + RAFT_LOG_DEBUG("build_hierarchical: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + // Precompute the L2 norm of the dataset if relevant. + const MathT* dataset_norm = nullptr; + rmm::device_uvector dataset_norm_buf(0, stream, device_memory); + if (params.metric == raft::distance::DistanceType::L2Expanded || + params.metric == raft::distance::DistanceType::L2SqrtExpanded) { + dataset_norm_buf.resize(n_rows, stream); + for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { + IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); + compute_norm(handle, + dataset_norm_buf.data() + offset, + dataset + dim * offset, + dim, + minibatch_size, + mapping_op, + device_memory); + } + dataset_norm = (const MathT*)dataset_norm_buf.data(); + } + + /* Temporary workaround to cub::DeviceHistogram not supporting any type that isn't natively + * supported by atomicAdd: find a supported CounterT based on the IdxT. */ + typedef typename std::conditional_t + CounterT; + + // build coarse clusters (mesoclusters) + rmm::device_uvector mesocluster_labels_buf(n_rows, stream, &managed_memory); + rmm::device_uvector mesocluster_sizes_buf(n_mesoclusters, stream, &managed_memory); + { + rmm::device_uvector mesocluster_centers_buf(n_mesoclusters * dim, stream, device_memory); + build_clusters(handle, + params, + dim, + dataset, + n_rows, + n_mesoclusters, + mesocluster_centers_buf.data(), + mesocluster_labels_buf.data(), + mesocluster_sizes_buf.data(), + mapping_op, + device_memory, + dataset_norm); + } + + auto mesocluster_sizes = mesocluster_sizes_buf.data(); + auto mesocluster_labels = mesocluster_labels_buf.data(); + + handle.sync_stream(stream); + + // build fine clusters + auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] = + arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes); + + const IdxT mesocluster_size_max_balanced = div_rounding_up_safe( + 2lu * size_t(n_rows), std::max(size_t(n_mesoclusters), 1lu)); + if (mesocluster_size_max > mesocluster_size_max_balanced) { + RAFT_LOG_WARN( + "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). " + "At most %u points will be used for training within each mesocluster. " + "Consider increasing the number of training iterations `n_iters`.", + mesocluster_size_max, + mesocluster_size_max_balanced, + mesocluster_size_max_balanced); + RAFT_LOG_TRACE_VEC(mesocluster_sizes, n_mesoclusters); + RAFT_LOG_TRACE_VEC(fine_clusters_nums.data(), n_mesoclusters); + mesocluster_size_max = mesocluster_size_max_balanced; + } + + auto n_clusters_done = build_fine_clusters(handle, + params, + dim, + dataset, + dataset_norm, + mesocluster_labels, + n_rows, + fine_clusters_nums.data(), + fine_clusters_csum.data(), + mesocluster_sizes, + n_mesoclusters, + mesocluster_size_max, + fine_clusters_nums_max, + cluster_centers, + mapping_op, + &managed_memory, + device_memory); + RAFT_EXPECTS(n_clusters_done == n_clusters, "Didn't process all clusters."); + + rmm::device_uvector cluster_sizes(n_clusters, stream, device_memory); + rmm::device_uvector labels(n_rows, stream, device_memory); + + // Fine-tuning k-means for all clusters + // + // (*) Since the likely cluster centroids have been calculated hierarchically already, the number + // of iterations for fine-tuning kmeans for whole clusters should be reduced. However, there is a + // possibility that the clusters could be unbalanced here, in which case the actual number of + // iterations would be increased. + // + balancing_em_iters(handle, + params, + std::max(params.n_iters / 10, 2), + dim, + dataset, + dataset_norm, + n_rows, + n_clusters, + cluster_centers, + labels.data(), + cluster_sizes.data(), + 5, + MathT{0.2}, + mapping_op, + device_memory); +} + +} // namespace raft::cluster::detail diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 2973be8c23..76fc22e99e 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,14 +30,14 @@ #include #include #include -#include +#include #include #include #include +#include #include #include #include -#include #include #include #include @@ -88,7 +88,7 @@ struct KeyValueIndexOp { // Computes the intensity histogram from a sequence of labels template -void countLabels(const raft::handle_t& handle, +void countLabels(raft::device_resources const& handle, SampleIteratorT labels, CounterT* count, IndexT n_samples, @@ -96,9 +96,13 @@ void countLabels(const raft::handle_t& handle, rmm::device_uvector& workspace) { cudaStream_t stream = handle.get_stream(); - IndexT num_levels = n_clusters + 1; - IndexT lower_level = 0; - IndexT upper_level = n_clusters; + + // CUB::DeviceHistogram requires a signed index type + typedef typename std::make_signed_t CubIndexT; + + CubIndexT num_levels = n_clusters + 1; + CubIndexT lower_level = 0; + CubIndexT upper_level = n_clusters; size_t temp_storage_bytes = 0; RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(nullptr, @@ -108,7 +112,7 @@ void countLabels(const raft::handle_t& handle, num_levels, lower_level, upper_level, - n_samples, + static_cast(n_samples), stream)); workspace.resize(temp_storage_bytes, stream); @@ -120,12 +124,12 @@ void countLabels(const raft::handle_t& handle, num_levels, lower_level, upper_level, - n_samples, + static_cast(n_samples), stream)); } template -void checkWeight(const raft::handle_t& handle, +void checkWeight(raft::device_resources const& handle, raft::device_vector_view weight, rmm::device_uvector& workspace) { @@ -156,12 +160,11 @@ void checkWeight(const raft::handle_t& handle, n_samples); auto scale = static_cast(n_samples) / wt_sum; - raft::linalg::unaryOp( - weight.data_handle(), - weight.data_handle(), - n_samples, - [=] __device__(const DataT& wt) { return wt * scale; }, - stream); + raft::linalg::unaryOp(weight.data_handle(), + weight.data_handle(), + n_samples, + raft::mul_const_op{scale}, + stream); } } @@ -179,38 +182,47 @@ IndexT getCentroidsBatchSize(int batch_centroids, IndexT n_local_clusters) return (minVal == 0) ? n_local_clusters : minVal; } -template -void computeClusterCost(const raft::handle_t& handle, - raft::device_vector_view minClusterDistance, +template +void computeClusterCost(raft::device_resources const& handle, + raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, + raft::device_scalar_view clusterCost, + MainOpT main_op, ReductionOpT reduction_op) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = handle.get_stream(); + + cub::TransformInputIterator itr(minClusterDistance.data_handle(), + main_op); + size_t temp_storage_bytes = 0; RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, - minClusterDistance.data_handle(), + itr, clusterCost.data_handle(), minClusterDistance.size(), reduction_op, - DataT(), + OutputT(), stream)); workspace.resize(temp_storage_bytes, stream); RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(workspace.data(), temp_storage_bytes, - minClusterDistance.data_handle(), + itr, clusterCost.data_handle(), minClusterDistance.size(), reduction_op, - DataT(), + OutputT(), stream)); } template -void sampleCentroids(const raft::handle_t& handle, +void sampleCentroids(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_vector_view minClusterDistance, raft::device_vector_view isSampleCentroid, @@ -267,16 +279,14 @@ void sampleCentroids(const raft::handle_t& handle, sampledMinClusterDistance.data_handle(), nPtsSampledInRank, inRankCp.data(), - [=] __device__(raft::KeyValuePair val) { // MapTransformOp - return val.key; - }, + raft::key_op{}, stream); } // calculate pairwise distance between 'dataset[n x d]' and 'centroids[k x d]', // result will be stored in 'pairwiseDistance[n x k]' template -void pairwise_distance_kmeans(const raft::handle_t& handle, +void pairwise_distance_kmeans(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_matrix_view pairwiseDistance, @@ -304,7 +314,7 @@ void pairwise_distance_kmeans(const raft::handle_t& handle, // shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores // in 'out' does not modify the input template -void shuffleAndGather(const raft::handle_t& handle, +void shuffleAndGather(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, uint32_t n_samples_to_gather, @@ -329,7 +339,7 @@ void shuffleAndGather(const raft::handle_t& handle, in.extent(1), in.extent(0), indices.data_handle(), - n_samples_to_gather, + static_cast(n_samples_to_gather), out.data_handle(), stream); } @@ -339,7 +349,7 @@ void shuffleAndGather(const raft::handle_t& handle, // is the distance between the sample and the 'centroid[key]' template void minClusterAndDistanceCompute( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view, IndexT> minClusterAndDistance, @@ -464,17 +474,15 @@ void minClusterAndDistanceCompute( pair.value = val; return pair; }, - [=] __device__(raft::KeyValuePair a, raft::KeyValuePair b) { - return (b.value < a.value) ? b : a; - }, - [=] __device__(raft::KeyValuePair pair) { return pair; }); + raft::argmin_op{}, + raft::identity_op{}); } } } } template -void minClusterDistanceCompute(const raft::handle_t& handle, +void minClusterDistanceCompute(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view minClusterDistance, @@ -542,7 +550,6 @@ void minClusterDistanceCompute(const raft::handle_t& handle, if (is_fused) { workspace.resize((sizeof(IndexT)) * ns, stream); - // todo(lsugy): remove cIdx raft::distance::fusedL2NNMinReduce( minClusterDistanceView.data_handle(), datasetView.data_handle(), @@ -577,30 +584,23 @@ void minClusterDistanceCompute(const raft::handle_t& handle, pairwise_distance_kmeans( handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); - raft::linalg::coalescedReduction( - minClusterDistanceView.data_handle(), - pairwiseDistanceView.data_handle(), - pairwiseDistanceView.extent(1), - pairwiseDistanceView.extent(0), - std::numeric_limits::max(), - stream, - true, - [=] __device__(DataT val, IndexT i) { // MainLambda - return val; - }, - [=] __device__(DataT a, DataT b) { // ReduceLambda - return (b < a) ? b : a; - }, - [=] __device__(DataT val) { // FinalLambda - return val; - }); + raft::linalg::coalescedReduction(minClusterDistanceView.data_handle(), + pairwiseDistanceView.data_handle(), + pairwiseDistanceView.extent(1), + pairwiseDistanceView.extent(0), + std::numeric_limits::max(), + stream, + true, + raft::identity_op{}, + raft::min_op{}, + raft::identity_op{}); } } } } template -void countSamplesInCluster(const raft::handle_t& handle, +void countSamplesInCluster(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view L2NormX, diff --git a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh b/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh index 2746b6f657..a9d8777304 100644 --- a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ #include #include -#include +#include #include #include #include @@ -360,7 +360,7 @@ static __global__ void divideCentroids(index_type_t d, * @return Zero if successful. Otherwise non-zero. */ template -static int chooseNewCentroid(handle_t const& handle, +static int chooseNewCentroid(raft::device_resources const& handle, index_type_t n, index_type_t d, value_type_t rand, @@ -457,7 +457,7 @@ static int chooseNewCentroid(handle_t const& handle, * @return Zero if successful. Otherwise non-zero. */ template -static int initializeCentroids(handle_t const& handle, +static int initializeCentroids(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -568,7 +568,7 @@ static int initializeCentroids(handle_t const& handle, * @return Zero if successful. Otherwise non-zero. */ template -static int assignCentroids(handle_t const& handle, +static int assignCentroids(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -640,7 +640,7 @@ static int assignCentroids(handle_t const& handle, * @return Zero if successful. Otherwise non-zero. */ template -static int updateCentroids(handle_t const& handle, +static int updateCentroids(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -783,7 +783,7 @@ static int updateCentroids(handle_t const& handle, * @return error flag. */ template -int kmeans(handle_t const& handle, +int kmeans(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -950,7 +950,7 @@ int kmeans(handle_t const& handle, * @return error flag */ template -int kmeans(handle_t const& handle, +int kmeans(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, diff --git a/cpp/include/raft/cluster/detail/mst.cuh b/cpp/include/raft/cluster/detail/mst.cuh index 8143d21641..46e31b672e 100644 --- a/cpp/include/raft/cluster/detail/mst.cuh +++ b/cpp/include/raft/cluster/detail/mst.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,7 +67,7 @@ void merge_msts(sparse::solver::Graph_COO& coo1, */ template void connect_knn_graph( - const raft::handle_t& handle, + raft::device_resources const& handle, const value_t* X, sparse::solver::Graph_COO& msf, size_t m, @@ -130,7 +130,7 @@ void connect_knn_graph( */ template void build_sorted_mst( - const raft::handle_t& handle, + raft::device_resources const& handle, const value_t* X, const value_idx* indptr, const value_idx* indices, diff --git a/cpp/include/raft/cluster/detail/single_linkage.cuh b/cpp/include/raft/cluster/detail/single_linkage.cuh index d12db85e1b..473d858827 100644 --- a/cpp/include/raft/cluster/detail/single_linkage.cuh +++ b/cpp/include/raft/cluster/detail/single_linkage.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ static const size_t EMPTY = 0; * @param[in] n_clusters number of clusters to assign data samples */ template -void single_linkage(const raft::handle_t& handle, +void single_linkage(raft::device_resources const& handle, const value_t* X, size_t m, size_t n, diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index d64815244b..ac9e66d5da 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #include #include #include +#include namespace raft::cluster::kmeans { @@ -43,12 +44,12 @@ using KeyValueIndexOp = detail::KeyValueIndexOp; * k-means++ algorithm. * * @code{.cpp} - * #include + * #include * #include * #include * using namespace raft::cluster; * ... - * raft::handle_t handle; + * raft::raft::device_resources handle; * raft::cluster::KMeansParams params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); @@ -82,7 +83,7 @@ using KeyValueIndexOp = detail::KeyValueIndexOp; * @param[out] n_iter Number of iterations run. */ template -void fit(handle_t const& handle, +void fit(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -97,12 +98,12 @@ void fit(handle_t const& handle, * @brief Predict the closest cluster each sample in X belongs to. * * @code{.cpp} - * #include + * #include * #include * #include * using namespace raft::cluster; * ... - * raft::handle_t handle; + * raft::raft::device_resources handle; * raft::cluster::KMeansParams params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); @@ -146,7 +147,7 @@ void fit(handle_t const& handle, * their closest cluster center. */ template -void predict(handle_t const& handle, +void predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -164,12 +165,12 @@ void predict(handle_t const& handle, * in the input. * * @code{.cpp} - * #include + * #include * #include * #include * using namespace raft::cluster; * ... - * raft::handle_t handle; + * raft::raft::device_resources handle; * raft::cluster::KMeansParams params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); @@ -209,7 +210,7 @@ void predict(handle_t const& handle, * @param[out] n_iter Number of iterations run. */ template -void fit_predict(handle_t const& handle, +void fit_predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -238,7 +239,7 @@ void fit_predict(handle_t const& handle, * [dim = n_samples x n_features] */ template -void transform(const raft::handle_t& handle, +void transform(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -248,7 +249,7 @@ void transform(const raft::handle_t& handle, } template -void transform(const raft::handle_t& handle, +void transform(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* centroids, @@ -280,7 +281,7 @@ void transform(const raft::handle_t& handle, * */ template -void sample_centroids(const raft::handle_t& handle, +void sample_centroids(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_vector_view minClusterDistance, raft::device_vector_view isSampleCentroid, @@ -307,13 +308,14 @@ void sample_centroids(const raft::handle_t& handle, * */ template -void cluster_cost(const raft::handle_t& handle, +void cluster_cost(raft::device_resources const& handle, raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, raft::device_scalar_view clusterCost, ReductionOpT reduction_op) { - detail::computeClusterCost(handle, minClusterDistance, workspace, clusterCost, reduction_op); + detail::computeClusterCost( + handle, minClusterDistance, workspace, clusterCost, raft::identity_op{}, reduction_op); } /** @@ -332,7 +334,7 @@ void cluster_cost(const raft::handle_t& handle, * @param[out] new_centroids: output matrix of updated centroids (size n_clusters, n_features) */ template -void update_centroids(const raft::handle_t& handle, +void update_centroids(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_vector_view sample_weights, raft::device_matrix_view centroids, @@ -373,7 +375,7 @@ void update_centroids(const raft::handle_t& handle, * */ template -void min_cluster_distance(const raft::handle_t& handle, +void min_cluster_distance(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view minClusterDistance, @@ -424,7 +426,7 @@ void min_cluster_distance(const raft::handle_t& handle, */ template void min_cluster_and_distance( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view, IndexT> minClusterAndDistance, @@ -464,7 +466,7 @@ void min_cluster_and_distance( * */ template -void shuffle_and_gather(const raft::handle_t& handle, +void shuffle_and_gather(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, uint32_t n_samples_to_gather, @@ -493,7 +495,7 @@ void shuffle_and_gather(const raft::handle_t& handle, * */ template -void count_samples_in_cluster(const raft::handle_t& handle, +void count_samples_in_cluster(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view L2NormX, @@ -523,7 +525,7 @@ void count_samples_in_cluster(const raft::handle_t& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void init_plus_plus(const raft::handle_t& handle, +void init_plus_plus(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -556,7 +558,7 @@ void init_plus_plus(const raft::handle_t& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void fit_main(const raft::handle_t& handle, +void fit_main(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view sample_weights, @@ -603,7 +605,7 @@ namespace raft::cluster { * @param[out] n_iter Number of iterations run. */ template -void kmeans_fit(handle_t const& handle, +void kmeans_fit(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -615,7 +617,7 @@ void kmeans_fit(handle_t const& handle, } template -void kmeans_fit(handle_t const& handle, +void kmeans_fit(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -650,7 +652,7 @@ void kmeans_fit(handle_t const& handle, * their closest cluster center. */ template -void kmeans_predict(handle_t const& handle, +void kmeans_predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -664,7 +666,7 @@ void kmeans_predict(handle_t const& handle, } template -void kmeans_predict(handle_t const& handle, +void kmeans_predict(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -715,7 +717,7 @@ void kmeans_predict(handle_t const& handle, * @param[out] n_iter Number of iterations run. */ template -void kmeans_fit_predict(handle_t const& handle, +void kmeans_fit_predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -729,7 +731,7 @@ void kmeans_fit_predict(handle_t const& handle, } template -void kmeans_fit_predict(handle_t const& handle, +void kmeans_fit_predict(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -760,7 +762,7 @@ void kmeans_fit_predict(handle_t const& handle, * [dim = n_samples x n_features] */ template -void kmeans_transform(const raft::handle_t& handle, +void kmeans_transform(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -770,7 +772,7 @@ void kmeans_transform(const raft::handle_t& handle, } template -void kmeans_transform(const raft::handle_t& handle, +void kmeans_transform(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* centroids, @@ -807,7 +809,7 @@ using KeyValueIndexOp = kmeans::KeyValueIndexOp; * */ template -void sampleCentroids(const raft::handle_t& handle, +void sampleCentroids(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_vector_view minClusterDistance, raft::device_vector_view isSampleCentroid, @@ -834,7 +836,7 @@ void sampleCentroids(const raft::handle_t& handle, * */ template -void computeClusterCost(const raft::handle_t& handle, +void computeClusterCost(raft::device_resources const& handle, raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, raft::device_scalar_view clusterCost, @@ -865,7 +867,7 @@ void computeClusterCost(const raft::handle_t& handle, * */ template -void minClusterDistanceCompute(const raft::handle_t& handle, +void minClusterDistanceCompute(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -912,7 +914,7 @@ void minClusterDistanceCompute(const raft::handle_t& handle, */ template void minClusterAndDistanceCompute( - const raft::handle_t& handle, + raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -950,7 +952,7 @@ void minClusterAndDistanceCompute( * */ template -void shuffleAndGather(const raft::handle_t& handle, +void shuffleAndGather(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, uint32_t n_samples_to_gather, @@ -979,7 +981,7 @@ void shuffleAndGather(const raft::handle_t& handle, * */ template -void countSamplesInCluster(const raft::handle_t& handle, +void countSamplesInCluster(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view L2NormX, @@ -1010,7 +1012,7 @@ void countSamplesInCluster(const raft::handle_t& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void kmeansPlusPlus(const raft::handle_t& handle, +void kmeansPlusPlus(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroidsRawData, @@ -1043,7 +1045,7 @@ void kmeansPlusPlus(const raft::handle_t& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void kmeans_fit_main(const raft::handle_t& handle, +void kmeans_fit_main(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view weight, diff --git a/cpp/include/raft/cluster/kmeans_balanced.cuh b/cpp/include/raft/cluster/kmeans_balanced.cuh new file mode 100644 index 0000000000..405c7a8018 --- /dev/null +++ b/cpp/include/raft/cluster/kmeans_balanced.cuh @@ -0,0 +1,365 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace raft::cluster::kmeans_balanced { + +/** + * @brief Find clusters of balanced sizes with a hierarchical k-means algorithm. + * + * This variant of the k-means algorithm first clusters the dataset in mesoclusters, then clusters + * the subsets associated to each mesocluster into fine clusters, and finally runs a few k-means + * iterations over the whole dataset and with all the centroids to obtain the final clusters. + * + * Each k-means iteration applies expectation-maximization-balancing: + * - Balancing: adjust centers for clusters that have a small number of entries. If the size of a + * cluster is below a threshold, the center is moved towards a bigger cluster. + * - Expectation: predict the labels (i.e find closest cluster centroid to each point) + * - Maximization: calculate optimal centroids (i.e find the center of gravity of each cluster) + * + * The number of mesoclusters is chosen by rounding the square root of the number of clusters. E.g + * for 512 clusters, we would have 23 mesoclusters. The number of fine clusters per mesocluster is + * chosen proportionally to the number of points in each mesocluster. + * + * This variant of k-means uses random initialization and a fixed number of iterations, though + * iterations can be repeated if the balancing step moved the centroids. + * + * Additionally, this algorithm supports quantized datasets in arbitrary types but the core part of + * the algorithm will work with a floating-point type, hence a conversion function can be provided + * to map the data type to the math type. + * + * @code{.cpp} + * #include + * #include + * #include + * ... + * raft::handle_t handle; + * raft::cluster::kmeans_balanced_params params; + * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); + * raft::cluster::kmeans_balanced::fit(handle, params, X, centroids.view()); + * @endcode + * + * @tparam DataT Type of the input data. + * @tparam MathT Type of the centroids and mapped data. + * @tparam IndexT Type used for indexing. + * @tparam MappingOpT Type of the mapping function. + * @param[in] handle The raft resources + * @param[in] params Structure containing the hyper-parameters + * @param[in] X Training instances to cluster. The data must be in row-major format. + * [dim = n_samples x n_features] + * @param[out] centroids The generated centroids [dim = n_clusters x n_features] + * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic + * datatype. If DataT == MathT, this must be the identity. + */ +template +void fit(const raft::device_resources& handle, + kmeans_balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + MappingOpT mapping_op = raft::identity_op()) +{ + RAFT_EXPECTS(X.extent(1) == centroids.extent(1), + "Number of features in dataset and centroids are different"); + RAFT_EXPECTS(static_cast(X.extent(0)) * static_cast(X.extent(1)) <= + static_cast(std::numeric_limits::max()), + "The chosen index type cannot represent all indices for the given dataset"); + RAFT_EXPECTS(centroids.extent(0) > IndexT{0} && centroids.extent(0) <= X.extent(0), + "The number of centroids must be strictly positive and cannot exceed the number of " + "points in the training dataset."); + + detail::build_hierarchical(handle, + params, + X.extent(1), + X.data_handle(), + X.extent(0), + centroids.data_handle(), + centroids.extent(0), + mapping_op); +} + +/** + * @brief Predict the closest cluster each sample in X belongs to. + * + * @code{.cpp} + * #include + * #include + * #include + * ... + * raft::handle_t handle; + * raft::cluster::kmeans_balanced_params params; + * auto labels = raft::make_device_vector(handle, n_rows); + * raft::cluster::kmeans_balanced::predict(handle, params, X, centroids, labels); + * @endcode + * + * @tparam DataT Type of the input data. + * @tparam MathT Type of the centroids and mapped data. + * @tparam IndexT Type used for indexing. + * @tparam LabelT Type of the output labels. + * @tparam MappingOpT Type of the mapping function. + * @param[in] handle The raft resources + * @param[in] params Structure containing the hyper-parameters + * @param[in] X Dataset for which to infer the closest clusters. + * [dim = n_samples x n_features] + * @param[in] centroids The input centroids [dim = n_clusters x n_features] + * @param[out] labels The output labels [dim = n_samples] + * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic + * datatype. If DataT == MathT, this must be the identity. + */ +template +void predict(const raft::device_resources& handle, + kmeans_balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + MappingOpT mapping_op = raft::identity_op()) +{ + RAFT_EXPECTS(X.extent(0) == labels.extent(0), + "Number of rows in dataset and labels are different"); + RAFT_EXPECTS(X.extent(1) == centroids.extent(1), + "Number of features in dataset and centroids are different"); + RAFT_EXPECTS(static_cast(X.extent(0)) * static_cast(X.extent(1)) <= + static_cast(std::numeric_limits::max()), + "The chosen index type cannot represent all indices for the given dataset"); + RAFT_EXPECTS(static_cast(centroids.extent(0)) <= + static_cast(std::numeric_limits::max()), + "The chosen label type cannot represent all cluster labels"); + + detail::predict(handle, + params, + centroids.data_handle(), + centroids.extent(0), + X.extent(1), + X.data_handle(), + X.extent(0), + labels.data_handle(), + mapping_op); +} + +/** + * @brief Compute hierarchical balanced k-means clustering and predict cluster index for each sample + * in the input. + * + * @code{.cpp} + * #include + * #include + * #include + * ... + * raft::handle_t handle; + * raft::cluster::kmeans_balanced_params params; + * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); + * auto labels = raft::make_device_vector(handle, n_rows); + * raft::cluster::kmeans_balanced::fit_predict( + * handle, params, X, centroids.view(), labels.view()); + * @endcode + * + * @tparam DataT Type of the input data. + * @tparam MathT Type of the centroids and mapped data. + * @tparam IndexT Type used for indexing. + * @tparam LabelT Type of the output labels. + * @tparam MappingOpT Type of the mapping function. + * @param[in] handle The raft resources + * @param[in] params Structure containing the hyper-parameters + * @param[in] X Training instances to cluster. The data must be in row-major format. + * [dim = n_samples x n_features] + * @param[out] centroids The output centroids [dim = n_clusters x n_features] + * @param[out] labels The output labels [dim = n_samples] + * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic + * datatype. If DataT and MathT are the same, this must be the identity. + */ +template +void fit_predict(const raft::device_resources& handle, + kmeans_balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + MappingOpT mapping_op = raft::identity_op()) +{ + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)); + raft::cluster::kmeans_balanced::fit(handle, params, X, centroids, mapping_op); + raft::cluster::kmeans_balanced::predict(handle, params, X, centroids_const, labels, mapping_op); +} + +namespace helpers { + +/** + * @brief Randomly initialize centers and apply expectation-maximization-balancing iterations + * + * This is essentially the non-hierarchical balanced k-means algorithm which is used by the + * hierarchical algorithm once to build the mesoclusters and once per mesocluster to build the fine + * clusters. + * + * @code{.cpp} + * #include + * #include + * #include + * ... + * raft::handle_t handle; + * raft::cluster::kmeans_balanced_params params; + * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); + * auto labels = raft::make_device_vector(handle, n_samples); + * auto sizes = raft::make_device_vector(handle, n_clusters); + * raft::cluster::kmeans_balanced::build_clusters( + * handle, params, X, centroids.view(), labels.view(), sizes.view()); + * @endcode + * + * @tparam DataT Type of the input data. + * @tparam MathT Type of the centroids and mapped data. + * @tparam IndexT Type used for indexing. + * @tparam LabelT Type of the output labels. + * @tparam CounterT Counter type supported by CUDA's native atomicAdd. + * @tparam MappingOpT Type of the mapping function. + * @param[in] handle The raft resources + * @param[in] params Structure containing the hyper-parameters + * @param[in] X Training instances to cluster. The data must be in row-major format. + * [dim = n_samples x n_features] + * @param[out] centroids The output centroids [dim = n_clusters x n_features] + * @param[out] labels The output labels [dim = n_samples] + * @param[out] cluster_sizes Size of each cluster [dim = n_clusters] + * @param[in] mapping_op (optional) Functor to convert from the input datatype to the + * arithmetic datatype. If DataT == MathT, this must be the identity. + * @param[in] X_norm (optional) Dataset's row norms [dim = n_samples] + */ +template +void build_clusters(const raft::device_resources& handle, + const kmeans_balanced_params& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + raft::device_vector_view cluster_sizes, + MappingOpT mapping_op = raft::identity_op(), + std::optional> X_norm = std::nullopt) +{ + RAFT_EXPECTS(X.extent(0) == labels.extent(0), + "Number of rows in dataset and labels are different"); + RAFT_EXPECTS(X.extent(1) == centroids.extent(1), + "Number of features in dataset and centroids are different"); + RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0), + "Number of rows in centroids and clusyer_sizes are different"); + + detail::build_clusters(handle, + params, + X.extent(1), + X.data_handle(), + X.extent(0), + centroids.extent(0), + centroids.data_handle(), + labels.data_handle(), + cluster_sizes.data_handle(), + mapping_op, + handle.get_workspace_resource(), + X_norm.has_value() ? X_norm.value().data_handle() : nullptr); +} + +/** + * @brief Given the data and labels, calculate cluster centers and sizes in one sweep. + * + * Let `S_i = {x_k | x_k \in X & labels[k] == i}` be the vectors in the dataset with label i. + * + * On exit, + * `centers_i = (\sum_{x \in S_i} x + w_i * center_i) / (|S_i| + w_i)`, + * where `w_i = reset_counters ? 0 : cluster_size[i]`. + * + * In other words, the updated cluster centers are a weighted average of the existing cluster + * center, and the coordinates of the points labeled with i. _This allows calling this function + * multiple times with different datasets with the same effect as if calling this function once + * on the combined dataset_. + * + * @code{.cpp} + * #include + * #include + * ... + * raft::handle_t handle; + * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); + * auto sizes = raft::make_device_vector(handle, n_clusters); + * raft::cluster::kmeans_balanced::calc_centers_and_sizes( + * handle, X, labels, centroids.view(), sizes.view(), true); + * @endcode + * + * @tparam DataT Type of the input data. + * @tparam MathT Type of the centroids and mapped data. + * @tparam IndexT Type used for indexing. + * @tparam LabelT Type of the output labels. + * @tparam CounterT Counter type supported by CUDA's native atomicAdd. + * @tparam MappingOpT Type of the mapping function. + * @param[in] handle The raft resources + * @param[in] X Dataset for which to calculate cluster centers. The data must be in + * row-major format. [dim = n_samples x n_features] + * @param[in] labels The input labels [dim = n_samples] + * @param[out] centroids The output centroids [dim = n_clusters x n_features] + * @param[out] cluster_sizes Size of each cluster [dim = n_clusters] + * @param[in] reset_counters Whether to clear the output arrays before calculating. + * When set to `false`, this function may be used to update existing + * centers and sizes using the weighted average principle. + * @param[in] mapping_op (optional) Functor to convert from the input datatype to the + * arithmetic datatype. If DataT == MathT, this must be the identity. + */ +template +void calc_centers_and_sizes(const raft::device_resources& handle, + raft::device_matrix_view X, + raft::device_vector_view labels, + raft::device_matrix_view centroids, + raft::device_vector_view cluster_sizes, + bool reset_counters = true, + MappingOpT mapping_op = raft::identity_op()) +{ + RAFT_EXPECTS(X.extent(0) == labels.extent(0), + "Number of rows in dataset and labels are different"); + RAFT_EXPECTS(X.extent(1) == centroids.extent(1), + "Number of features in dataset and centroids are different"); + RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0), + "Number of rows in centroids and clusyer_sizes are different"); + + detail::calc_centers_and_sizes(handle, + centroids.data_handle(), + cluster_sizes.data_handle(), + centroids.extent(0), + X.extent(1), + X.data_handle(), + X.extent(0), + labels.data_handle(), + reset_counters, + mapping_op); +} + +} // namespace helpers + +} // namespace raft::cluster::kmeans_balanced diff --git a/cpp/include/raft/cluster/kmeans_balanced_types.hpp b/cpp/include/raft/cluster/kmeans_balanced_types.hpp new file mode 100644 index 0000000000..11b77e288a --- /dev/null +++ b/cpp/include/raft/cluster/kmeans_balanced_types.hpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft::cluster::kmeans_balanced { + +/** + * Simple object to specify hyper-parameters to the balanced k-means algorithm. + * + * The following metrics are currently supported in k-means balanced: + * - InnerProduct + * - L2Expanded + * - L2SqrtExpanded + */ +struct kmeans_balanced_params : kmeans_base_params { + /** + * Number of training iterations + */ + uint32_t n_iters = 20; +}; + +} // namespace raft::cluster::kmeans_balanced + +namespace raft::cluster { + +using kmeans_balanced::kmeans_balanced_params; + +} // namespace raft::cluster diff --git a/cpp/include/raft/cluster/kmeans_deprecated.cuh b/cpp/include/raft/cluster/kmeans_deprecated.cuh index a4cac4cb0f..8e0861ada1 100644 --- a/cpp/include/raft/cluster/kmeans_deprecated.cuh +++ b/cpp/include/raft/cluster/kmeans_deprecated.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ namespace cluster { * @return error flag */ template -int kmeans(handle_t const& handle, +int kmeans(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, diff --git a/cpp/include/raft/cluster/kmeans_types.hpp b/cpp/include/raft/cluster/kmeans_types.hpp index f411b12b5c..4d956ad7a0 100644 --- a/cpp/include/raft/cluster/kmeans_types.hpp +++ b/cpp/include/raft/cluster/kmeans_types.hpp @@ -18,12 +18,24 @@ #include #include +namespace raft::cluster { + +/** Base structure for parameters that are common to all k-means algorithms */ +struct kmeans_base_params { + /** + * Metric to use for distance computation. The supported metrics can vary per algorithm. + */ + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; +}; + +} // namespace raft::cluster + namespace raft::cluster::kmeans { /** * Simple object to specify hyper-parameters to the kmeans algorithm. */ -struct KMeansParams { +struct KMeansParams : kmeans_base_params { enum InitMethod { /** @@ -75,13 +87,7 @@ struct KMeansParams { /** * Seed to the random number generator. */ - raft::random::RngState rng_state = - raft::random::RngState(0, raft::random::GeneratorType::GenPhilox); - - /** - * Metric to use for distance computation. - */ - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; + raft::random::RngState rng_state{0}; /** * Number of instance k-means algorithm will be run with different seeds. diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh index 2d74c364b2..91241b853b 100644 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ b/cpp/include/raft/cluster/single_linkage.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ namespace raft::cluster { template -void single_linkage(const raft::handle_t& handle, +void single_linkage(raft::device_resources const& handle, const value_t* X, size_t m, size_t n, @@ -87,7 +87,7 @@ constexpr int DEFAULT_CONST_C = 15; control of k. The algorithm will set `k = log(n) + c` */ template -void single_linkage(const raft::handle_t& handle, +void single_linkage(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view dendrogram, raft::device_vector_view labels, diff --git a/cpp/include/raft/comms/comms_test.hpp b/cpp/include/raft/comms/comms_test.hpp index c7e5dd3ab6..c61bb32f79 100644 --- a/cpp/include/raft/comms/comms_test.hpp +++ b/cpp/include/raft/comms/comms_test.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include #include -#include +#include namespace raft { namespace comms { @@ -31,7 +31,7 @@ namespace comms { * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allreduce(const handle_t& handle, int root) +bool test_collective_allreduce(raft::device_resources const& handle, int root) { return detail::test_collective_allreduce(handle, root); } @@ -43,7 +43,7 @@ bool test_collective_allreduce(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_broadcast(const handle_t& handle, int root) +bool test_collective_broadcast(raft::device_resources const& handle, int root) { return detail::test_collective_broadcast(handle, root); } @@ -55,7 +55,7 @@ bool test_collective_broadcast(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reduce(const handle_t& handle, int root) +bool test_collective_reduce(raft::device_resources const& handle, int root) { return detail::test_collective_reduce(handle, root); } @@ -67,7 +67,7 @@ bool test_collective_reduce(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allgather(const handle_t& handle, int root) +bool test_collective_allgather(raft::device_resources const& handle, int root) { return detail::test_collective_allgather(handle, root); } @@ -79,7 +79,7 @@ bool test_collective_allgather(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gather(const handle_t& handle, int root) +bool test_collective_gather(raft::device_resources const& handle, int root) { return detail::test_collective_gather(handle, root); } @@ -91,7 +91,7 @@ bool test_collective_gather(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gatherv(const handle_t& handle, int root) +bool test_collective_gatherv(raft::device_resources const& handle, int root) { return detail::test_collective_gatherv(handle, root); } @@ -103,7 +103,7 @@ bool test_collective_gatherv(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reducescatter(const handle_t& handle, int root) +bool test_collective_reducescatter(raft::device_resources const& handle, int root) { return detail::test_collective_reducescatter(handle, root); } @@ -115,7 +115,7 @@ bool test_collective_reducescatter(const handle_t& handle, int root) * initialized comms instance. * @param[in] numTrials number of iterations of all-to-all messaging to perform */ -bool test_pointToPoint_simple_send_recv(const handle_t& h, int numTrials) +bool test_pointToPoint_simple_send_recv(raft::device_resources const& h, int numTrials) { return detail::test_pointToPoint_simple_send_recv(h, numTrials); } @@ -127,7 +127,7 @@ bool test_pointToPoint_simple_send_recv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_send_or_recv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_send_or_recv(raft::device_resources const& h, int numTrials) { return detail::test_pointToPoint_device_send_or_recv(h, numTrials); } @@ -139,7 +139,7 @@ bool test_pointToPoint_device_send_or_recv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_sendrecv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_sendrecv(raft::device_resources const& h, int numTrials) { return detail::test_pointToPoint_device_sendrecv(h, numTrials); } @@ -151,7 +151,7 @@ bool test_pointToPoint_device_sendrecv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_multicast_sendrecv(raft::device_resources const& h, int numTrials) { return detail::test_pointToPoint_device_multicast_sendrecv(h, numTrials); } @@ -163,6 +163,9 @@ bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrial * initialized comms instance. * @param n_colors number of different colors to test */ -bool test_commsplit(const handle_t& h, int n_colors) { return detail::test_commsplit(h, n_colors); } +bool test_commsplit(raft::device_resources const& h, int n_colors) +{ + return detail::test_commsplit(h, n_colors); +} } // namespace comms }; // namespace raft diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index 508a9ce717..4062389eea 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,8 +28,8 @@ #include #include +#include #include -#include #include #include #include diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 33892597d8..0db27f0a45 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include diff --git a/cpp/include/raft/comms/detail/test.hpp b/cpp/include/raft/comms/detail/test.hpp index 6ba4be3886..2b12bf2d2a 100644 --- a/cpp/include/raft/comms/detail/test.hpp +++ b/cpp/include/raft/comms/detail/test.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include @@ -38,7 +38,7 @@ namespace detail { * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allreduce(const handle_t& handle, int root) +bool test_collective_allreduce(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -69,7 +69,7 @@ bool test_collective_allreduce(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_broadcast(const handle_t& handle, int root) +bool test_collective_broadcast(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -104,7 +104,7 @@ bool test_collective_broadcast(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reduce(const handle_t& handle, int root) +bool test_collective_reduce(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -140,7 +140,7 @@ bool test_collective_reduce(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allgather(const handle_t& handle, int root) +bool test_collective_allgather(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -177,7 +177,7 @@ bool test_collective_allgather(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gather(const handle_t& handle, int root) +bool test_collective_gather(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -214,7 +214,7 @@ bool test_collective_gather(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gatherv(const handle_t& handle, int root) +bool test_collective_gatherv(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -273,7 +273,7 @@ bool test_collective_gatherv(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reducescatter(const handle_t& handle, int root) +bool test_collective_reducescatter(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -308,7 +308,7 @@ bool test_collective_reducescatter(const handle_t& handle, int root) * initialized comms instance. * @param[in] numTrials number of iterations of all-to-all messaging to perform */ -bool test_pointToPoint_simple_send_recv(const handle_t& h, int numTrials) +bool test_pointToPoint_simple_send_recv(raft::device_resources const& h, int numTrials) { comms_t const& communicator = h.get_comms(); int const rank = communicator.get_rank(); @@ -373,7 +373,7 @@ bool test_pointToPoint_simple_send_recv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_send_or_recv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_send_or_recv(raft::device_resources const& h, int numTrials) { comms_t const& communicator = h.get_comms(); int const rank = communicator.get_rank(); @@ -415,7 +415,7 @@ bool test_pointToPoint_device_send_or_recv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_sendrecv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_sendrecv(raft::device_resources const& h, int numTrials) { comms_t const& communicator = h.get_comms(); int const rank = communicator.get_rank(); @@ -461,7 +461,7 @@ bool test_pointToPoint_device_sendrecv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_multicast_sendrecv(raft::device_resources const& h, int numTrials) { comms_t const& communicator = h.get_comms(); int const rank = communicator.get_rank(); @@ -520,7 +520,7 @@ bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrial * initialized comms instance. * @param n_colors number of different colors to test */ -bool test_commsplit(const handle_t& h, int n_colors) +bool test_commsplit(raft::device_resources const& h, int n_colors) { comms_t const& communicator = h.get_comms(); int const rank = communicator.get_rank(); diff --git a/cpp/include/raft/comms/helper.hpp b/cpp/include/raft/comms/helper.hpp deleted file mode 100644 index f6b63ac971..0000000000 --- a/cpp/include/raft/comms/helper.hpp +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include -#include - -namespace raft { -namespace comms { - -/** - * Function to construct comms_t and inject it on a handle_t. This - * is used for convenience in the Python layer. - * - * @param handle raft::handle_t for injecting the comms - * @param nccl_comm initialized NCCL communicator to use for collectives - * @param num_ranks number of ranks in communicator clique - * @param rank rank of local instance - */ -void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks, int rank) -{ - cudaStream_t stream = handle->get_stream(); - - auto communicator = std::make_shared( - std::unique_ptr(new raft::comms::std_comms(nccl_comm, num_ranks, rank, stream))); - handle->set_comms(communicator); -} - -/** - * Function to construct comms_t and inject it on a handle_t. This - * is used for convenience in the Python layer. - * - * @param handle raft::handle_t for injecting the comms - * @param nccl_comm initialized NCCL communicator to use for collectives - * @param ucp_worker of local process - * Note: This is purposefully left as void* so that the ucp_worker_h - * doesn't need to be exposed through the cython layer - * @param eps array of ucp_ep_h instances. - * Note: This is purposefully left as void* so that - * the ucp_ep_h doesn't need to be exposed through the cython layer. - * @param num_ranks number of ranks in communicator clique - * @param rank rank of local instance - */ -void build_comms_nccl_ucx( - handle_t* handle, ncclComm_t nccl_comm, void* ucp_worker, void* eps, int num_ranks, int rank) -{ - auto eps_sp = std::make_shared(new ucp_ep_h[num_ranks]); - - auto size_t_ep_arr = reinterpret_cast(eps); - - for (int i = 0; i < num_ranks; i++) { - size_t ptr = size_t_ep_arr[i]; - auto ucp_ep_v = reinterpret_cast(*eps_sp); - - if (ptr != 0) { - auto eps_ptr = reinterpret_cast(size_t_ep_arr[i]); - ucp_ep_v[i] = eps_ptr; - } else { - ucp_ep_v[i] = nullptr; - } - } - - cudaStream_t stream = handle->get_stream(); - - auto communicator = - std::make_shared(std::unique_ptr(new raft::comms::std_comms( - nccl_comm, (ucp_worker_h)ucp_worker, eps_sp, num_ranks, rank, stream))); - handle->set_comms(communicator); -} - -inline void nccl_unique_id_from_char(ncclUniqueId* id, char* uniqueId, int size) -{ - memcpy(id->internal, uniqueId, size); -} - -inline void get_unique_id(char* uid, int size) -{ - ncclUniqueId id; - ncclGetUniqueId(&id); - - memcpy(uid, id.internal, size); -} -}; // namespace comms -}; // end namespace raft diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index ca5275cd06..9076176ea6 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,12 +24,47 @@ namespace comms { using mpi_comms = detail::mpi_comms; -inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm) +/** + * @defgroup mpi_comms_factory MPI Comms Factory Functions + * @{ + */ + +/** + * Given a properly initialized MPI_Comm, construct an instance of RAFT's + * MPI Communicator and inject it into the given RAFT handle instance + * @param handle raft handle for managing expensive resources + * @param comm an initialized MPI communicator + * + * @code{.cpp} + * #include + * #include + * + * MPI_Comm mpi_comm; + * raft::raft::device_resources handle; + * + * initialize_mpi_comms(&handle, mpi_comm); + * ... + * const auto& comm = handle.get_comms(); + * auto gather_data = raft::make_device_vector(handle, comm.get_size()); + * ... + * comm.allgather((gather_data.data_handle())[comm.get_rank()], + * gather_data.data_handle(), + * 1, + * handle.get_stream()); + * + * comm.sync_stream(handle.get_stream()); + * @endcode + */ +inline void initialize_mpi_comms(device_resources* handle, MPI_Comm comm) { auto communicator = std::make_shared( std::unique_ptr(new mpi_comms(comm, false, handle->get_stream()))); handle->set_comms(communicator); }; +/** + * @} + */ + }; // namespace comms }; // end namespace raft diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index edace60fbd..6370d4a8e6 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -31,15 +31,40 @@ namespace comms { using std_comms = detail::std_comms; /** - * Function to construct comms_t and inject it on a handle_t. This - * is used for convenience in the Python layer. + * @defgroup std_comms_factory std_comms Factory functions + * @{ + */ + +/** + * Factory function to construct a RAFT NCCL communicator and inject it into a + * RAFT handle. * - * @param handle raft::handle_t for injecting the comms + * @param handle raft::device_resources for injecting the comms * @param nccl_comm initialized NCCL communicator to use for collectives * @param num_ranks number of ranks in communicator clique * @param rank rank of local instance + * + * @code{.cpp} + * #include + * #include + * + * ncclComm_t nccl_comm; + * raft::raft::device_resources handle; + * + * build_comms_nccl_only(&handle, nccl_comm, 5, 0); + * ... + * const auto& comm = handle.get_comms(); + * auto gather_data = raft::make_device_vector(handle, comm.get_size()); + * ... + * comm.allgather((gather_data.data_handle())[comm.get_rank()], + * gather_data.data_handle(), + * 1, + * handle.get_stream()); + * + * comm.sync_stream(handle.get_stream()); + * @endcode */ -void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks, int rank) +void build_comms_nccl_only(device_resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank) { cudaStream_t stream = handle->get_stream(); @@ -49,10 +74,10 @@ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks } /** - * Function to construct comms_t and inject it on a handle_t. This - * is used for convenience in the Python layer. + * Factory function to construct a RAFT NCCL+UCX and inject it into a RAFT + * handle. * - * @param handle raft::handle_t for injecting the comms + * @param handle raft::device_resources for injecting the comms * @param nccl_comm initialized NCCL communicator to use for collectives * @param ucp_worker of local process * Note: This is purposefully left as void* so that the ucp_worker_h @@ -62,9 +87,35 @@ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks * the ucp_ep_h doesn't need to be exposed through the cython layer. * @param num_ranks number of ranks in communicator clique * @param rank rank of local instance + * + * @code{.cpp} + * #include + * #include + * + * ncclComm_t nccl_comm; + * raft::raft::device_resources handle; + * ucp_worker_h ucp_worker; + * ucp_ep_h *ucp_endpoints_arr; + * + * build_comms_nccl_ucx(&handle, nccl_comm, &ucp_worker, ucp_endpoints_arr, 5, 0); + * ... + * const auto& comm = handle.get_comms(); + * auto gather_data = raft::make_device_vector(handle, comm.get_size()); + * ... + * comm.allgather((gather_data.data_handle())[comm.get_rank()], + * gather_data.data_handle(), + * 1, + * handle.get_stream()); + * + * comm.sync_stream(handle.get_stream()); + * @endcode */ -void build_comms_nccl_ucx( - handle_t* handle, ncclComm_t nccl_comm, void* ucp_worker, void* eps, int num_ranks, int rank) +void build_comms_nccl_ucx(device_resources* handle, + ncclComm_t nccl_comm, + void* ucp_worker, + void* eps, + int num_ranks, + int rank) { auto eps_sp = std::make_shared(new ucp_ep_h[num_ranks]); @@ -90,6 +141,10 @@ void build_comms_nccl_ucx( handle->set_comms(communicator); } +/** + * @} + */ + inline void nccl_unique_id_from_char(ncclUniqueId* id, char* uniqueId, int size) { memcpy(id->internal, uniqueId, size); diff --git a/cpp/include/raft/core/comms.hpp b/cpp/include/raft/core/comms.hpp index 78ce91dbf2..463c17f2f6 100644 --- a/cpp/include/raft/core/comms.hpp +++ b/cpp/include/raft/core/comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -23,6 +24,11 @@ namespace raft { namespace comms { +/** + * @defgroup comms_types Common mnmg comms types + * @{ + */ + typedef unsigned int request_t; enum class datatype_t { CHAR, UINT8, INT32, UINT32, INT64, UINT64, FLOAT32, FLOAT64 }; enum class op_t { SUM, PROD, MIN, MAX }; @@ -105,6 +111,15 @@ get_type() return datatype_t::FLOAT64; } +/** + * @} + */ + +/** + * @defgroup comms_iface MNMG Communicator Interface + * @{ + */ + class comms_iface { public: virtual ~comms_iface() {} @@ -215,6 +230,15 @@ class comms_iface { virtual void group_end() const = 0; }; +/** + * @} + */ + +/** + * @defgroup comms_t Base Communicator Proxy + * @{ + */ + class comms_t { public: comms_t(std::unique_ptr impl) : impl_(impl.release()) @@ -647,5 +671,9 @@ class comms_t { std::unique_ptr impl_; }; +/** + * @} + */ + } // namespace comms } // namespace raft diff --git a/cpp/include/raft/core/cublas_macros.hpp b/cpp/include/raft/core/cublas_macros.hpp index d2456433ab..855c1228f7 100644 --- a/cpp/include/raft/core/cublas_macros.hpp +++ b/cpp/include/raft/core/cublas_macros.hpp @@ -32,6 +32,11 @@ namespace raft { +/** + * @ingroup error_handling + * @{ + */ + /** * @brief Exception thrown when a cuBLAS error is encountered. */ @@ -40,6 +45,10 @@ struct cublas_error : public raft::exception { explicit cublas_error(std::string const& message) : raft::exception(message) {} }; +/** + * @} + */ + namespace linalg { namespace detail { @@ -66,6 +75,11 @@ inline const char* cublas_error_to_string(cublasStatus_t err) #undef _CUBLAS_ERR_TO_STR +/** + * @ingroup assertion + * @{ + */ + /** * @brief Error checking macro for cuBLAS runtime API functions. * @@ -108,6 +122,9 @@ inline const char* cublas_error_to_string(cublasStatus_t err) } \ } while (0) +/** + * @} + */ /** FIXME: remove after cuml rename */ #ifndef CUBLAS_CHECK #define CUBLAS_CHECK(call) CUBLAS_TRY(call) diff --git a/cpp/include/raft/core/cusolver_macros.hpp b/cpp/include/raft/core/cusolver_macros.hpp index 505485e6a0..8f7caf65f3 100644 --- a/cpp/include/raft/core/cusolver_macros.hpp +++ b/cpp/include/raft/core/cusolver_macros.hpp @@ -31,6 +31,11 @@ namespace raft { +/** + * @ingroup error_handling + * @{ + */ + /** * @brief Exception thrown when a cuSOLVER error is encountered. */ @@ -39,6 +44,10 @@ struct cusolver_error : public raft::exception { explicit cusolver_error(std::string const& message) : raft::exception(message) {} }; +/** + * @} + */ + namespace linalg { namespace detail { @@ -65,6 +74,11 @@ inline const char* cusolver_error_to_string(cusolverStatus_t err) #undef _CUSOLVER_ERR_TO_STR +/** + * @ingroup assertion + * @{ + */ + /** * @brief Error checking macro for cuSOLVER runtime API functions. * @@ -107,6 +121,10 @@ inline const char* cusolver_error_to_string(cusolverStatus_t err) } \ } while (0) +/** + * @} + */ + // FIXME: remove after cuml rename #ifndef CUSOLVER_CHECK #define CUSOLVER_CHECK(call) CUSOLVER_TRY(call) diff --git a/cpp/include/raft/core/cusparse_macros.hpp b/cpp/include/raft/core/cusparse_macros.hpp index cf5195582b..8a9aab55f7 100644 --- a/cpp/include/raft/core/cusparse_macros.hpp +++ b/cpp/include/raft/core/cusparse_macros.hpp @@ -37,6 +37,11 @@ namespace raft { +/** + * @ingroup error_handling + * @{ + */ + /** * @brief Exception thrown when a cuSparse error is encountered. */ @@ -45,6 +50,9 @@ struct cusparse_error : public raft::exception { explicit cusparse_error(std::string const& message) : raft::exception(message) {} }; +/** + * @} + */ namespace sparse { namespace detail { @@ -73,6 +81,11 @@ inline const char* cusparse_error_to_string(cusparseStatus_t err) #undef _CUSPARSE_ERR_TO_STR +/** + * @ingroup assertion + * @{ + */ + /** * @brief Error checking macro for cuSparse runtime API functions. * @@ -94,6 +107,10 @@ inline const char* cusparse_error_to_string(cusparseStatus_t err) } \ } while (0) +/** + * @} + */ + // FIXME: Remove after consumer rename #ifndef CUSPARSE_TRY #define CUSPARSE_TRY(call) RAFT_CUSPARSE_TRY(call) @@ -104,6 +121,10 @@ inline const char* cusparse_error_to_string(cusparseStatus_t err) #define CUSPARSE_CHECK(call) CUSPARSE_TRY(call) #endif +/** + * @ingroup assertion + * @{ + */ //@todo: use logger here once logging is enabled /** check for cusparse runtime API errors but do not assert */ #define RAFT_CUSPARSE_TRY_NO_THROW(call) \ @@ -117,6 +138,10 @@ inline const char* cusparse_error_to_string(cusparseStatus_t err) } \ } while (0) +/** + * @} + */ + // FIXME: Remove after consumer rename #ifndef CUSPARSE_CHECK_NO_THROW #define CUSPARSE_CHECK_NO_THROW(call) RAFT_CUSPARSE_TRY_NO_THROW(call) diff --git a/cpp/include/raft/core/detail/device_mdarray.hpp b/cpp/include/raft/core/detail/device_mdarray.hpp index ad6831794e..31dfaba70a 100644 --- a/cpp/include/raft/core/detail/device_mdarray.hpp +++ b/cpp/include/raft/core/detail/device_mdarray.hpp @@ -6,7 +6,7 @@ */ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ */ #pragma once #include -#include +#include #include #include // dynamic_extent diff --git a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp new file mode 100644 index 0000000000..df89811636 --- /dev/null +++ b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp @@ -0,0 +1,487 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +namespace detail { + +namespace numpy_serializer { + +/** + * A small implementation of NumPy serialization format. + * Reference: https://numpy.org/doc/1.23/reference/generated/numpy.lib.format.html + * + * Adapted from https://github.com/llohse/libnpy/blob/master/include/npy.hpp, using the following + * license: + * + * MIT License + * + * Copyright (c) 2021 Leon Merten Lohse + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#define RAFT_NUMPY_LITTLE_ENDIAN_CHAR '<' +#define RAFT_NUMPY_BIG_ENDIAN_CHAR '>' +#define RAFT_NUMPY_NO_ENDIAN_CHAR '|' +#define RAFT_NUMPY_MAGIC_STRING "\x93NUMPY" +#define RAFT_NUMPY_MAGIC_STRING_LENGTH 6 + +#if RAFT_SYSTEM_LITTLE_ENDIAN == 1 +#define RAFT_NUMPY_HOST_ENDIAN_CHAR RAFT_NUMPY_LITTLE_ENDIAN_CHAR +#else // RAFT_SYSTEM_LITTLE_ENDIAN == 1 +#define RAFT_NUMPY_HOST_ENDIAN_CHAR RAFT_NUMPY_BIG_ENDIAN_CHAR +#endif // RAFT_SYSTEM_LITTLE_ENDIAN == 1 + +using ndarray_len_t = std::uint64_t; + +struct dtype_t { + const char byteorder; + const char kind; + const unsigned int itemsize; + + std::string to_string() const + { + char buf[16] = {0}; + std::sprintf(buf, "%c%c%u", byteorder, kind, itemsize); + return std::string(buf); + } + + bool operator==(const dtype_t& other) const + { + return (byteorder == other.byteorder && kind == other.kind && itemsize == other.itemsize); + } +}; + +struct header_t { + const dtype_t dtype; + const bool fortran_order; + const std::vector shape; + + bool operator==(const header_t& other) const + { + return (dtype == other.dtype && fortran_order == other.fortran_order && shape == other.shape); + } +}; + +template +struct is_complex : std::false_type { +}; +template +struct is_complex> : std::true_type { +}; + +template , bool> = true> +inline dtype_t get_numpy_dtype() +{ + return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'f', sizeof(T)}; +} + +template && std::is_signed_v, bool> = true> +inline dtype_t get_numpy_dtype() +{ + const char endian_char = + (sizeof(T) == 1 ? RAFT_NUMPY_NO_ENDIAN_CHAR : RAFT_NUMPY_HOST_ENDIAN_CHAR); + return {endian_char, 'i', sizeof(T)}; +} + +template && std::is_unsigned_v, bool> = true> +inline dtype_t get_numpy_dtype() +{ + const char endian_char = + (sizeof(T) == 1 ? RAFT_NUMPY_NO_ENDIAN_CHAR : RAFT_NUMPY_HOST_ENDIAN_CHAR); + return {endian_char, 'u', sizeof(T)}; +} + +template {}, bool> = true> +inline dtype_t get_numpy_dtype() +{ + return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'c', sizeof(T)}; +} + +template , bool> = true> +inline dtype_t get_numpy_dtype() +{ + return get_numpy_dtype>(); +} + +template +inline std::string tuple_to_string(const std::vector& tuple) +{ + std::ostringstream oss; + if (tuple.empty()) { + oss << "()"; + } else if (tuple.size() == 1) { + oss << "(" << tuple.front() << ",)"; + } else { + oss << "("; + for (std::size_t i = 0; i < tuple.size() - 1; ++i) { + oss << tuple[i] << ", "; + } + oss << tuple.back() << ")"; + } + return oss.str(); +} + +inline std::string header_to_string(const header_t& header) +{ + std::ostringstream oss; + oss << "{'descr': '" << header.dtype.to_string() + << "', 'fortran_order': " << (header.fortran_order ? "True" : "False") + << ", 'shape': " << tuple_to_string(header.shape) << "}"; + return oss.str(); +} + +inline std::string trim(const std::string& str) +{ + const std::string whitespace = " \t"; + auto begin = str.find_first_not_of(whitespace); + if (begin == std::string::npos) { return ""; } + auto end = str.find_last_not_of(whitespace); + + return str.substr(begin, end - begin + 1); +} + +// A poor man's parser for Python dictionary +// TODO(hcho3): Consider writing a proper parser +// Limitation: can only parse a flat dictionary; all values are assumed to non-objects +// Limitation: must know all the keys ahead of time; you get undefined behavior if you omit any key +inline std::map parse_pydict(std::string str, + const std::vector& keys) +{ + std::map result; + + // Unwrap dictionary + str = trim(str); + RAFT_EXPECTS(str.front() == '{' && str.back() == '}', "Expected a Python dictionary"); + str = str.substr(1, str.length() - 2); + + // Get the position of each key and put it in the list + std::vector> positions; + for (auto const& key : keys) { + std::size_t pos = str.find("'" + key + "'"); + RAFT_EXPECTS(pos != std::string::npos, "Missing '%s' key.", key.c_str()); + positions.emplace_back(pos, key); + } + // Sort the list + std::sort(positions.begin(), positions.end()); + + // Extract each key-value pair + for (std::size_t i = 0; i < positions.size(); ++i) { + std::string key = positions[i].second; + + std::size_t begin = positions[i].first; + std::size_t end = (i + 1 < positions.size() ? positions[i + 1].first : std::string::npos); + std::string raw_value = trim(str.substr(begin, end - begin)); + if (raw_value.back() == ',') { raw_value.pop_back(); } + std::size_t sep_pos = raw_value.find_first_of(":"); + if (sep_pos == std::string::npos) { + result[key] = ""; + } else { + result[key] = trim(raw_value.substr(sep_pos + 1)); + } + } + + return result; +} + +inline std::string parse_pystring(std::string str) +{ + RAFT_EXPECTS(str.front() == '\'' && str.back() == '\'', "Invalid Python string: %s", str.c_str()); + return str.substr(1, str.length() - 2); +} + +inline bool parse_pybool(std::string str) +{ + if (str == "True") { + return true; + } else if (str == "False") { + return false; + } else { + RAFT_FAIL("Invalid Python boolean: %s", str.c_str()); + } +} + +inline std::vector parse_pytuple(std::string str) +{ + std::vector result; + + str = trim(str); + RAFT_EXPECTS(str.front() == '(' && str.back() == ')', "Invalid Python tuple: %s", str.c_str()); + str = str.substr(1, str.length() - 2); + + std::istringstream iss(str); + for (std::string token; std::getline(iss, token, ',');) { + result.push_back(trim(token)); + } + + return result; +} + +inline dtype_t parse_descr(std::string typestr) +{ + RAFT_EXPECTS(typestr.length() >= 3, "Invalid typestr: Too short"); + char byteorder_c = typestr.at(0); + char kind_c = typestr.at(1); + std::string itemsize_s = typestr.substr(2); + + const char endian_chars[] = { + RAFT_NUMPY_LITTLE_ENDIAN_CHAR, RAFT_NUMPY_BIG_ENDIAN_CHAR, RAFT_NUMPY_NO_ENDIAN_CHAR}; + const char numtype_chars[] = {'f', 'i', 'u', 'c'}; + + RAFT_EXPECTS(std::find(std::begin(endian_chars), std::end(endian_chars), byteorder_c) != + std::end(endian_chars), + "Invalid typestr: unrecognized byteorder %c", + byteorder_c); + RAFT_EXPECTS(std::find(std::begin(numtype_chars), std::end(numtype_chars), kind_c) != + std::end(numtype_chars), + "Invalid typestr: unrecognized kind %c", + kind_c); + unsigned int itemsize = std::stoul(itemsize_s); + + return {byteorder_c, kind_c, itemsize}; +} + +inline void write_magic(std::ostream& os) +{ + os.write(RAFT_NUMPY_MAGIC_STRING, RAFT_NUMPY_MAGIC_STRING_LENGTH); + RAFT_EXPECTS(os.good(), "Error writing magic string"); + // Use version 1.0 + os.put(1); + os.put(0); + RAFT_EXPECTS(os.good(), "Error writing magic string"); +} + +inline void read_magic(std::istream& is) +{ + char magic_buf[RAFT_NUMPY_MAGIC_STRING_LENGTH + 2] = {0}; + is.read(magic_buf, RAFT_NUMPY_MAGIC_STRING_LENGTH + 2); + RAFT_EXPECTS(is.good(), "Error reading magic string"); + + RAFT_EXPECTS(std::memcmp(magic_buf, RAFT_NUMPY_MAGIC_STRING, RAFT_NUMPY_MAGIC_STRING_LENGTH) == 0, + "The given stream does not have a valid NumPy format."); + + std::uint8_t version_major = magic_buf[RAFT_NUMPY_MAGIC_STRING_LENGTH]; + std::uint8_t version_minor = magic_buf[RAFT_NUMPY_MAGIC_STRING_LENGTH + 1]; + RAFT_EXPECTS(version_major == 1 && version_minor == 0, + "Unsupported NumPy version: %d.%d", + version_major, + version_minor); +} + +inline void write_header(std::ostream& os, const header_t& header) +{ + std::string header_dict = header_to_string(header); + std::size_t preamble_length = RAFT_NUMPY_MAGIC_STRING_LENGTH + 2 + 2 + header_dict.length() + 1; + RAFT_EXPECTS(preamble_length < 255 * 255, "Header too long"); + // Enforce 64-byte alignment + std::size_t padding_len = 64 - preamble_length % 64; + std::string padding(padding_len, ' '); + + write_magic(os); + + // Write header length + std::uint8_t header_len_le16[2]; + std::uint16_t header_len = + static_cast(header_dict.length() + padding.length() + 1); + header_len_le16[0] = (header_len >> 0) & 0xff; + header_len_le16[1] = (header_len >> 8) & 0xff; + os.put(header_len_le16[0]); + os.put(header_len_le16[1]); + RAFT_EXPECTS(os.good(), "Error writing HEADER_LEN"); + + os << header_dict << padding << "\n"; + RAFT_EXPECTS(os.good(), "Error writing header dict"); +} + +inline std::string read_header_bytes(std::istream& is) +{ + read_magic(is); + + // Read header length + std::uint8_t header_len_le16[2]; + is.read(reinterpret_cast(header_len_le16), 2); + RAFT_EXPECTS(is.good(), "Error while reading HEADER_LEN"); + const std::uint32_t header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8); + + std::vector header_bytes(header_length); + is.read(header_bytes.data(), header_length); + RAFT_EXPECTS(is.good(), "Error while reading the header"); + + return std::string(header_bytes.data(), header_length); +} + +inline header_t read_header(std::istream& is) +{ + std::string header_bytes = read_header_bytes(is); + + // remove trailing newline + RAFT_EXPECTS(header_bytes.back() == '\n', "Invalid NumPy header"); + header_bytes.pop_back(); + + // parse the header dict + auto header_dict = parse_pydict(header_bytes, {"descr", "fortran_order", "shape"}); + dtype_t descr = parse_descr(parse_pystring(header_dict["descr"])); + bool fortran_order = parse_pybool(header_dict["fortran_order"]); + std::vector shape; + auto shape_tup_str = parse_pytuple(header_dict["shape"]); + for (const auto& e : shape_tup_str) { + shape.push_back(static_cast(std::stoul(e))); + } + + RAFT_EXPECTS( + descr.byteorder == RAFT_NUMPY_HOST_ENDIAN_CHAR || descr.byteorder == RAFT_NUMPY_NO_ENDIAN_CHAR, + "The mdspan was serialized on a %s machine but you're attempting to load it on " + "a %s machine. This use case is not currently supported.", + (RAFT_SYSTEM_LITTLE_ENDIAN ? "big-endian" : "little-endian"), + (RAFT_SYSTEM_LITTLE_ENDIAN ? "little-endian" : "big-endian")); + + return {descr, fortran_order, shape}; +} + +template +inline void serialize_host_mdspan( + std::ostream& os, + const raft::host_mdspan& obj) +{ + static_assert(std::is_same_v || + std::is_same_v, + "The serializer only supports row-major and column-major layouts"); + + using obj_t = raft::host_mdspan; + + const auto dtype = get_numpy_dtype(); + const bool fortran_order = std::is_same_v; + std::vector shape; + for (typename obj_t::rank_type i = 0; i < obj.rank(); ++i) { + shape.push_back(obj.extent(i)); + } + const header_t header = {dtype, fortran_order, shape}; + write_header(os, header); + + // For contiguous layouts, size() == product of dimensions + os.write(reinterpret_cast(obj.data_handle()), obj.size() * sizeof(ElementType)); + RAFT_EXPECTS(os.good(), "Error writing content of mdspan"); +} + +template +inline void serialize_scalar(std::ostream& os, const T& value) +{ + const auto dtype = get_numpy_dtype(); + const bool fortran_order = false; + const std::vector shape{}; + const header_t header = {dtype, fortran_order, shape}; + write_header(os, header); + os.write(reinterpret_cast(&value), sizeof(T)); + RAFT_EXPECTS(os.good(), "Error serializing a scalar"); +} + +template +inline void deserialize_host_mdspan( + std::istream& is, + const raft::host_mdspan& obj) +{ + static_assert(std::is_same_v || + std::is_same_v, + "The serializer only supports row-major and column-major layouts"); + + using obj_t = raft::host_mdspan; + + // Check if given dtype and fortran_order are compatible with the mdspan + const auto expected_dtype = get_numpy_dtype(); + const bool expected_fortran_order = std::is_same_v; + header_t header = read_header(is); + RAFT_EXPECTS(header.dtype == expected_dtype, + "Expected dtype %s but got %s instead", + header.dtype.to_string().c_str(), + expected_dtype.to_string().c_str()); + RAFT_EXPECTS(header.fortran_order == expected_fortran_order, + "Wrong matrix layout; expected %s but got a different layout", + (expected_fortran_order ? "Fortran layout" : "C layout")); + + // Check if dimensions are correct + RAFT_EXPECTS(obj.rank() == header.shape.size(), + "Incorrect rank: expected %zu but got %zu", + obj.rank(), + header.shape.size()); + for (typename obj_t::rank_type i = 0; i < obj.rank(); ++i) { + RAFT_EXPECTS(static_cast(obj.extent(i)) == header.shape[i], + "Incorrect dimension: expected %zu but got %zu", + static_cast(obj.extent(i)), + header.shape[i]); + } + + // For contiguous layouts, size() == product of dimensions + is.read(reinterpret_cast(obj.data_handle()), obj.size() * sizeof(ElementType)); + RAFT_EXPECTS(is.good(), "Error while reading mdspan content"); +} + +template +inline T deserialize_scalar(std::istream& is) +{ + // Check if dtype is correct + const auto expected_dtype = get_numpy_dtype(); + header_t header = read_header(is); + RAFT_EXPECTS(header.dtype == expected_dtype, + "Expected dtype %s but got %s instead", + header.dtype.to_string().c_str(), + expected_dtype.to_string().c_str()); + // Check if dimensions are correct; shape should be () + RAFT_EXPECTS(header.shape.empty(), "Incorrect rank: expected 0 but got %zu", header.shape.size()); + + T value; + is.read(reinterpret_cast(&value), sizeof(T)); + RAFT_EXPECTS(is.good(), "Error while deserializing scalar"); + return value; +} + +} // end namespace numpy_serializer +} // end namespace detail +} // end namespace raft diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index 693e50a506..03cb09eecb 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,7 +72,7 @@ using device_matrix = device_mdarray, Layo * @tparam ElementType the data type of the matrix elements * @tparam IndexType the index type of the extents * @tparam LayoutPolicy policy for strides and layout ordering - * @param handle raft::handle_t + * @param handle raft::device_resources * @param exts dimensionality of the array (series of integers) * @return raft::device_mdarray */ @@ -80,7 +80,7 @@ template -auto make_device_mdarray(const raft::handle_t& handle, extents exts) +auto make_device_mdarray(raft::device_resources const& handle, extents exts) { using mdarray_t = device_mdarray; @@ -95,7 +95,7 @@ auto make_device_mdarray(const raft::handle_t& handle, extents -auto make_device_mdarray(const raft::handle_t& handle, +auto make_device_mdarray(raft::device_resources const& handle, rmm::mr::device_memory_resource* mr, extents exts) { @@ -130,7 +130,7 @@ auto make_device_mdarray(const raft::handle_t& handle, template -auto make_device_matrix(raft::handle_t const& handle, IndexType n_rows, IndexType n_cols) +auto make_device_matrix(raft::device_resources const& handle, IndexType n_rows, IndexType n_cols) { return make_device_mdarray( handle.get_stream(), make_extents(n_rows, n_cols)); @@ -146,7 +146,7 @@ auto make_device_matrix(raft::handle_t const& handle, IndexType n_rows, IndexTyp * @return raft::device_scalar */ template -auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) +auto make_device_scalar(raft::device_resources const& handle, ElementType const& v) { scalar_extent extents; using policy_t = typename device_scalar::container_policy_type; @@ -168,7 +168,7 @@ auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) template -auto make_device_vector(raft::handle_t const& handle, IndexType n) +auto make_device_vector(raft::device_resources const& handle, IndexType n) { return make_device_mdarray(handle.get_stream(), make_extents(n)); diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index f64f15d0d5..f72ae36d64 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -197,7 +197,9 @@ auto make_device_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexTy detail::alignment::value>::data_handle_type; static_assert(std::is_same>::value || std::is_same>::value); - assert(ptr == alignTo(ptr, detail::alignment::value)); + assert(reinterpret_cast(ptr) == + std::experimental::details::alignTo(reinterpret_cast(ptr), + detail::alignment::value)); data_handle_type aligned_pointer = ptr; diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp new file mode 100644 index 0000000000..68c56dc9b6 --- /dev/null +++ b/cpp/include/raft/core/device_resources.hpp @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __RAFT_DEVICE_RESOURCES +#define __RAFT_DEVICE_RESOURCES + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +/** + * @brief Main resource container object that stores all necessary resources + * used for calling necessary device functions, cuda kernels and/or libraries + */ +class device_resources : public resources { + public: + device_resources(const device_resources& handle, + rmm::mr::device_memory_resource* workspace_resource) + : resources{handle} + { + // replace the resource factory for the workspace_resources + resources::add_resource_factory( + std::make_shared(workspace_resource)); + } + + device_resources(const device_resources& handle) : resources{handle} {} + + device_resources(device_resources&&) = delete; + device_resources& operator=(device_resources&&) = delete; + + /** + * @brief Construct a resources instance with a stream view and stream pool + * + * @param[in] stream_view the default stream (which has the default per-thread stream if + * unspecified) + * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) + * @param[in] workspace_resource an optional resource used by some functions for allocating + * temporary workspaces. + */ + device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, + std::shared_ptr stream_pool = {nullptr}, + rmm::mr::device_memory_resource* workspace_resource = nullptr) + : resources{} + { + resources::add_resource_factory(std::make_shared()); + resources::add_resource_factory( + std::make_shared(stream_view)); + resources::add_resource_factory( + std::make_shared(stream_pool)); + resources::add_resource_factory( + std::make_shared(workspace_resource)); + } + + /** Destroys all held-up resources */ + virtual ~device_resources() {} + + int get_device() const { return resource::get_device_id(*this); } + + cublasHandle_t get_cublas_handle() const { return resource::get_cublas_handle(*this); } + + cusolverDnHandle_t get_cusolver_dn_handle() const + { + return resource::get_cusolver_dn_handle(*this); + } + + cusolverSpHandle_t get_cusolver_sp_handle() const + { + return resource::get_cusolver_sp_handle(*this); + } + + cusparseHandle_t get_cusparse_handle() const { return resource::get_cusparse_handle(*this); } + + rmm::exec_policy& get_thrust_policy() const { return resource::get_thrust_policy(*this); } + + /** + * @brief synchronize a stream on the current container + */ + void sync_stream(rmm::cuda_stream_view stream) const { resource::sync_stream(*this, stream); } + + /** + * @brief synchronize main stream on the current container + */ + void sync_stream() const { resource::sync_stream(*this); } + + /** + * @brief returns main stream on the current container + */ + rmm::cuda_stream_view get_stream() const { return resource::get_cuda_stream(*this); } + + /** + * @brief returns whether stream pool was initialized on the current container + */ + + bool is_stream_pool_initialized() const { return resource::is_stream_pool_initialized(*this); } + + /** + * @brief returns stream pool on the current container + */ + const rmm::cuda_stream_pool& get_stream_pool() const + { + return resource::get_cuda_stream_pool(*this); + } + + std::size_t get_stream_pool_size() const { return resource::get_stream_pool_size(*this); } + + /** + * @brief return stream from pool + */ + rmm::cuda_stream_view get_stream_from_stream_pool() const + { + return resource::get_stream_from_stream_pool(*this); + } + + /** + * @brief return stream from pool at index + */ + rmm::cuda_stream_view get_stream_from_stream_pool(std::size_t stream_idx) const + { + return resource::get_stream_from_stream_pool(*this, stream_idx); + } + + /** + * @brief return stream from pool if size > 0, else main stream on current container + */ + rmm::cuda_stream_view get_next_usable_stream() const + { + return resource::get_next_usable_stream(*this); + } + + /** + * @brief return stream from pool at index if size > 0, else main stream on current container + * + * @param[in] stream_idx the required index of the stream in the stream pool if available + */ + rmm::cuda_stream_view get_next_usable_stream(std::size_t stream_idx) const + { + return resource::get_next_usable_stream(*this, stream_idx); + } + + /** + * @brief synchronize the stream pool on the current container + */ + void sync_stream_pool() const { return resource::sync_stream_pool(*this); } + + /** + * @brief synchronize subset of stream pool + * + * @param[in] stream_indices the indices of the streams in the stream pool to synchronize + */ + void sync_stream_pool(const std::vector stream_indices) const + { + return resource::sync_stream_pool(*this, stream_indices); + } + + /** + * @brief ask stream pool to wait on last event in main stream + */ + void wait_stream_pool_on_stream() const { return resource::wait_stream_pool_on_stream(*this); } + + void set_comms(std::shared_ptr communicator) + { + resource::set_comms(*this, communicator); + } + + const comms::comms_t& get_comms() const { return resource::get_comms(*this); } + + void set_subcomm(std::string key, std::shared_ptr subcomm) + { + resource::set_subcomm(*this, key, subcomm); + } + + const comms::comms_t& get_subcomm(std::string key) const + { + return resource::get_subcomm(*this, key); + } + + rmm::mr::device_memory_resource* get_workspace_resource() const + { + return resource::get_workspace_resource(*this); + } + + bool comms_initialized() const { return resource::comms_initialized(*this); } + + const cudaDeviceProp& get_device_properties() const + { + return resource::get_device_properties(*this); + } +}; // class device_resources + +/** + * @brief RAII approach to synchronizing across all streams in the current container + */ +class stream_syncer { + public: + explicit stream_syncer(const device_resources& handle) : handle_(handle) + { + handle_.sync_stream(); + } + ~stream_syncer() + { + handle_.wait_stream_pool_on_stream(); + handle_.sync_stream_pool(); + } + + stream_syncer(const stream_syncer& other) = delete; + stream_syncer& operator=(const stream_syncer& other) = delete; + + private: + const device_resources& handle_; +}; // class stream_syncer + +} // namespace raft + +#endif \ No newline at end of file diff --git a/cpp/include/raft/core/error.hpp b/cpp/include/raft/core/error.hpp index b932309d24..84b244f4dc 100644 --- a/cpp/include/raft/core/error.hpp +++ b/cpp/include/raft/core/error.hpp @@ -30,6 +30,11 @@ namespace raft { +/** + * @defgroup error_handling Exceptions & Error Handling + * @{ + */ + /** base exception class for the whole of raft */ class exception : public std::exception { public: @@ -93,6 +98,10 @@ struct logic_error : public raft::exception { explicit logic_error(std::string const& message) : raft::exception(message) {} }; +/** + * @} + */ + } // namespace raft // FIXME: Need to be replaced with RAFT_FAIL @@ -143,6 +152,11 @@ struct logic_error : public raft::exception { msg += std::string(buf.data(), buf.data() + size - 1); /* -1 to remove final '\0' */ \ } while (0) +/** + * @defgroup assertion Assertion and error macros + * @{ + */ + /** * @brief Macro for checking (pre-)conditions that throws an exception when a condition is false * @@ -174,4 +188,8 @@ struct logic_error : public raft::exception { throw raft::logic_error(msg); \ } while (0) +/** + * @} + */ + #endif diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index 08cb812bb7..02efebec9e 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,326 +14,52 @@ * limitations under the License. */ -#ifndef __RAFT_RT_HANDLE -#define __RAFT_RT_HANDLE - #pragma once -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -///@todo: enable once we have migrated cuml-comms layer too -//#include - -#include - -#include -#include -#include -#include -#include -#include -#include +#include namespace raft { /** - * @brief Main handle object that stores all necessary context used for calling - * necessary cuda kernels and/or libraries + * raft::handle_t is being kept around for backwards + * compatibility and will be removed in a future version. + * + * Extending the `raft::handle_t` instead of `using` to + * minimize needed changes downstream + * (e.g. existing forward declarations, etc...) + * + * Use of `raft::resources` or `raft::handle_t` is preferred. */ -class handle_t { +class handle_t : public raft::device_resources { public: - // delete copy/move constructors and assignment operators as - // copying and moving underlying resources is unsafe - handle_t(const handle_t&) = delete; - handle_t& operator=(const handle_t&) = delete; - handle_t(handle_t&&) = delete; + handle_t(const handle_t& handle, rmm::mr::device_memory_resource* workspace_resource) + : device_resources(handle, workspace_resource) + { + } + + handle_t(const handle_t& handle) : device_resources{handle} {} + + handle_t(handle_t&&) = delete; handle_t& operator=(handle_t&&) = delete; /** - * @brief Construct a handle with a stream view and stream pool + * @brief Construct a resources instance with a stream view and stream pool * * @param[in] stream_view the default stream (which has the default per-thread stream if * unspecified) * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) + * @param[in] workspace_resource an optional resource used by some functions for allocating + * temporary workspaces. */ - handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}) - : dev_id_([]() -> int { - int cur_dev = -1; - RAFT_CUDA_TRY(cudaGetDevice(&cur_dev)); - return cur_dev; - }()), - stream_view_{stream_view}, - stream_pool_{stream_pool} + handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, + std::shared_ptr stream_pool = {nullptr}, + rmm::mr::device_memory_resource* workspace_resource = nullptr) + : device_resources{stream_view, stream_pool, workspace_resource} { - create_resources(); } /** Destroys all held-up resources */ - virtual ~handle_t() { destroy_resources(); } - - int get_device() const { return dev_id_; } - - cublasHandle_t get_cublas_handle() const - { - std::lock_guard _(mutex_); - if (!cublas_initialized_) { - RAFT_CUBLAS_TRY_NO_THROW(cublasCreate(&cublas_handle_)); - RAFT_CUBLAS_TRY_NO_THROW(cublasSetStream(cublas_handle_, stream_view_)); - cublas_initialized_ = true; - } - return cublas_handle_; - } - - cusolverDnHandle_t get_cusolver_dn_handle() const - { - std::lock_guard _(mutex_); - if (!cusolver_dn_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnCreate(&cusolver_dn_handle_)); - RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnSetStream(cusolver_dn_handle_, stream_view_)); - cusolver_dn_initialized_ = true; - } - return cusolver_dn_handle_; - } - - cusolverSpHandle_t get_cusolver_sp_handle() const - { - std::lock_guard _(mutex_); - if (!cusolver_sp_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpCreate(&cusolver_sp_handle_)); - RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpSetStream(cusolver_sp_handle_, stream_view_)); - cusolver_sp_initialized_ = true; - } - return cusolver_sp_handle_; - } - - cusparseHandle_t get_cusparse_handle() const - { - std::lock_guard _(mutex_); - if (!cusparse_initialized_) { - RAFT_CUSPARSE_TRY_NO_THROW(cusparseCreate(&cusparse_handle_)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseSetStream(cusparse_handle_, stream_view_)); - cusparse_initialized_ = true; - } - return cusparse_handle_; - } - - rmm::exec_policy& get_thrust_policy() const { return *thrust_policy_; } - - /** - * @brief synchronize a stream on the handle - */ - void sync_stream(rmm::cuda_stream_view stream) const { interruptible::synchronize(stream); } - - /** - * @brief synchronize main stream on the handle - */ - void sync_stream() const { sync_stream(stream_view_); } - - /** - * @brief returns main stream on the handle - */ - rmm::cuda_stream_view get_stream() const { return stream_view_; } - - /** - * @brief returns whether stream pool was initialized on the handle - */ - - bool is_stream_pool_initialized() const { return stream_pool_.get() != nullptr; } - - /** - * @brief returns stream pool on the handle - */ - const rmm::cuda_stream_pool& get_stream_pool() const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - return *stream_pool_; - } - - std::size_t get_stream_pool_size() const - { - return is_stream_pool_initialized() ? stream_pool_->get_pool_size() : 0; - } - - /** - * @brief return stream from pool - */ - rmm::cuda_stream_view get_stream_from_stream_pool() const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - return stream_pool_->get_stream(); - } - - /** - * @brief return stream from pool at index - */ - rmm::cuda_stream_view get_stream_from_stream_pool(std::size_t stream_idx) const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - return stream_pool_->get_stream(stream_idx); - } - - /** - * @brief return stream from pool if size > 0, else main stream on handle - */ - rmm::cuda_stream_view get_next_usable_stream() const - { - return is_stream_pool_initialized() ? get_stream_from_stream_pool() : stream_view_; - } - - /** - * @brief return stream from pool at index if size > 0, else main stream on handle - * - * @param[in] stream_idx the required index of the stream in the stream pool if available - */ - rmm::cuda_stream_view get_next_usable_stream(std::size_t stream_idx) const - { - return is_stream_pool_initialized() ? get_stream_from_stream_pool(stream_idx) : stream_view_; - } - - /** - * @brief synchronize the stream pool on the handle - */ - void sync_stream_pool() const - { - for (std::size_t i = 0; i < get_stream_pool_size(); i++) { - sync_stream(stream_pool_->get_stream(i)); - } - } - - /** - * @brief synchronize subset of stream pool - * - * @param[in] stream_indices the indices of the streams in the stream pool to synchronize - */ - void sync_stream_pool(const std::vector stream_indices) const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - for (const auto& stream_index : stream_indices) { - sync_stream(stream_pool_->get_stream(stream_index)); - } - } - - /** - * @brief ask stream pool to wait on last event in main stream - */ - void wait_stream_pool_on_stream() const - { - RAFT_CUDA_TRY(cudaEventRecord(event_, stream_view_)); - for (std::size_t i = 0; i < get_stream_pool_size(); i++) { - RAFT_CUDA_TRY(cudaStreamWaitEvent(stream_pool_->get_stream(i), event_, 0)); - } - } - - void set_comms(std::shared_ptr communicator) { communicator_ = communicator; } - - const comms::comms_t& get_comms() const - { - RAFT_EXPECTS(this->comms_initialized(), "ERROR: Communicator was not initialized\n"); - return *communicator_; - } - - void set_subcomm(std::string key, std::shared_ptr subcomm) - { - subcomms_[key] = subcomm; - } - - const comms::comms_t& get_subcomm(std::string key) const - { - RAFT_EXPECTS( - subcomms_.find(key) != subcomms_.end(), "%s was not found in subcommunicators.", key.c_str()); - - auto subcomm = subcomms_.at(key); - - RAFT_EXPECTS(nullptr != subcomm.get(), "ERROR: Subcommunicator was not initialized"); - - return *subcomm; - } - - bool comms_initialized() const { return (nullptr != communicator_.get()); } - - const cudaDeviceProp& get_device_properties() const - { - std::lock_guard _(mutex_); - if (!device_prop_initialized_) { - RAFT_CUDA_TRY_NO_THROW(cudaGetDeviceProperties(&prop_, dev_id_)); - device_prop_initialized_ = true; - } - return prop_; - } - - private: - std::shared_ptr communicator_; - std::unordered_map> subcomms_; - - const int dev_id_; - mutable cublasHandle_t cublas_handle_; - mutable bool cublas_initialized_{false}; - mutable cusolverDnHandle_t cusolver_dn_handle_; - mutable bool cusolver_dn_initialized_{false}; - mutable cusolverSpHandle_t cusolver_sp_handle_; - mutable bool cusolver_sp_initialized_{false}; - mutable cusparseHandle_t cusparse_handle_; - mutable bool cusparse_initialized_{false}; - std::unique_ptr thrust_policy_{nullptr}; - rmm::cuda_stream_view stream_view_{rmm::cuda_stream_per_thread}; - std::shared_ptr stream_pool_{nullptr}; - cudaEvent_t event_; - mutable cudaDeviceProp prop_; - mutable bool device_prop_initialized_{false}; - mutable std::mutex mutex_; - - void create_resources() - { - thrust_policy_ = std::make_unique(stream_view_); - - RAFT_CUDA_TRY(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); - } - - void destroy_resources() - { - if (cusparse_initialized_) { RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroy(cusparse_handle_)); } - if (cusolver_dn_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnDestroy(cusolver_dn_handle_)); - } - if (cusolver_sp_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpDestroy(cusolver_sp_handle_)); - } - if (cublas_initialized_) { RAFT_CUBLAS_TRY_NO_THROW(cublasDestroy(cublas_handle_)); } - RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(event_)); - } -}; // class handle_t - -/** - * @brief RAII approach to synchronizing across all streams in the handle - */ -class stream_syncer { - public: - explicit stream_syncer(const handle_t& handle) : handle_(handle) { handle_.sync_stream(); } - ~stream_syncer() - { - handle_.wait_stream_pool_on_stream(); - handle_.sync_stream_pool(); - } - - stream_syncer(const stream_syncer& other) = delete; - stream_syncer& operator=(const stream_syncer& other) = delete; - - private: - const handle_t& handle_; -}; // class stream_syncer - -} // namespace raft + ~handle_t() override {} +}; -#endif \ No newline at end of file +} // end NAMESPACE raft diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index 1a0ea6432f..a6cdec7a84 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -144,7 +144,9 @@ auto make_host_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType static_assert(std::is_same>::value || std::is_same>::value); - assert(ptr == alignTo(ptr, detail::alignment::value)); + assert(reinterpret_cast(ptr) == + std::experimental::details::alignTo(reinterpret_cast(ptr), + detail::alignment::value)); data_handle_type aligned_pointer = ptr; matrix_extent extents{n_rows, n_cols}; diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp index f6ea841dc4..8d3321eb77 100644 --- a/cpp/include/raft/core/kvp.hpp +++ b/cpp/include/raft/core/kvp.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #ifdef _RAFT_HAS_CUDA #include +#include #endif namespace raft { /** @@ -58,5 +59,27 @@ struct KeyValuePair { { return (value != b.value) || (key != b.key); } + + RAFT_INLINE_FUNCTION bool operator<(const KeyValuePair<_Key, _Value>& b) const + { + return (key < b.key) || ((key == b.key) && value < b.value); + } + + RAFT_INLINE_FUNCTION bool operator>(const KeyValuePair<_Key, _Value>& b) const + { + return (key > b.key) || ((key == b.key) && value > b.value); + } }; + +#ifdef _RAFT_HAS_CUDA +template +RAFT_INLINE_FUNCTION KeyValuePair<_Key, _Value> shfl_xor(const KeyValuePair<_Key, _Value>& input, + int laneMask, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ + return KeyValuePair<_Key, _Value>(shfl_xor(input.key, laneMask, width, mask), + shfl_xor(input.value, laneMask, width, mask)); +} +#endif } // end namespace raft diff --git a/cpp/include/raft/core/math.hpp b/cpp/include/raft/core/math.hpp new file mode 100644 index 0000000000..c5f08b84b7 --- /dev/null +++ b/cpp/include/raft/core/math.hpp @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace raft { + +/** + * @defgroup Absolute Absolute value + * @{ + */ +template +RAFT_INLINE_FUNCTION auto abs(T x) + -> std::enable_if_t || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + T> +{ +#ifdef __CUDA_ARCH__ + return ::abs(x); +#else + return std::abs(x); +#endif +} +template +constexpr RAFT_INLINE_FUNCTION auto abs(T x) + -> std::enable_if_t && !std::is_same_v && + !std::is_same_v && !std::is_same_v && + !std::is_same_v, + T> +{ + return x < T{0} ? -x : x; +} +/** @} */ + +/** + * @defgroup Trigonometry Trigonometry functions + * @{ + */ +/** Inverse cosine */ +template +RAFT_INLINE_FUNCTION auto acos(T x) +{ +#ifdef __CUDA_ARCH__ + return ::acos(x); +#else + return std::acos(x); +#endif +} + +/** Inverse sine */ +template +RAFT_INLINE_FUNCTION auto asin(T x) +{ +#ifdef __CUDA_ARCH__ + return ::asin(x); +#else + return std::asin(x); +#endif +} + +/** Inverse hyperbolic tangent */ +template +RAFT_INLINE_FUNCTION auto atanh(T x) +{ +#ifdef __CUDA_ARCH__ + return ::atanh(x); +#else + return std::atanh(x); +#endif +} + +/** Cosine */ +template +RAFT_INLINE_FUNCTION auto cos(T x) +{ +#ifdef __CUDA_ARCH__ + return ::cos(x); +#else + return std::cos(x); +#endif +} + +/** Sine */ +template +RAFT_INLINE_FUNCTION auto sin(T x) +{ +#ifdef __CUDA_ARCH__ + return ::sin(x); +#else + return std::sin(x); +#endif +} + +/** Sine and cosine */ +template +RAFT_INLINE_FUNCTION std::enable_if_t || std::is_same_v> sincos( + const T& x, T* s, T* c) +{ +#ifdef __CUDA_ARCH__ + ::sincos(x, s, c); +#else + *s = std::sin(x); + *c = std::cos(x); +#endif +} + +/** Hyperbolic tangent */ +template +RAFT_INLINE_FUNCTION auto tanh(T x) +{ +#ifdef __CUDA_ARCH__ + return ::tanh(x); +#else + return std::tanh(x); +#endif +} +/** @} */ + +/** + * @defgroup Exponential Exponential and logarithm + * @{ + */ +/** Exponential function */ +template +RAFT_INLINE_FUNCTION auto exp(T x) +{ +#ifdef __CUDA_ARCH__ + return ::exp(x); +#else + return std::exp(x); +#endif +} + +/** Natural logarithm */ +template +RAFT_INLINE_FUNCTION auto log(T x) +{ +#ifdef __CUDA_ARCH__ + return ::log(x); +#else + return std::log(x); +#endif +} +/** @} */ + +/** + * @defgroup Maximum Maximum of two or more values. + * + * The CUDA Math API has overloads for all combinations of float/double. We provide similar + * functionality while wrapping around std::max, which only supports arguments of the same type. + * However, though the CUDA Math API supports combinations of unsigned and signed integers, this is + * very error-prone so we do not support that and require the user to cast instead. (e.g the max of + * -1 and 1u is 4294967295u...) + * + * When no overload matches, we provide a generic implementation but require that both types be the + * same (and that the less-than operator be defined). + * @{ + */ +template +RAFT_INLINE_FUNCTION auto max(const T1& x, const T2& y) +{ +#ifdef __CUDA_ARCH__ + // Combinations of types supported by the CUDA Math API + if constexpr ((std::is_integral_v && std::is_integral_v && std::is_same_v) || + ((std::is_same_v || std::is_same_v)&&( + std::is_same_v || std::is_same_v))) { + return ::max(x, y); + } + // Else, check that the types are the same and provide a generic implementation + else { + static_assert( + std::is_same_v, + "No native max overload for these types. Both argument types must be the same to use " + "the generic max. Please cast appropriately."); + return (x < y) ? y : x; + } +#else + if constexpr (std::is_same_v && std::is_same_v) { + return std::max(static_cast(x), y); + } else if constexpr (std::is_same_v && std::is_same_v) { + return std::max(x, static_cast(y)); + } else { + static_assert( + std::is_same_v, + "std::max requires that both argument types be the same. Please cast appropriately."); + return std::max(x, y); + } +#endif +} + +/** Many-argument overload to avoid verbose nested calls or use with variadic arguments */ +template +RAFT_INLINE_FUNCTION auto max(const T1& x, const T2& y, Args&&... args) +{ + return raft::max(x, raft::max(y, std::forward(args)...)); +} + +/** One-argument overload for convenience when using with variadic arguments */ +template +constexpr RAFT_INLINE_FUNCTION auto max(const T& x) +{ + return x; +} +/** @} */ + +/** + * @defgroup Minimum Minimum of two or more values. + * + * The CUDA Math API has overloads for all combinations of float/double. We provide similar + * functionality while wrapping around std::min, which only supports arguments of the same type. + * However, though the CUDA Math API supports combinations of unsigned and signed integers, this is + * very error-prone so we do not support that and require the user to cast instead. (e.g the min of + * -1 and 1u is 1u...) + * + * When no overload matches, we provide a generic implementation but require that both types be the + * same (and that the less-than operator be defined). + * @{ + */ +template +RAFT_INLINE_FUNCTION auto min(const T1& x, const T2& y) +{ +#ifdef __CUDA_ARCH__ + // Combinations of types supported by the CUDA Math API + if constexpr ((std::is_integral_v && std::is_integral_v && std::is_same_v) || + ((std::is_same_v || std::is_same_v)&&( + std::is_same_v || std::is_same_v))) { + return ::min(x, y); + } + // Else, check that the types are the same and provide a generic implementation + else { + static_assert( + std::is_same_v, + "No native min overload for these types. Both argument types must be the same to use " + "the generic min. Please cast appropriately."); + return (y < x) ? y : x; + } +#else + if constexpr (std::is_same_v && std::is_same_v) { + return std::min(static_cast(x), y); + } else if constexpr (std::is_same_v && std::is_same_v) { + return std::min(x, static_cast(y)); + } else { + static_assert( + std::is_same_v, + "std::min requires that both argument types be the same. Please cast appropriately."); + return std::min(x, y); + } +#endif +} + +/** Many-argument overload to avoid verbose nested calls or use with variadic arguments */ +template +RAFT_INLINE_FUNCTION auto min(const T1& x, const T2& y, Args&&... args) +{ + return raft::min(x, raft::min(y, std::forward(args)...)); +} + +/** One-argument overload for convenience when using with variadic arguments */ +template +constexpr RAFT_INLINE_FUNCTION auto min(const T& x) +{ + return x; +} +/** @} */ + +/** + * @defgroup Power Power and root functions + * @{ + */ +/** Power */ +template +RAFT_INLINE_FUNCTION auto pow(T1 x, T2 y) +{ +#ifdef __CUDA_ARCH__ + return ::pow(x, y); +#else + return std::pow(x, y); +#endif +} + +/** Square root */ +template +RAFT_INLINE_FUNCTION auto sqrt(T x) +{ +#ifdef __CUDA_ARCH__ + return ::sqrt(x); +#else + return std::sqrt(x); +#endif +} +/** @} */ + +/** Sign */ +template +RAFT_INLINE_FUNCTION auto sgn(T val) -> int +{ + return (T(0) < val) - (val < T(0)); +} + +} // namespace raft diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 786ce69f89..f805d20064 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -304,4 +304,52 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, } } +/** + * @brief Const accessor specialization for default_accessor + * + * @tparam ElementType + * @param a + * @return std::experimental::default_accessor> + */ +template +std::experimental::default_accessor> accessor_of_const( + std::experimental::default_accessor a) +{ + return {a}; +} + +/** + * @brief Const accessor specialization for host_device_accessor + * + * @tparam ElementType the data type of the mdspan elements + * @tparam MemType the type of memory where the elements are stored. + * @param a host_device_accessor + * @return host_device_accessor>, + * MemType> + */ +template +host_device_accessor>, MemType> +accessor_of_const(host_device_accessor, MemType> a) +{ + return {a}; +} + +/** + * @brief Create a copy of the given mdspan with const element type + * + * @tparam ElementType the const-qualified data type of the mdspan elements + * @tparam Extents raft::extents for dimensions + * @tparam Layout policy for strides and layout ordering + * @tparam Accessor Accessor policy for the input and output + * @param mds raft::mdspan object + * @return raft::mdspan + */ +template +auto make_const_mdspan(mdspan mds) +{ + auto acc_c = accessor_of_const(mds.accessor()); + return mdspan, Extents, Layout, decltype(acc_c)>{ + mds.data_handle(), mds.mapping(), acc_c}; +} + } // namespace raft diff --git a/cpp/include/raft/core/operators.cuh b/cpp/include/raft/core/operators.cuh new file mode 100644 index 0000000000..cafb404ef6 --- /dev/null +++ b/cpp/include/raft/core/operators.cuh @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft { + +/** + * @defgroup DeviceFunctors Commonly used device-only functors. + * @{ + */ + +struct atomic_add_op { + template + _RAFT_DEVICE _RAFT_FORCEINLINE Type operator()(Type* address, const Type& val) + { + return atomicAdd(address, val); + } +}; + +struct atomic_max_op { + template + _RAFT_DEVICE _RAFT_FORCEINLINE Type operator()(Type* address, const Type& val) + { + return atomicMax(address, val); + } +}; + +struct atomic_min_op { + template + _RAFT_DEVICE _RAFT_FORCEINLINE Type operator()(Type* address, const Type& val) + { + return atomicMin(address, val); + } +}; +/** @} */ + +} // namespace raft diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp new file mode 100644 index 0000000000..7acc907c49 --- /dev/null +++ b/cpp/include/raft/core/operators.hpp @@ -0,0 +1,421 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { + +/** + * @defgroup operators Commonly used functors. + * The optional unused arguments are useful for kernels that pass the index along with the value. + * @{ + */ + +struct identity_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + { + return in; + } +}; + +struct void_op { + template + constexpr RAFT_INLINE_FUNCTION void operator()(UnusedArgs...) const + { + return; + } +}; + +template +struct cast_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(InT in, UnusedArgs...) const + { + return static_cast(in); + } +}; + +struct key_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& p, UnusedArgs...) const + { + return p.key; + } +}; + +struct value_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& p, UnusedArgs...) const + { + return p.value; + } +}; + +struct sqrt_op { + template + RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + { + return raft::sqrt(in); + } +}; + +struct nz_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + { + return in != Type(0) ? Type(1) : Type(0); + } +}; + +struct abs_op { + template + RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + { + return raft::abs(in); + } +}; + +struct sq_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + { + return in * in; + } +}; + +struct add_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a + b; + } +}; + +struct sub_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a - b; + } +}; + +struct mul_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a * b; + } +}; + +struct div_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a / b; + } +}; + +struct div_checkzero_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + if (b == T2{0}) { return T1{0} / T2{1}; } + return a / b; + } +}; + +struct pow_op { + template + RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + { + return raft::pow(a, b); + } +}; + +struct mod_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a % b; + } +}; + +struct min_op { + template + RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const + { + return raft::min(std::forward(args)...); + } +}; + +struct max_op { + template + RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const + { + return raft::max(std::forward(args)...); + } +}; + +struct argmin_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& a, const KVP& b) const + { + if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } + return a; + } +}; + +struct argmax_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const KVP& a, const KVP& b) const + { + if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) { return b; } + return a; + } +}; + +struct greater_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a > b; + } +}; + +struct less_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a < b; + } +}; + +struct greater_or_equal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a >= b; + } +}; + +struct less_or_equal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a <= b; + } +}; + +struct equal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a == b; + } +}; + +struct notequal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a != b; + } +}; + +template +struct const_op { + const ScalarT scalar; + + constexpr explicit const_op(const ScalarT& s) : scalar{s} {} + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(Args...) const + { + return scalar; + } +}; + +/** + * @brief Wraps around a binary operator, passing a constant on the right-hand side. + * + * Usage example: + * @code{.cpp} + * #include + * + * raft::plug_const_op op(2.0f); + * std::cout << op(2.1f) << std::endl; // 4.2 + * @endcode + * + * @tparam ConstT + * @tparam BinaryOpT + */ +template +struct plug_const_op { + const ConstT c; + const BinaryOpT composed_op; + + template >> + constexpr explicit plug_const_op(const ConstT& s) + : c{s}, composed_op{} // The compiler complains if composed_op is not initialized explicitly + { + } + constexpr plug_const_op(const ConstT& s, BinaryOpT o) : c{s}, composed_op{o} {} + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(InT a) const + { + return composed_op(a, c); + } +}; + +template +using add_const_op = plug_const_op; + +template +using sub_const_op = plug_const_op; + +template +using mul_const_op = plug_const_op; + +template +using div_const_op = plug_const_op; + +template +using div_checkzero_const_op = plug_const_op; + +template +using pow_const_op = plug_const_op; + +template +using mod_const_op = plug_const_op; + +template +using mod_const_op = plug_const_op; + +template +using equal_const_op = plug_const_op; + +/** + * @brief Constructs an operator by composing a chain of operators. + * + * Note that all arguments are passed to the innermost operator. + * + * Usage example: + * @code{.cpp} + * #include + * + * auto op = raft::compose_op(raft::sqrt_op(), raft::abs_op(), raft::cast_op(), + * raft::add_const_op(8)); + * std::cout << op(-50) << std::endl; // 6.48074 + * @endcode + * + * @tparam OpsT Any number of operation types. + */ +template +struct compose_op { + const std::tuple ops; + + template , + typename CondT = std::enable_if_t>> + constexpr compose_op() + { + } + constexpr explicit compose_op(OpsT... ops) : ops{ops...} {} + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const + { + return compose(std::forward(args)...); + } + + private: + template + constexpr RAFT_INLINE_FUNCTION auto compose(Args&&... args) const + { + if constexpr (RemOps > 0) { + return compose(std::get(ops)(std::forward(args)...)); + } else { + return identity_op{}(std::forward(args)...); + } + } +}; + +using absdiff_op = compose_op; + +using sqdiff_op = compose_op; + +/** + * @brief Constructs an operator by composing an outer op with one inner op for each of its inputs. + * + * Usage example: + * @code{.cpp} + * #include + * + * raft::map_args_op> op; + * std::cout << op(42.0f, 10) << std::endl; // 16.4807 + * @endcode + * + * @tparam OuterOpT Outer operation type + * @tparam ArgOpsT Operation types for each input of the outer operation + */ +template +struct map_args_op { + const OuterOpT outer_op; + const std::tuple arg_ops; + + template , + typename CondT = std::enable_if_t && + std::is_default_constructible_v>> + constexpr map_args_op() + : outer_op{} // The compiler complains if outer_op is not initialized explicitly + { + } + constexpr explicit map_args_op(OuterOpT outer_op, ArgOpsT... arg_ops) + : outer_op{outer_op}, arg_ops{arg_ops...} + { + } + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const + { + constexpr size_t kNumOps = sizeof...(ArgOpsT); + static_assert(kNumOps == sizeof...(Args), + "The number of arguments does not match the number of mapping operators"); + return map_args(std::make_index_sequence{}, std::forward(args)...); + } + + private: + template + constexpr RAFT_INLINE_FUNCTION auto map_args(std::index_sequence, Args&&... args) const + { + return outer_op(std::get(arg_ops)(std::forward(args))...); + } +}; + +/** @} */ +} // namespace raft diff --git a/cpp/include/raft/core/resource/comms.hpp b/cpp/include/raft/core/resource/comms.hpp new file mode 100644 index 0000000000..73de166c14 --- /dev/null +++ b/cpp/include/raft/core/resource/comms.hpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::resource { +class comms_resource : public resource { + public: + comms_resource(std::shared_ptr comnumicator) : communicator_(comnumicator) {} + + void* get_resource() override { return &communicator_; } + + ~comms_resource() override {} + + private: + std::shared_ptr communicator_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class comms_resource_factory : public resource_factory { + public: + comms_resource_factory(std::shared_ptr communicator) : communicator_(communicator) + { + } + + resource_type get_resource_type() override { return resource_type::COMMUNICATOR; } + + resource* make_resource() override { return new comms_resource(communicator_); } + + private: + std::shared_ptr communicator_; +}; + +/** + * @defgroup resource_comms Comms resource functions + * @{ + */ + +inline bool comms_initialized(resources const& res) +{ + return res.has_resource_factory(resource_type::COMMUNICATOR); +} + +inline comms::comms_t const& get_comms(resources const& res) +{ + RAFT_EXPECTS(comms_initialized(res), "ERROR: Communicator was not initialized\n"); + return *(*res.get_resource>(resource_type::COMMUNICATOR)); +} + +inline void set_comms(resources const& res, std::shared_ptr communicator) +{ + res.add_resource_factory(std::make_shared(communicator)); +} + +/** + * @} + */ +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cublas_handle.hpp b/cpp/include/raft/core/resource/cublas_handle.hpp new file mode 100644 index 0000000000..710fcc7e60 --- /dev/null +++ b/cpp/include/raft/core/resource/cublas_handle.hpp @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { + +class cublas_resource : public resource { + public: + cublas_resource(rmm::cuda_stream_view stream) + { + RAFT_CUBLAS_TRY_NO_THROW(cublasCreate(&cublas_res)); + RAFT_CUBLAS_TRY_NO_THROW(cublasSetStream(cublas_res, stream)); + } + + ~cublas_resource() override { RAFT_CUBLAS_TRY_NO_THROW(cublasDestroy(cublas_res)); } + + void* get_resource() override { return &cublas_res; } + + private: + cublasHandle_t cublas_res; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cublas_resource_factory : public resource_factory { + public: + cublas_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUBLAS_HANDLE; } + resource* make_resource() override { return new cublas_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * @defgroup resource_cublas cuBLAS handle resource functions + * @{ + */ + +/** + * Load a cublasres_t from raft res if it exists, otherwise + * add it and return it. + * @param[in] res the raft resources object + * @return cublas handle + */ +inline cublasHandle_t get_cublas_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUBLAS_HANDLE)) { + cudaStream_t stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUBLAS_HANDLE); +}; + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cuda_event.hpp b/cpp/include/raft/core/resource/cuda_event.hpp new file mode 100644 index 0000000000..4859d95ee9 --- /dev/null +++ b/cpp/include/raft/core/resource/cuda_event.hpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::resource { + +class cuda_event_resource : public resource { + public: + cuda_event_resource() + { + RAFT_CUDA_TRY_NO_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + } + void* get_resource() override { return &event_; } + + ~cuda_event_resource() override { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(event_)); } + + private: + cudaEvent_t event_; +}; +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cuda_stream.hpp b/cpp/include/raft/core/resource/cuda_stream.hpp new file mode 100644 index 0000000000..fc69f10d83 --- /dev/null +++ b/cpp/include/raft/core/resource/cuda_stream.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::resource { +class cuda_stream_resource : public resource { + public: + cuda_stream_resource(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread) + : stream(stream_view) + { + } + void* get_resource() override { return &stream; } + + ~cuda_stream_resource() override {} + + private: + rmm::cuda_stream_view stream; +}; + +/** + * Factory that knows how to construct a specific raft::resource to populate + * the resources instance. + */ +class cuda_stream_resource_factory : public resource_factory { + public: + cuda_stream_resource_factory(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread) + : stream(stream_view) + { + } + resource_type get_resource_type() override { return resource_type::CUDA_STREAM_VIEW; } + resource* make_resource() override { return new cuda_stream_resource(stream); } + + private: + rmm::cuda_stream_view stream; +}; + +/** + * @defgroup resource_cuda_stream CUDA stream resource functions + * @{ + */ +/** + * Load a rmm::cuda_stream_view from a resources instance (and populate it on the res + * if needed). + * @param res raft res object for managing resources + * @return + */ +inline rmm::cuda_stream_view get_cuda_stream(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUDA_STREAM_VIEW)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::CUDA_STREAM_VIEW); +}; + +/** + * Load a rmm::cuda_stream_view from a resources instance (and populate it on the res + * if needed). + * @param[in] res raft resources object for managing resources + * @param[in] stream_view cuda stream view + */ +inline void set_cuda_stream(resources const& res, rmm::cuda_stream_view stream_view) +{ + res.add_resource_factory(std::make_shared(stream_view)); +}; + +/** + * @brief synchronize a specific stream + * + * @param[in] res the raft resources object + * @param[in] stream stream to synchronize + */ +inline void sync_stream(const resources& res, rmm::cuda_stream_view stream) +{ + // TODO: Fix interruptible segfault: + // https://github.com/rapidsai/raft/issues/1225 + // interruptible::synchronize(stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); +} + +/** + * @brief synchronize main stream on the resources instance + */ +inline void sync_stream(const resources& res) { sync_stream(res, get_cuda_stream(res)); } + +/** + * @} + */ + +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/cuda_stream_pool.hpp b/cpp/include/raft/core/resource/cuda_stream_pool.hpp new file mode 100644 index 0000000000..dbce75b3a5 --- /dev/null +++ b/cpp/include/raft/core/resource/cuda_stream_pool.hpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace raft::resource { + +class cuda_stream_pool_resource : public resource { + public: + cuda_stream_pool_resource(std::shared_ptr stream_pool) + : stream_pool_(stream_pool) + { + } + + ~cuda_stream_pool_resource() override {} + void* get_resource() override { return &stream_pool_; } + + private: + std::shared_ptr stream_pool_{nullptr}; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cuda_stream_pool_resource_factory : public resource_factory { + public: + cuda_stream_pool_resource_factory(std::shared_ptr stream_pool = {nullptr}) + : stream_pool_(stream_pool) + { + } + + resource_type get_resource_type() override { return resource_type::CUDA_STREAM_POOL; } + resource* make_resource() override { return new cuda_stream_pool_resource(stream_pool_); } + + private: + std::shared_ptr stream_pool_{nullptr}; +}; + +inline bool is_stream_pool_initialized(const resources& res) +{ + return *res.get_resource>( + resource_type::CUDA_STREAM_POOL) != nullptr; +} + +/** + * @defgroup resource_stream_pool CUDA Stream pool resource functions + * @{ + */ + +/** + * Load a cuda_stream_pool, and create a new one if it doesn't already exist + * @param res raft res object for managing resources + * @return + */ +inline const rmm::cuda_stream_pool& get_cuda_stream_pool(const resources& res) +{ + if (!res.has_resource_factory(resource_type::CUDA_STREAM_POOL)) { + res.add_resource_factory(std::make_shared()); + } + return *( + *res.get_resource>(resource_type::CUDA_STREAM_POOL)); +}; + +/** + * Explicitly set a stream pool on the current res. Note that this will overwrite + * an existing stream pool on the res. + * @param res + * @param stream_pool + */ +inline void set_cuda_stream_pool(const resources& res, + std::shared_ptr stream_pool) +{ + res.add_resource_factory(std::make_shared(stream_pool)); +}; + +inline std::size_t get_stream_pool_size(const resources& res) +{ + return is_stream_pool_initialized(res) ? get_cuda_stream_pool(res).get_pool_size() : 0; +} + +/** + * @brief return stream from pool + */ +inline rmm::cuda_stream_view get_stream_from_stream_pool(const resources& res) +{ + RAFT_EXPECTS(is_stream_pool_initialized(res), "ERROR: rmm::cuda_stream_pool was not initialized"); + return get_cuda_stream_pool(res).get_stream(); +} + +/** + * @brief return stream from pool at index + */ +inline rmm::cuda_stream_view get_stream_from_stream_pool(const resources& res, + std::size_t stream_idx) +{ + RAFT_EXPECTS(is_stream_pool_initialized(res), "ERROR: rmm::cuda_stream_pool was not initialized"); + return get_cuda_stream_pool(res).get_stream(stream_idx); +} + +/** + * @brief return stream from pool if size > 0, else main stream on res + */ +inline rmm::cuda_stream_view get_next_usable_stream(const resources& res) +{ + return is_stream_pool_initialized(res) ? get_stream_from_stream_pool(res) : get_cuda_stream(res); +} + +/** + * @brief return stream from pool at index if size > 0, else main stream on res + * + * @param[in] res the raft resources object + * @param[in] stream_idx the required index of the stream in the stream pool if available + */ +inline rmm::cuda_stream_view get_next_usable_stream(const resources& res, std::size_t stream_idx) +{ + return is_stream_pool_initialized(res) ? get_stream_from_stream_pool(res, stream_idx) + : get_cuda_stream(res); +} + +/** + * @brief synchronize the stream pool on the res + * + * @param[in] res the raft resources object + */ +inline void sync_stream_pool(const resources& res) +{ + for (std::size_t i = 0; i < get_stream_pool_size(res); i++) { + sync_stream(res, get_cuda_stream_pool(res).get_stream(i)); + } +} + +/** + * @brief synchronize subset of stream pool + * + * @param[in] res the raft resources object + * @param[in] stream_indices the indices of the streams in the stream pool to synchronize + */ +inline void sync_stream_pool(const resources& res, const std::vector stream_indices) +{ + RAFT_EXPECTS(is_stream_pool_initialized(res), "ERROR: rmm::cuda_stream_pool was not initialized"); + for (const auto& stream_index : stream_indices) { + sync_stream(res, get_cuda_stream_pool(res).get_stream(stream_index)); + } +} + +/** + * @brief ask stream pool to wait on last event in main stream + * + * @param[in] res the raft resources object + */ +inline void wait_stream_pool_on_stream(const resources& res) +{ + cudaEvent_t event = detail::get_cuda_stream_sync_event(res); + RAFT_CUDA_TRY(cudaEventRecord(event, get_cuda_stream(res))); + for (std::size_t i = 0; i < get_stream_pool_size(res); i++) { + RAFT_CUDA_TRY(cudaStreamWaitEvent(get_cuda_stream_pool(res).get_stream(i), event, 0)); + } +} + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cusolver_dn_handle.hpp b/cpp/include/raft/core/resource/cusolver_dn_handle.hpp new file mode 100644 index 0000000000..7a33e2dd2a --- /dev/null +++ b/cpp/include/raft/core/resource/cusolver_dn_handle.hpp @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cuda_stream.hpp" +#include +#include +#include +#include +#include + +namespace raft::resource { + +/** + * + */ +class cusolver_dn_resource : public resource { + public: + cusolver_dn_resource(rmm::cuda_stream_view stream) + { + RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnCreate(&cusolver_res)); + RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnSetStream(cusolver_res, stream)); + } + + void* get_resource() override { return &cusolver_res; } + + ~cusolver_dn_resource() override { RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnDestroy(cusolver_res)); } + + private: + cusolverDnHandle_t cusolver_res; +}; + +/** + * @defgroup resource_cusolver_dn cuSolver DN handle resource functions + * @{ + */ + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cusolver_dn_resource_factory : public resource_factory { + public: + cusolver_dn_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUSOLVER_DN_HANDLE; } + resource* make_resource() override { return new cusolver_dn_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * Load a cusolverSpres_t from raft res if it exists, otherwise + * add it and return it. + * @param[in] res the raft resources object + * @return cusolver dn handle + */ +inline cusolverDnHandle_t get_cusolver_dn_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUSOLVER_DN_HANDLE)) { + cudaStream_t stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUSOLVER_DN_HANDLE); +}; + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cusolver_sp_handle.hpp b/cpp/include/raft/core/resource/cusolver_sp_handle.hpp new file mode 100644 index 0000000000..61fd95b44f --- /dev/null +++ b/cpp/include/raft/core/resource/cusolver_sp_handle.hpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { + +/** + * + */ +class cusolver_sp_resource : public resource { + public: + cusolver_sp_resource(rmm::cuda_stream_view stream) + { + RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpCreate(&cusolver_res)); + RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpSetStream(cusolver_res, stream)); + } + + void* get_resource() override { return &cusolver_res; } + + ~cusolver_sp_resource() override { RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpDestroy(cusolver_res)); } + + private: + cusolverSpHandle_t cusolver_res; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cusolver_sp_resource_factory : public resource_factory { + public: + cusolver_sp_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUSOLVER_SP_HANDLE; } + resource* make_resource() override { return new cusolver_sp_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * @defgroup resource_cusolver_sp cuSolver SP handle resource functions + * @{ + */ + +/** + * Load a cusolverSpres_t from raft res if it exists, otherwise + * add it and return it. + * @param[in] res the raft resources object + * @return cusolver sp handle + */ +inline cusolverSpHandle_t get_cusolver_sp_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUSOLVER_SP_HANDLE)) { + cudaStream_t stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUSOLVER_SP_HANDLE); +}; + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cusparse_handle.hpp b/cpp/include/raft/core/resource/cusparse_handle.hpp new file mode 100644 index 0000000000..9893ed2f86 --- /dev/null +++ b/cpp/include/raft/core/resource/cusparse_handle.hpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { +class cusparse_resource : public resource { + public: + cusparse_resource(rmm::cuda_stream_view stream) + { + RAFT_CUSPARSE_TRY_NO_THROW(cusparseCreate(&cusparse_res)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseSetStream(cusparse_res, stream)); + } + + ~cusparse_resource() { RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroy(cusparse_res)); } + void* get_resource() override { return &cusparse_res; } + + private: + cusparseHandle_t cusparse_res; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cusparse_resource_factory : public resource_factory { + public: + cusparse_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUSPARSE_HANDLE; } + resource* make_resource() override { return new cusparse_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * @defgroup resource_cusparse cuSparse handle resource functions + * @{ + */ + +/** + * Load a cusparseres_t from raft res if it exists, otherwise + * add it and return it. + * @param[in] res the raft resources object + * @return cusparse handle + */ +inline cusparseHandle_t get_cusparse_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUSPARSE_HANDLE)) { + rmm::cuda_stream_view stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUSPARSE_HANDLE); +}; + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/detail/stream_sync_event.hpp b/cpp/include/raft/core/resource/detail/stream_sync_event.hpp new file mode 100644 index 0000000000..1d02fef20d --- /dev/null +++ b/cpp/include/raft/core/resource/detail/stream_sync_event.hpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource::detail { + +/** + * Factory that knows how to construct a specific raft::resource to populate + * the res_t. + */ +class cuda_stream_sync_event_resource_factory : public resource_factory { + public: + resource_type get_resource_type() override { return resource_type::CUDA_STREAM_SYNC_EVENT; } + resource* make_resource() override { return new cuda_event_resource(); } +}; + +/** + * Load a cudaEvent from a resources instance (and populate it on the resources instance) + * if needed) for syncing the main cuda stream. + * @param res raft resources instance for managing resources + * @return + */ +inline cudaEvent_t& get_cuda_stream_sync_event(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUDA_STREAM_SYNC_EVENT)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::CUDA_STREAM_SYNC_EVENT); +}; + +} // namespace raft::resource::detail diff --git a/cpp/include/raft/core/resource/device_id.hpp b/cpp/include/raft/core/resource/device_id.hpp new file mode 100644 index 0000000000..b55e56ca45 --- /dev/null +++ b/cpp/include/raft/core/resource/device_id.hpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::resource { + +class device_id_resource : public resource { + public: + device_id_resource() + : dev_id_([]() -> int { + int cur_dev = -1; + RAFT_CUDA_TRY_NO_THROW(cudaGetDevice(&cur_dev)); + return cur_dev; + }()) + { + } + void* get_resource() override { return &dev_id_; } + + ~device_id_resource() override {} + + private: + int dev_id_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class device_id_resource_factory : public resource_factory { + public: + resource_type get_resource_type() override { return resource_type::DEVICE_ID; } + resource* make_resource() override { return new device_id_resource(); } +}; + +/** + * @defgroup resource_device_id Device ID resource functions + * @{ + */ + +/** + * Load a device id from a res (and populate it on the res if needed). + * @param res raft res object for managing resources + * @return device id + */ +inline int get_device_id(resources const& res) +{ + if (!res.has_resource_factory(resource_type::DEVICE_ID)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::DEVICE_ID); +}; + +/** + * @} + */ +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp new file mode 100644 index 0000000000..35ae3d715f --- /dev/null +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::resource { +class device_memory_resource : public resource { + public: + device_memory_resource(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_) + { + if (mr_ == nullptr) { mr = rmm::mr::get_current_device_resource(); } + } + void* get_resource() override { return mr; } + + ~device_memory_resource() override {} + + private: + rmm::mr::device_memory_resource* mr; +}; + +/** + * Factory that knows how to construct a specific raft::resource to populate + * the resources instance. + */ +class workspace_resource_factory : public resource_factory { + public: + workspace_resource_factory(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_) {} + resource_type get_resource_type() override { return resource_type::WORKSPACE_RESOURCE; } + resource* make_resource() override { return new device_memory_resource(mr); } + + private: + rmm::mr::device_memory_resource* mr; +}; + +/** + * Load a temp workspace resource from a resources instance (and populate it on the res + * if needed). + * @param res raft resources object for managing resources + * @return device memory resource object + */ +inline rmm::mr::device_memory_resource* get_workspace_resource(resources const& res) +{ + if (!res.has_resource_factory(resource_type::WORKSPACE_RESOURCE)) { + res.add_resource_factory(std::make_shared()); + } + return res.get_resource(resource_type::WORKSPACE_RESOURCE); +}; + +/** + * Set a temp workspace resource on a resources instance. + * + * @param res raft resources object for managing resources + * @param mr a valid rmm device_memory_resource + */ +inline void set_workspace_resource(resources const& res, rmm::mr::device_memory_resource* mr) +{ + res.add_resource_factory(std::make_shared(mr)); +}; +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/device_properties.hpp b/cpp/include/raft/core/resource/device_properties.hpp new file mode 100644 index 0000000000..c3b0b8f2b9 --- /dev/null +++ b/cpp/include/raft/core/resource/device_properties.hpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { + +class device_properties_resource : public resource { + public: + device_properties_resource(int dev_id) + { + RAFT_CUDA_TRY_NO_THROW(cudaGetDeviceProperties(&prop_, dev_id)); + } + void* get_resource() override { return &prop_; } + + ~device_properties_resource() override {} + + private: + cudaDeviceProp prop_; +}; + +/** + * @defgroup resource_device_props Device properties resource functions + * @{ + */ + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class device_properties_resource_factory : public resource_factory { + public: + device_properties_resource_factory(int dev_id) : dev_id_(dev_id) {} + resource_type get_resource_type() override { return resource_type::DEVICE_PROPERTIES; } + resource* make_resource() override { return new device_properties_resource(dev_id_); } + + private: + int dev_id_; +}; + +/** + * Load a cudaDeviceProp from a res (and populate it on the res if needed). + * @param res raft res object for managing resources + * @return populated cuda device properties instance + */ +inline cudaDeviceProp& get_device_properties(resources const& res) +{ + if (!res.has_resource_factory(resource_type::DEVICE_PROPERTIES)) { + int dev_id = get_device_id(res); + res.add_resource_factory(std::make_shared(dev_id)); + } + return *res.get_resource(resource_type::DEVICE_PROPERTIES); +}; + +/** + * @} + */ +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp new file mode 100644 index 0000000000..cf302e25f9 --- /dev/null +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::resource { + +/** + * @defgroup resource_types Core resource vocabulary types + * @{ + */ + +/** + * @brief Resource types can apply to any resource and don't have to be host- or device-specific. + */ +enum resource_type { + // device-specific resource types + CUBLAS_HANDLE = 0, // cublas handle + CUSOLVER_DN_HANDLE, // cusolver dn handle + CUSOLVER_SP_HANDLE, // cusolver sp handle + CUSPARSE_HANDLE, // cusparse handle + CUDA_STREAM_VIEW, // view of a cuda stream + CUDA_STREAM_POOL, // cuda stream pool + CUDA_STREAM_SYNC_EVENT, // cuda event for syncing streams + COMMUNICATOR, // raft communicator + SUB_COMMUNICATOR, // raft sub communicator + DEVICE_PROPERTIES, // cuda device properties + DEVICE_ID, // cuda device id + THRUST_POLICY, // thrust execution policy + WORKSPACE_RESOURCE, // rmm device memory resource + + LAST_KEY // reserved for the last key +}; + +/** + * @brief A resource constructs and contains an instance of + * some pre-determined object type and facades that object + * behind a common API. + */ +class resource { + public: + virtual void* get_resource() = 0; + + virtual ~resource() {} +}; + +class empty_resource : public resource { + public: + empty_resource() : resource() {} + + void* get_resource() override { return nullptr; } + + ~empty_resource() override {} +}; + +/** + * @brief A resource factory knows how to construct an instance of + * a specific raft::resource::resource. + */ +class resource_factory { + public: + /** + * @brief Return the resource_type associated with the current factory + * @return resource_type corresponding to the current factory + */ + virtual resource_type get_resource_type() = 0; + + /** + * @brief Construct an instance of the factory's underlying resource. + * @return resource instance + */ + virtual resource* make_resource() = 0; +}; + +/** + * @brief A resource factory knows how to construct an instance of + * a specific raft::resource::resource. + */ +class empty_resource_factory : public resource_factory { + public: + empty_resource_factory() : resource_factory() {} + /** + * @brief Return the resource_type associated with the current factory + * @return resource_type corresponding to the current factory + */ + resource_type get_resource_type() override { return resource_type::LAST_KEY; } + + /** + * @brief Construct an instance of the factory's underlying resource. + * @return resource instance + */ + resource* make_resource() override { return &res; } + + private: + empty_resource res; +}; + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/sub_comms.hpp b/cpp/include/raft/core/resource/sub_comms.hpp new file mode 100644 index 0000000000..7070b61c54 --- /dev/null +++ b/cpp/include/raft/core/resource/sub_comms.hpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::resource { +class sub_comms_resource : public resource { + public: + sub_comms_resource() : communicators_() {} + void* get_resource() override { return &communicators_; } + + ~sub_comms_resource() override {} + + private: + std::unordered_map> communicators_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class sub_comms_resource_factory : public resource_factory { + public: + resource_type get_resource_type() override { return resource_type::SUB_COMMUNICATOR; } + resource* make_resource() override { return new sub_comms_resource(); } +}; + +/** + * @defgroup resource_subcomms Subcommunicator resource functions + * @{ + */ + +inline const comms::comms_t& get_subcomm(const resources& res, std::string key) +{ + if (!res.has_resource_factory(resource_type::SUB_COMMUNICATOR)) { + res.add_resource_factory(std::make_shared()); + } + + auto sub_comms = + res.get_resource>>( + resource_type::SUB_COMMUNICATOR); + auto sub_comm = sub_comms->at(key); + RAFT_EXPECTS(nullptr != sub_comm.get(), "ERROR: Subcommunicator was not initialized"); + + return *sub_comm; +} + +inline void set_subcomm(resources const& res, + std::string key, + std::shared_ptr subcomm) +{ + if (!res.has_resource_factory(resource_type::SUB_COMMUNICATOR)) { + res.add_resource_factory(std::make_shared()); + } + auto sub_comms = + res.get_resource>>( + resource_type::SUB_COMMUNICATOR); + sub_comms->insert(std::make_pair(key, subcomm)); +} + +/** + * @} + */ + +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/thrust_policy.hpp b/cpp/include/raft/core/resource/thrust_policy.hpp new file mode 100644 index 0000000000..1e7441e5e4 --- /dev/null +++ b/cpp/include/raft/core/resource/thrust_policy.hpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +namespace raft::resource { +class thrust_policy_resource : public resource { + public: + thrust_policy_resource(rmm::cuda_stream_view stream_view) + : thrust_policy_(std::make_unique(stream_view)) + { + } + void* get_resource() override { return thrust_policy_.get(); } + + ~thrust_policy_resource() override {} + + private: + std::unique_ptr thrust_policy_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class thrust_policy_resource_factory : public resource_factory { + public: + thrust_policy_resource_factory(rmm::cuda_stream_view stream_view) : stream_view_(stream_view) {} + resource_type get_resource_type() override { return resource_type::THRUST_POLICY; } + resource* make_resource() override { return new thrust_policy_resource(stream_view_); } + + private: + rmm::cuda_stream_view stream_view_; +}; + +/** + * @defgroup resource_thrust_policy Thrust policy resource functions + * @{ + */ + +/** + * Load a thrust policy from a res (and populate it on the res if needed). + * @param res raft res object for managing resources + * @return thrust execution policy + */ +inline rmm::exec_policy& get_thrust_policy(resources const& res) +{ + if (!res.has_resource_factory(resource_type::THRUST_POLICY)) { + rmm::cuda_stream_view stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::THRUST_POLICY); +}; + +/** + * @} + */ + +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resources.hpp b/cpp/include/raft/core/resources.hpp new file mode 100644 index 0000000000..64e281e934 --- /dev/null +++ b/cpp/include/raft/core/resources.hpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "resource/resource_types.hpp" +#include +#include +#include +#include +#include + +namespace raft { + +/** + * @brief Resource container which allows lazy-loading and registration + * of resource_factory implementations, which in turn generate resource instances. + * + * This class is intended to be agnostic of the resources it contains and + * does not, itself, differentiate between host and device resources. Downstream + * accessor functions can then register and load resources as needed in order + * to keep its usage somewhat opaque to end-users. + * + * @code{.cpp} + * #include + * #include + * #include + * + * raft::resources res; + * auto stream = raft::resource::get_cuda_stream(res); + * auto cublas_handle = raft::resource::get_cublas_handle(res); + * @endcode + */ +class resources { + public: + template + using pair_res = std::pair>; + + using pair_res_factory = pair_res; + using pair_resource = pair_res; + + resources() + : factories_(resource::resource_type::LAST_KEY), resources_(resource::resource_type::LAST_KEY) + { + for (int i = 0; i < resource::resource_type::LAST_KEY; ++i) { + factories_.at(i) = std::make_pair(resource::resource_type::LAST_KEY, + std::make_shared()); + resources_.at(i) = std::make_pair(resource::resource_type::LAST_KEY, + std::make_shared()); + } + } + + /** + * @brief Shallow copy of underlying resources instance. + * Note that this does not create any new resources. + */ + resources(const resources& res) : factories_(res.factories_), resources_(res.resources_) {} + resources(resources&&) = delete; + resources& operator=(resources&&) = delete; + + /** + * @brief Returns true if a resource_factory has been registered for the + * given resource_type, false otherwise. + * @param resource_type resource type to check + * @return true if resource_factory is registered for the given resource_type + */ + bool has_resource_factory(resource::resource_type resource_type) const + { + std::lock_guard _(mutex_); + return factories_.at(resource_type).first != resource::resource_type::LAST_KEY; + } + + /** + * @brief Register a resource_factory with the current instance. + * This will overwrite any existing resource factories. + * @param factory resource factory to register on the current instance + */ + void add_resource_factory(std::shared_ptr factory) const + { + std::lock_guard _(mutex_); + resource::resource_type rtype = factory.get()->get_resource_type(); + RAFT_EXPECTS(rtype != resource::resource_type::LAST_KEY, + "LAST_KEY is a placeholder and not a valid resource factory type."); + factories_.at(rtype) = std::make_pair(rtype, factory); + } + + /** + * @brief Retrieve a resource for the given resource_type and cast to given pointer type. + * Note that the resources are loaded lazily on-demand and resources which don't yet + * exist on the current instance will be created using the corresponding factory, if + * it exists. + * @tparam res_t pointer type for which retrieved resource will be casted + * @param resource_type resource type to retrieve + * @return the given resource, if it exists. + */ + template + res_t* get_resource(resource::resource_type resource_type) const + { + std::lock_guard _(mutex_); + + if (resources_.at(resource_type).first == resource::resource_type::LAST_KEY) { + RAFT_EXPECTS(factories_.at(resource_type).first != resource::resource_type::LAST_KEY, + "No resource factory has been registered for the given resource %d.", + resource_type); + resource::resource_factory* factory = factories_.at(resource_type).second.get(); + resources_.at(resource_type) = std::make_pair( + resource_type, std::shared_ptr(factory->make_resource())); + } + + resource::resource* res = resources_.at(resource_type).second.get(); + return reinterpret_cast(res->get_resource()); + } + + protected: + mutable std::mutex mutex_; + mutable std::vector factories_; + mutable std::vector resources_; +}; +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/serialize.hpp b/cpp/include/raft/core/serialize.hpp new file mode 100644 index 0000000000..05814e2845 --- /dev/null +++ b/cpp/include/raft/core/serialize.hpp @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +/** + * Collection of serialization functions for RAFT data types + */ + +namespace raft { + +template +inline void serialize_mdspan( + const raft::device_resources&, + std::ostream& os, + const raft::host_mdspan& obj) +{ + detail::numpy_serializer::serialize_host_mdspan(os, obj); +} + +template +inline void serialize_mdspan( + const raft::device_resources& handle, + std::ostream& os, + const raft::device_mdspan& obj) +{ + static_assert(std::is_same_v || + std::is_same_v, + "The serializer only supports row-major and column-major layouts"); + using obj_t = raft::device_mdspan; + + // Copy to host before serializing + // For contiguous layouts, size() == product of dimensions + std::vector tmp(obj.size()); + cudaStream_t stream = handle.get_stream(); + raft::update_host(tmp.data(), obj.data_handle(), obj.size(), stream); + handle.sync_stream(); + using inner_accessor_type = typename obj_t::accessor_type::accessor_type; + auto tmp_mdspan = + raft::host_mdspan>( + tmp.data(), obj.extents()); + detail::numpy_serializer::serialize_host_mdspan(os, tmp_mdspan); +} + +template +inline void serialize_mdspan( + const raft::device_resources&, + std::ostream& os, + const raft::managed_mdspan& obj) +{ + using obj_t = raft::managed_mdspan; + using inner_accessor_type = typename obj_t::accessor_type::accessor_type; + auto tmp_mdspan = + raft::host_mdspan>( + obj.data_handle(), obj.extents()); + detail::numpy_serializer::serialize_host_mdspan(os, tmp_mdspan); +} + +template +inline void deserialize_mdspan( + const raft::device_resources&, + std::istream& is, + raft::host_mdspan& obj) +{ + detail::numpy_serializer::deserialize_host_mdspan(is, obj); +} + +template +inline void deserialize_mdspan( + const raft::device_resources& handle, + std::istream& is, + raft::device_mdspan& obj) +{ + static_assert(std::is_same_v || + std::is_same_v, + "The serializer only supports row-major and column-major layouts"); + using obj_t = raft::device_mdspan; + + // Copy to device after serializing + // For contiguous layouts, size() == product of dimensions + std::vector tmp(obj.size()); + using inner_accessor_type = typename obj_t::accessor_type::accessor_type; + auto tmp_mdspan = + raft::host_mdspan>( + tmp.data(), obj.extents()); + detail::numpy_serializer::deserialize_host_mdspan(is, tmp_mdspan); + + cudaStream_t stream = handle.get_stream(); + raft::update_device(obj.data_handle(), tmp.data(), obj.size(), stream); + handle.sync_stream(); +} + +template +inline void deserialize_mdspan( + const raft::device_resources& handle, + std::istream& is, + raft::host_mdspan&& obj) +{ + deserialize_mdspan(handle, is, obj); +} + +template +inline void deserialize_mdspan( + const raft::device_resources& handle, + std::istream& is, + raft::managed_mdspan& obj) +{ + using obj_t = raft::managed_mdspan; + using inner_accessor_type = typename obj_t::accessor_type::accessor_type; + auto tmp_mdspan = + raft::host_mdspan>( + obj.data_handle(), obj.extents()); + detail::numpy_serializer::deserialize_host_mdspan(is, tmp_mdspan); +} + +template +inline void deserialize_mdspan( + const raft::device_resources& handle, + std::istream& is, + raft::managed_mdspan&& obj) +{ + deserialize_mdspan(handle, is, obj); +} + +template +inline void deserialize_mdspan( + const raft::device_resources& handle, + std::istream& is, + raft::device_mdspan&& obj) +{ + deserialize_mdspan(handle, is, obj); +} + +template +inline void serialize_scalar(const raft::device_resources&, std::ostream& os, const T& value) +{ + detail::numpy_serializer::serialize_scalar(os, value); +} + +template +inline T deserialize_scalar(const raft::device_resources&, std::istream& is) +{ + return detail::numpy_serializer::deserialize_scalar(is); +} + +} // end namespace raft diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index 6be994b80a..f17a26dc4b 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,19 +73,15 @@ static void canberraImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::L1Op()(x - y); - const auto add = raft::myAbs(x) + raft::myAbs(y); + const auto diff = raft::abs(x - y); + const auto add = raft::abs(x) + raft::abs(y); // deal with potential for 0 in denominator by // forcing 1/0 instead acc += ((add != 0) * diff / (add + (add == 0))); }; // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { return; }; + auto epilog_lambda = raft::void_op(); if (isRowMajor) { auto canberraRowMajor = pairwiseDistanceMatKernel #include namespace raft { @@ -72,16 +73,12 @@ static void chebyshevImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::L1Op()(x - y); - acc = raft::myMax(acc, diff); + const auto diff = raft::abs(x - y); + acc = raft::max(acc, diff); }; // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { return; }; + auto epilog_lambda = raft::void_op(); if (isRowMajor) { auto chebyshevRowMajor = pairwiseDistanceMatKernel(), - raft::Sum()); + raft::identity_op(), + raft::add_op()); raft::linalg::reduce(norm_row_vec, pB, k, @@ -273,8 +273,8 @@ void correlationImpl(int m, true, stream, false, - raft::Nop(), - raft::Sum()); + raft::identity_op(), + raft::add_op()); sq_norm_col_vec += (m + n); sq_norm_row_vec = sq_norm_col_vec + m; @@ -290,8 +290,8 @@ void correlationImpl(int m, true, stream, false, - raft::Nop(), - raft::Sum()); + raft::identity_op(), + raft::add_op()); sq_norm_col_vec += m; sq_norm_row_vec = sq_norm_col_vec; raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream); diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index f06051962f..46a694aa51 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -19,6 +19,7 @@ #include #include #include +#include namespace raft { namespace distance { @@ -229,8 +230,6 @@ void cosineAlgo1(Index_ m, cudaStream_t stream, bool isRowMajor) { - auto norm_op = [] __device__(AccType in) { return raft::mySqrt(in); }; - // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), "OutType can be uint8_t, float, double," @@ -248,10 +247,13 @@ void cosineAlgo1(Index_ m, InType* row_vec = workspace; if (pA != pB) { row_vec += m; - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); - raft::linalg::rowNorm(row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); + raft::linalg::rowNorm( + row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); } else { - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); } if (isRowMajor) { diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 5ea74fa884..1a2db63f5c 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include #include #include +#include namespace raft { namespace distance { @@ -33,7 +34,7 @@ struct L2ExpandedOp { __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; - return sqrt ? raft::mySqrt(outVal) : outVal; + return sqrt ? raft::sqrt(outVal) : outVal; } __device__ AccT operator()(DataT aData) const noexcept { return aData; } @@ -129,7 +130,7 @@ void euclideanExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + acc[i][j] = raft::sqrt(acc[i][j]); } } } @@ -247,8 +248,6 @@ void euclideanAlgo1(Index_ m, cudaStream_t stream, bool isRowMajor) { - auto norm_op = [] __device__(InType in) { return in; }; - // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), "OutType can be uint8_t, float, double," @@ -266,10 +265,13 @@ void euclideanAlgo1(Index_ m, InType* row_vec = workspace; if (pA != pB) { row_vec += m; - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); - raft::linalg::rowNorm(row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + raft::linalg::rowNorm( + row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } else { - raft::linalg::rowNorm(col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, norm_op); + raft::linalg::rowNorm( + col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } if (isRowMajor) { @@ -348,7 +350,7 @@ void euclideanUnExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + acc[i][j] = raft::sqrt(acc[i][j]); } } } diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index e8c2648c2e..447359ffe6 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -175,7 +175,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { auto acc_ij = acc[i][j]; - acc[i][j] = acc_ij > DataT{0} ? raft::mySqrt(acc_ij) : DataT{0}; + acc[i][j] = acc_ij > DataT{0} ? raft::sqrt(acc_ij) : DataT{0}; } } } @@ -298,8 +298,6 @@ void fusedL2NNImpl(OutT* min, RAFT_CUDA_TRY(cudaGetLastError()); } - auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; - constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); if (sqrt) { auto fusedL2NNSqrt = fusedL2NNkernel; + raft::identity_op>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - fusedL2NNSqrt<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + fusedL2NNSqrt<<>>(min, + x, + y, + xn, + yn, + m, + n, + k, + maxVal, + workspace, + redOp, + pairRedOp, + core_lambda, + raft::identity_op{}); } else { auto fusedL2NN = fusedL2NNkernel; + raft::identity_op>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + fusedL2NN<<>>(min, + x, + y, + xn, + yn, + m, + n, + k, + maxVal, + workspace, + redOp, + pairRedOp, + core_lambda, + raft::identity_op{}); } RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index 31854fd1d6..13507fe84f 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #pragma once #include #include +#include namespace raft { namespace distance { @@ -78,14 +79,10 @@ static void hellingerImpl(const DataT* x, dim3 blk(KPolicy::Nthreads); - auto unaryOp_lambda = [] __device__(DataT input) { return raft::mySqrt(input); }; // First sqrt x and y - raft::linalg::unaryOp( - (DataT*)x, x, m * k, unaryOp_lambda, stream); - + raft::linalg::unaryOp((DataT*)x, x, m * k, raft::sqrt_op{}, stream); if (x != y) { - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda, stream); + raft::linalg::unaryOp((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } // Accumulation operation lambda @@ -108,7 +105,7 @@ static void hellingerImpl(const DataT* x, // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative const auto finalVal = (1 - acc[i][j]); const auto rectifier = (!signbit(finalVal)); - acc[i][j] = raft::mySqrt(rectifier * finalVal); + acc[i][j] = raft::sqrt(rectifier * finalVal); } } }; @@ -145,11 +142,9 @@ static void hellingerImpl(const DataT* x, } // Revert sqrt of x and y - raft::linalg::unaryOp( - (DataT*)x, x, m * k, unaryOp_lambda, stream); + raft::linalg::unaryOp((DataT*)x, x, m * k, raft::sqrt_op{}, stream); if (x != y) { - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda, stream); + raft::linalg::unaryOp((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/include/raft/distance/detail/jensen_shannon.cuh b/cpp/include/raft/distance/detail/jensen_shannon.cuh index 92ee071cf5..f96da01b87 100644 --- a/cpp/include/raft/distance/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/jensen_shannon.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,11 +78,11 @@ static void jensenShannonImpl(const DataT* x, auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { const DataT m = 0.5f * (x + y); const bool m_zero = (m == 0); - const auto logM = (!m_zero) * raft::myLog(m + m_zero); + const auto logM = (!m_zero) * raft::log(m + m_zero); const bool x_zero = (x == 0); const bool y_zero = (y == 0); - acc += (-x * (logM - raft::myLog(x + x_zero))) + (-y * (logM - raft::myLog(y + y_zero))); + acc += (-x * (logM - raft::log(x + x_zero))) + (-y * (logM - raft::log(y + y_zero))); }; // epilogue operation lambda for final value calculation @@ -95,7 +95,7 @@ static void jensenShannonImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(0.5 * acc[i][j]); + acc[i][j] = raft::sqrt(0.5 * acc[i][j]); } } }; diff --git a/cpp/include/raft/distance/detail/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh index 4c0c4b6ace..7ebeaf4de9 100644 --- a/cpp/include/raft/distance/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/kl_divergence.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,10 +81,10 @@ static void klDivergenceImpl(const DataT* x, auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { if (isRowMajor) { const bool x_zero = (x == 0); - acc += x * (raft::myLog(x + x_zero) - y); + acc += x * (raft::log(x + x_zero) - y); } else { const bool y_zero = (y == 0); - acc += y * (raft::myLog(y + y_zero) - x); + acc += y * (raft::log(y + y_zero) - x); } }; @@ -92,23 +92,23 @@ static void klDivergenceImpl(const DataT* x, if (isRowMajor) { const bool x_zero = (x == 0); const bool y_zero = (y == 0); - acc += x * (raft::myLog(x + x_zero) - (!y_zero) * raft::myLog(y + y_zero)); + acc += x * (raft::log(x + x_zero) - (!y_zero) * raft::log(y + y_zero)); } else { const bool y_zero = (y == 0); const bool x_zero = (x == 0); - acc += y * (raft::myLog(y + y_zero) - (!x_zero) * raft::myLog(x + x_zero)); + acc += y * (raft::log(y + y_zero) - (!x_zero) * raft::log(x + x_zero)); } }; auto unaryOp_lambda = [] __device__(DataT input) { const bool x_zero = (input == 0); - return (!x_zero) * raft::myLog(input + x_zero); + return (!x_zero) * raft::log(input + x_zero); }; auto unaryOp_lambda_reverse = [] __device__(DataT input) { // reverse previous log (x) back to x using (e ^ log(x)) const bool x_zero = (input == 0); - return (!x_zero) * raft::myExp(input); + return (!x_zero) * raft::exp(input); }; // epilogue operation lambda for final value calculation diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 6372019fd3..bf10651b60 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -71,16 +71,12 @@ static void l1Impl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::L1Op()(x - y); + const auto diff = raft::abs(x - y); acc += diff; }; // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { return; }; + auto epilog_lambda = raft::void_op(); if (isRowMajor) { auto l1RowMajor = pairwiseDistanceMatKernel()(x - y); - acc += raft::myPow(diff, p); + const auto diff = raft::abs(x - y); + acc += raft::pow(diff, p); }; // epilogue operation lambda for final value calculation @@ -89,7 +89,7 @@ void minkowskiUnExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::myPow(acc[i][j], one_over_p); + acc[i][j] = raft::pow(acc[i][j], one_over_p); } } }; diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 26536d13cd..445b4bac52 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ * limitations under the License. */ #pragma once +#include #include #include #include diff --git a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h index ea9ed77fb5..8dcccfc14f 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h @@ -66,7 +66,7 @@ struct PairwiseDistanceGemm { /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op using InstructionShape = - cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 + cutlass::gemm::GemmShape<16, 8, 4>; // <- MMA Op tile M = 16, N = 8, K = 4 /// Operation performed by GEMM using Operator = cutlass::arch::OpMultiplyAddFastF32; diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 6e3f97b45c..93a5ce7f1a 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,21 +18,21 @@ #pragma once -#include +#include #include #include #include #include +namespace raft { +namespace distance { + /** - * @defgroup pairwise_distance pairwise distance prims + * @defgroup pairwise_distance pointer-based pairwise distance prims * @{ */ -namespace raft { -namespace distance { - /** * @brief Evaluate pairwise distances with the user epilogue lamba allowed * @tparam DistanceType which distance to evaluate @@ -219,58 +219,6 @@ void distance(const InType* x, x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, metric_arg); } -/** - * @brief Evaluate pairwise distances for the simple use case. - * - * Note: Only contiguous row- or column-major layouts supported currently. - * - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points (size n*k) - * @param y second set of points (size m*k) - * @param dist output distance matrix (size n*m) - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::handle_t const& handle, - raft::device_matrix_view const x, - raft::device_matrix_view const y, - raft::device_matrix_view dist, - InType metric_arg = 2.0f) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - RAFT_EXPECTS(dist.extent(0) == x.extent(0), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.extent(0), - "Number of columns in output must be equal to " - "number of rows in Y"); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - - constexpr auto is_rowmajor = std::is_same_v; - - distance(x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - handle.get_stream(), - is_rowmajor, - metric_arg); -} - /** * @brief Convenience wrapper around 'distance' prim to convert runtime metric * into compile time for the purpose of dispatch @@ -290,7 +238,7 @@ void distance(raft::handle_t const& handle, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(const raft::handle_t& handle, +void pairwise_distance(raft::device_resources const& handle, const Type* x, const Type* y, Type* dist, @@ -385,7 +333,7 @@ void pairwise_distance(const raft::handle_t& handle, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(const raft::handle_t& handle, +void pairwise_distance(raft::device_resources const& handle, const Type* x, const Type* y, Type* dist, @@ -401,6 +349,85 @@ void pairwise_distance(const raft::handle_t& handle, handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); } +/** @} */ + +/** + * \defgroup distance_mdspan Pairwise distance functions + * @{ + */ + +/** + * @brief Evaluate pairwise distances for the simple use case. + * + * Note: Only contiguous row- or column-major layouts supported currently. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * #include + * + * raft::raft::device_resources handle; + * int n_samples = 5000; + * int n_features = 50; + * + * auto input = raft::make_device_matrix(handle, n_samples, n_features); + * auto labels = raft::make_device_vector(handle, n_samples); + * auto output = raft::make_device_matrix(handle, n_samples, n_samples); + * + * raft::random::make_blobs(handle, input.view(), labels.view()); + * auto metric = raft::distance::DistanceType::L2SqrtExpanded; + * raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); + * @endcode + * + * @tparam DistanceType which distance to evaluate + * @tparam InType input argument type + * @tparam AccType accumulation type + * @tparam OutType output type + * @tparam Index_ Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points (size n*k) + * @param y second set of points (size m*k) + * @param dist output distance matrix (size n*m) + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::device_resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + InType metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); + + constexpr auto is_rowmajor = std::is_same_v; + + distance(x.data_handle(), + y.data_handle(), + dist.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + handle.get_stream(), + is_rowmajor, + metric_arg); +} + /** * @brief Convenience wrapper around 'distance' prim to convert runtime metric * into compile time for the purpose of dispatch @@ -414,7 +441,7 @@ void pairwise_distance(const raft::handle_t& handle, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(raft::handle_t const& handle, +void pairwise_distance(raft::device_resources const& handle, device_matrix_view const x, device_matrix_view const y, device_matrix_view dist, @@ -449,9 +476,9 @@ void pairwise_distance(raft::handle_t const& handle, metric_arg); } +/** @} */ + }; // namespace distance }; // namespace raft -/** @} */ - #endif \ No newline at end of file diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index ef51a54622..e832bcb020 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include #include @@ -30,6 +30,10 @@ namespace raft { namespace distance { +/** + * \defgroup fused_l2_nn Fused 1-nearest neighbors + * @{ + */ template using KVPMinReduce = detail::KVPMinReduceImpl; @@ -40,15 +44,22 @@ using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl template using MinReduceOp = detail::MinReduceOpImpl; +/** @} */ + /** * Initialize array using init value from reduction op */ template -void initialize(const raft::handle_t& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +void initialize( + raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { detail::initialize(min, m, maxVal, redOp, handle.get_stream()); } +/** + * \ingroup fused_l2_nn + * @{ + */ /** * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. * @@ -211,6 +222,8 @@ void fusedL2NNMinReduce(OutT* min, min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); } +/** @} */ + } // namespace distance } // namespace raft diff --git a/cpp/include/raft/handle.hpp b/cpp/include/raft/handle.hpp index 4525af49d2..caa68061db 100644 --- a/cpp/include/raft/handle.hpp +++ b/cpp/include/raft/handle.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,4 +21,4 @@ #pragma once -#include +#include diff --git a/cpp/include/raft/label/detail/classlabels.cuh b/cpp/include/raft/label/detail/classlabels.cuh index 0af1c70b91..64d8b4bfae 100644 --- a/cpp/include/raft/label/detail/classlabels.cuh +++ b/cpp/include/raft/label/detail/classlabels.cuh @@ -18,6 +18,7 @@ #include +#include #include #include #include @@ -194,8 +195,7 @@ void make_monotonic( template void make_monotonic(Type* out, Type* in, size_t N, cudaStream_t stream, bool zero_based = false) { - make_monotonic( - out, in, N, stream, [] __device__(Type val) { return false; }, zero_based); + make_monotonic(out, in, N, stream, raft::const_op(false), zero_based); } }; // namespace detail diff --git a/cpp/include/raft/linalg/add.cuh b/cpp/include/raft/linalg/add.cuh index 37956fe762..608c63e1a9 100644 --- a/cpp/include/raft/linalg/add.cuh +++ b/cpp/include/raft/linalg/add.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,6 @@ #ifndef __ADD_H #define __ADD_H -/** - * @defgroup arithmetic Dense matrix arithmetic - * @{ - */ - #pragma once #include "detail/add.cuh" @@ -32,8 +27,6 @@ namespace raft { namespace linalg { -using detail::adds_scalar; - /** * @ingroup arithmetic * @brief Elementwise scalar add operation on the input buffer @@ -94,7 +87,7 @@ void addDevScalar( } /** - * @defgroup add Addition Arithmetic + * @defgroup add_dense Addition Arithmetic * @{ */ @@ -102,7 +95,7 @@ void addDevScalar( * @brief Elementwise add operation * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in1 First Input * @param[in] in2 Second Input * @param[out] out Output @@ -111,7 +104,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void add(const raft::handle_t& handle, InType in1, InType in2, OutType out) +void add(raft::device_resources const& handle, InType in1, InType in2, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -142,7 +135,7 @@ void add(const raft::handle_t& handle, InType in1, InType in2, OutType out) * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[in] scalar raft::device_scalar_view * @param[in] out Output @@ -152,7 +145,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void add_scalar(const raft::handle_t& handle, +void add_scalar(raft::device_resources const& handle, InType in, OutType out, raft::device_scalar_view scalar) @@ -184,7 +177,7 @@ void add_scalar(const raft::handle_t& handle, * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[in] scalar raft::host_scalar_view * @param[in] out Output @@ -194,7 +187,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void add_scalar(const raft::handle_t& handle, +void add_scalar(raft::device_resources const& handle, const InType in, OutType out, raft::host_scalar_view scalar) @@ -226,6 +219,4 @@ void add_scalar(const raft::handle_t& handle, }; // end namespace linalg }; // end namespace raft -/** @} */ - #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/axpy.cuh b/cpp/include/raft/linalg/axpy.cuh index 88b065c8b0..9b3af73234 100644 --- a/cpp/include/raft/linalg/axpy.cuh +++ b/cpp/include/raft/linalg/axpy.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ namespace raft::linalg { * @param [in] stream */ template -void axpy(const raft::handle_t& handle, +void axpy(raft::device_resources const& handle, const int n, const T* alpha, const T* x, @@ -54,7 +54,7 @@ void axpy(const raft::handle_t& handle, } /** - * @defgroup axpy axpy + * @defgroup axpy axpy routine * @{ */ @@ -62,7 +62,7 @@ void axpy(const raft::handle_t& handle, * @brief axpy function * It computes the following equation: y = alpha * x + y * - * @param [in] handle raft::handle_t + * @param [in] handle raft::device_resources * @param [in] alpha raft::device_scalar_view * @param [in] x Input vector * @param [inout] y Output vector @@ -72,7 +72,7 @@ template -void axpy(const raft::handle_t& handle, +void axpy(raft::device_resources const& handle, raft::device_scalar_view alpha, raft::device_vector_view x, raft::device_vector_view y) @@ -92,7 +92,7 @@ void axpy(const raft::handle_t& handle, /** * @brief axpy function * It computes the following equation: y = alpha * x + y - * @param [in] handle raft::handle_t + * @param [in] handle raft::device_resources * @param [in] alpha raft::device_scalar_view * @param [in] x Input vector * @param [inout] y Output vector @@ -102,7 +102,7 @@ template -void axpy(const raft::handle_t& handle, +void axpy(raft::device_resources const& handle, raft::host_scalar_view alpha, raft::device_vector_view x, raft::device_vector_view y) diff --git a/cpp/include/raft/linalg/binary_op.cuh b/cpp/include/raft/linalg/binary_op.cuh index 693ef961c2..966e84965d 100644 --- a/cpp/include/raft/linalg/binary_op.cuh +++ b/cpp/include/raft/linalg/binary_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include "detail/binary_op.cuh" #include -#include +#include #include #include @@ -65,7 +65,7 @@ void binaryOp( * @tparam InType Input Type raft::device_mdspan * @tparam Lambda the device-lambda performing the actual operation * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in1 First input * @param[in] in2 Second input * @param[out] out Output @@ -78,7 +78,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void binary_op(const raft::handle_t& handle, InType in1, InType in2, OutType out, Lambda op) +void binary_op(raft::device_resources const& handle, InType in1, InType in2, OutType out, Lambda op) { RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous"); diff --git a/cpp/include/raft/linalg/cholesky_r1_update.cuh b/cpp/include/raft/linalg/cholesky_r1_update.cuh index af8d12d873..e10f43653b 100644 --- a/cpp/include/raft/linalg/cholesky_r1_update.cuh +++ b/cpp/include/raft/linalg/cholesky_r1_update.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -121,7 +121,7 @@ namespace linalg { * conditioned systems. Negative values mean no regularizaton. */ template -void choleskyRank1Update(const raft::handle_t& handle, +void choleskyRank1Update(raft::device_resources const& handle, math_t* L, int n, int ld, diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index e9e5a99f46..674be207d8 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,8 @@ #include "detail/coalesced_reduction.cuh" #include -#include +#include +#include namespace raft { namespace linalg { @@ -56,9 +57,9 @@ namespace linalg { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReduction(OutType* dots, const InType* data, IdxType D, @@ -66,9 +67,9 @@ void coalescedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { detail::coalescedReduction( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); @@ -100,7 +101,7 @@ void coalescedReduction(OutType* dots, * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) * It must be a 'callable' supporting the following input and output: *

OutType (*FinalLambda)(OutType);
- * @param handle raft::handle_t + * @param handle raft::device_resources * @param[in] data Input of type raft::device_matrix_view * @param[out] dots Output of type raft::device_matrix_view * @param[in] init initial value to use for the reduction @@ -113,17 +114,17 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> -void coalesced_reduction(const raft::handle_t& handle, + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> +void coalesced_reduction(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutValueType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if constexpr (std::is_same_v) { RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), diff --git a/cpp/include/raft/linalg/detail/add.cuh b/cpp/include/raft/linalg/detail/add.cuh index 34966ebbc2..bf9b2bd1d8 100644 --- a/cpp/include/raft/linalg/detail/add.cuh +++ b/cpp/include/raft/linalg/detail/add.cuh @@ -16,14 +16,11 @@ #pragma once -#include "functional.cuh" - +#include #include #include #include -#include - namespace raft { namespace linalg { namespace detail { @@ -31,13 +28,13 @@ namespace detail { template void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, adds_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::add_const_op(scalar), stream); } template void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::plus(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::add_op(), stream); } template diff --git a/cpp/include/raft/linalg/detail/axpy.cuh b/cpp/include/raft/linalg/detail/axpy.cuh index f3e1a177c8..5747e840c4 100644 --- a/cpp/include/raft/linalg/detail/axpy.cuh +++ b/cpp/include/raft/linalg/detail/axpy.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,12 +20,12 @@ #include "cublas_wrappers.hpp" -#include +#include namespace raft::linalg::detail { template -void axpy(const raft::handle_t& handle, +void axpy(raft::device_resources const& handle, const int n, const T* alpha, const T* x, diff --git a/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh b/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh index 47937815bd..afa9155753 100644 --- a/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh +++ b/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ #include "cublas_wrappers.hpp" #include "cusolver_wrappers.hpp" -#include +#include #include namespace raft { @@ -26,7 +26,7 @@ namespace linalg { namespace detail { template -void choleskyRank1Update(const raft::handle_t& handle, +void choleskyRank1Update(raft::device_resources const& handle, math_t* L, int n, int ld, diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh index 63351f5475..238e17fa56 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -71,9 +72,9 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThin(OutType* dots, const InType* data, IdxType D, @@ -81,9 +82,9 @@ void coalescedReductionThin(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { common::nvtx::range fun_scope( "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); @@ -97,9 +98,9 @@ void coalescedReductionThin(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThinDispatcher(OutType* dots, const InType* data, IdxType D, @@ -107,9 +108,9 @@ void coalescedReductionThinDispatcher(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if (D <= IdxType(2)) { coalescedReductionThin>( @@ -168,9 +169,9 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionMedium(OutType* dots, const InType* data, IdxType D, @@ -178,9 +179,9 @@ void coalescedReductionMedium(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { common::nvtx::range fun_scope("coalescedReductionMedium<%d>", TPB); coalescedReductionMediumKernel @@ -191,9 +192,9 @@ void coalescedReductionMedium(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionMediumDispatcher(OutType* dots, const InType* data, IdxType D, @@ -201,9 +202,9 @@ void coalescedReductionMediumDispatcher(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { // Note: for now, this kernel is only used when D > 256. If this changes in the future, use // smaller block sizes when relevant. @@ -251,9 +252,9 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThick(OutType* dots, const InType* data, IdxType D, @@ -261,9 +262,9 @@ void coalescedReductionThick(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { common::nvtx::range fun_scope( "coalescedReductionThick<%d,%d>", ThickPolicy::ThreadsPerBlock, ThickPolicy::BlocksPerRow); @@ -291,7 +292,7 @@ void coalescedReductionThick(OutType* dots, init, stream, inplace, - raft::Nop(), + raft::identity_op(), reduce_op, final_op); } @@ -299,9 +300,9 @@ void coalescedReductionThick(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReductionThickDispatcher(OutType* dots, const InType* data, IdxType D, @@ -309,9 +310,9 @@ void coalescedReductionThickDispatcher(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { // Note: multiple elements per thread to take advantage of the sequential reduction and loop // unrolling @@ -330,9 +331,9 @@ void coalescedReductionThickDispatcher(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void coalescedReduction(OutType* dots, const InType* data, IdxType D, @@ -340,9 +341,9 @@ void coalescedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { /* The primitive selects one of three implementations based on heuristics: * - Thin: very efficient when D is small and/or N is large diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 5d83f88e71..e247f39bc7 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/linalg/detail/divide.cuh b/cpp/include/raft/linalg/detail/divide.cuh index 333cd3e83c..eef1d19d6e 100644 --- a/cpp/include/raft/linalg/detail/divide.cuh +++ b/cpp/include/raft/linalg/detail/divide.cuh @@ -16,9 +16,8 @@ #pragma once -#include "functional.cuh" - #include +#include #include namespace raft { @@ -28,7 +27,7 @@ namespace detail { template void divideScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, divides_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::div_const_op(scalar), stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/eig.cuh b/cpp/include/raft/linalg/detail/eig.cuh index d48b42fc57..94493efb24 100644 --- a/cpp/include/raft/linalg/detail/eig.cuh +++ b/cpp/include/raft/linalg/detail/eig.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ #include "cusolver_wrappers.hpp" #include -#include +#include #include #include #include @@ -29,7 +29,7 @@ namespace linalg { namespace detail { template -void eigDC_legacy(const raft::handle_t& handle, +void eigDC_legacy(raft::device_resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -74,7 +74,7 @@ void eigDC_legacy(const raft::handle_t& handle, } template -void eigDC(const raft::handle_t& handle, +void eigDC(raft::device_resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -137,7 +137,7 @@ void eigDC(const raft::handle_t& handle, enum EigVecMemUsage { OVERWRITE_INPUT, COPY_INPUT }; template -void eigSelDC(const raft::handle_t& handle, +void eigSelDC(raft::device_resources const& handle, math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -228,7 +228,7 @@ void eigSelDC(const raft::handle_t& handle, } template -void eigJacobi(const raft::handle_t& handle, +void eigJacobi(raft::device_resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, diff --git a/cpp/include/raft/linalg/detail/eltwise.cuh b/cpp/include/raft/linalg/detail/eltwise.cuh index 019f86a779..25b4ca0499 100644 --- a/cpp/include/raft/linalg/detail/eltwise.cuh +++ b/cpp/include/raft/linalg/detail/eltwise.cuh @@ -16,13 +16,10 @@ #pragma once -#include "functional.cuh" - +#include #include #include -#include - namespace raft { namespace linalg { namespace detail { @@ -30,48 +27,48 @@ namespace detail { template void scalarAdd(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, adds_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::add_const_op(scalar), stream); } template void scalarMultiply(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, multiplies_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, raft::mul_const_op(scalar), stream); } template void eltwiseAdd( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::plus(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::add_op(), stream); } template void eltwiseSub( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::minus(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::sub_op(), stream); } template void eltwiseMultiply( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::multiplies(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::mul_op(), stream); } template void eltwiseDivide( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, thrust::divides(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::div_op(), stream); } template void eltwiseDivideCheckZero( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp(out, in1, in2, len, divides_check_zero(), stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::div_checkzero_op(), stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/functional.cuh b/cpp/include/raft/linalg/detail/functional.cuh deleted file mode 100644 index 067b1565e0..0000000000 --- a/cpp/include/raft/linalg/detail/functional.cuh +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace linalg { -namespace detail { - -template -struct divides_scalar { - public: - divides_scalar(ArgType scalar) : scalar_(scalar) {} - - __host__ __device__ inline ReturnType operator()(ArgType in) { return in / scalar_; } - - private: - ArgType scalar_; -}; - -template -struct adds_scalar { - public: - adds_scalar(ArgType scalar) : scalar_(scalar) {} - - __host__ __device__ inline ReturnType operator()(ArgType in) { return in + scalar_; } - - private: - ArgType scalar_; -}; - -template -struct multiplies_scalar { - public: - multiplies_scalar(ArgType scalar) : scalar_(scalar) {} - - __host__ __device__ inline ReturnType operator()(ArgType in) { return in * scalar_; } - - private: - ArgType scalar_; -}; - -template -struct divides_check_zero { - public: - __host__ __device__ inline ReturnType operator()(ArgType a, ArgType b) - { - return (b == static_cast(0)) ? 0.0 : a / b; - } -}; - -} // namespace detail -} // namespace linalg -} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/linalg/detail/gemm.hpp b/cpp/include/raft/linalg/detail/gemm.hpp index baa066984b..ba9496c3b9 100644 --- a/cpp/include/raft/linalg/detail/gemm.hpp +++ b/cpp/include/raft/linalg/detail/gemm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include "cublas_wrappers.hpp" -#include +#include namespace raft { namespace linalg { @@ -49,7 +49,7 @@ namespace detail { * @param [in] stream */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const bool trans_a, const bool trans_b, const int m, @@ -103,7 +103,7 @@ void gemm(const raft::handle_t& handle, * @param stream cuda stream */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -130,7 +130,7 @@ void gemm(const raft::handle_t& handle, } template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -149,7 +149,7 @@ void gemm(const raft::handle_t& handle, } template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, T* z, T* x, T* y, diff --git a/cpp/include/raft/linalg/detail/gemv.hpp b/cpp/include/raft/linalg/detail/gemv.hpp index 38fcdcd82e..b3e001a851 100644 --- a/cpp/include/raft/linalg/detail/gemv.hpp +++ b/cpp/include/raft/linalg/detail/gemv.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,14 +20,14 @@ #include "cublas_wrappers.hpp" -#include +#include namespace raft { namespace linalg { namespace detail { template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const bool trans_a, const int m, const int n, @@ -59,7 +59,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -76,7 +76,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -91,7 +91,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -107,7 +107,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -126,7 +126,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, diff --git a/cpp/include/raft/linalg/detail/lanczos.cuh b/cpp/include/raft/linalg/detail/lanczos.cuh index 5a3c595512..8c0cfeba28 100644 --- a/cpp/include/raft/linalg/detail/lanczos.cuh +++ b/cpp/include/raft/linalg/detail/lanczos.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ #include #include "cublas_wrappers.hpp" -#include +#include #include #include #include @@ -82,7 +82,7 @@ inline curandStatus_t curandGenerateNormalX( * @return Zero if successful. Otherwise non-zero. */ template -int performLanczosIteration(handle_t const& handle, +int performLanczosIteration(raft::device_resources const& handle, spectral::matrix::sparse_matrix_t const* A, index_type_t* iter, index_type_t maxIter, @@ -540,7 +540,7 @@ static int francisQRIteration(index_type_t n, * @return error flag. */ template -static int lanczosRestart(handle_t const& handle, +static int lanczosRestart(raft::device_resources const& handle, index_type_t n, index_type_t iter, index_type_t iter_new, @@ -743,7 +743,7 @@ static int lanczosRestart(handle_t const& handle, */ template int computeSmallestEigenvectors( - handle_t const& handle, + raft::device_resources const& handle, spectral::matrix::sparse_matrix_t const* A, index_type_t nEigVecs, index_type_t maxIter, @@ -984,7 +984,7 @@ int computeSmallestEigenvectors( template int computeSmallestEigenvectors( - handle_t const& handle, + raft::device_resources const& handle, spectral::matrix::sparse_matrix_t const& A, index_type_t nEigVecs, index_type_t maxIter, @@ -1087,7 +1087,7 @@ int computeSmallestEigenvectors( */ template int computeLargestEigenvectors( - handle_t const& handle, + raft::device_resources const& handle, spectral::matrix::sparse_matrix_t const* A, index_type_t nEigVecs, index_type_t maxIter, @@ -1331,7 +1331,7 @@ int computeLargestEigenvectors( template int computeLargestEigenvectors( - handle_t const& handle, + raft::device_resources const& handle, spectral::matrix::sparse_matrix_t const& A, index_type_t nEigVecs, index_type_t maxIter, diff --git a/cpp/include/raft/linalg/detail/lstsq.cuh b/cpp/include/raft/linalg/detail/lstsq.cuh index 1273956b21..207bcefc32 100644 --- a/cpp/include/raft/linalg/detail/lstsq.cuh +++ b/cpp/include/raft/linalg/detail/lstsq.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -104,7 +104,7 @@ struct DivideByNonZero { operator()(const math_t a, const math_t b) const { - return raft::myAbs(b) >= eps ? a / b : a; + return raft::abs(b) >= eps ? a / b : a; } }; @@ -117,7 +117,7 @@ struct DivideByNonZero { * so it's not guaranteed to stay unmodified. */ template -void lstsqSvdQR(const raft::handle_t& handle, +void lstsqSvdQR(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -177,7 +177,7 @@ void lstsqSvdQR(const raft::handle_t& handle, * so it's not guaranteed to stay unmodified. */ template -void lstsqSvdJacobi(const raft::handle_t& handle, +void lstsqSvdJacobi(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -248,7 +248,7 @@ void lstsqSvdJacobi(const raft::handle_t& handle, * (`w = (A^T A)^-1 A^T b`) */ template -void lstsqEig(const raft::handle_t& handle, +void lstsqEig(raft::device_resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -352,7 +352,7 @@ void lstsqEig(const raft::handle_t& handle, * Warning: the content of this vector is modified by the cuSOLVER routines. */ template -void lstsqQR(const raft::handle_t& handle, +void lstsqQR(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, diff --git a/cpp/include/raft/linalg/detail/map.cuh b/cpp/include/raft/linalg/detail/map.cuh index add003eb52..e0b473bdd4 100644 --- a/cpp/include/raft/linalg/detail/map.cuh +++ b/cpp/include/raft/linalg/detail/map.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include diff --git a/cpp/include/raft/linalg/detail/map_then_reduce.cuh b/cpp/include/raft/linalg/detail/map_then_reduce.cuh index 7ef9ca1c43..70bb2df4f5 100644 --- a/cpp/include/raft/linalg/detail/map_then_reduce.cuh +++ b/cpp/include/raft/linalg/detail/map_then_reduce.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include diff --git a/cpp/include/raft/linalg/detail/multiply.cuh b/cpp/include/raft/linalg/detail/multiply.cuh index f1a8548bfa..84b832d875 100644 --- a/cpp/include/raft/linalg/detail/multiply.cuh +++ b/cpp/include/raft/linalg/detail/multiply.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include namespace raft { @@ -26,8 +27,7 @@ template void multiplyScalar( math_t* out, const math_t* in, const math_t scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp( - out, in, len, [scalar] __device__(math_t in) { return in * scalar; }, stream); + raft::linalg::unaryOp(out, in, len, raft::mul_const_op{scalar}, stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/norm.cuh b/cpp/include/raft/linalg/detail/norm.cuh index f2f08233d5..ed7e360848 100644 --- a/cpp/include/raft/linalg/detail/norm.cuh +++ b/cpp/include/raft/linalg/detail/norm.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -44,8 +45,8 @@ void rowNormCaller(Type* dots, true, stream, false, - raft::L1Op(), - raft::Sum(), + raft::abs_op(), + raft::add_op(), fin_op); break; case L2Norm: @@ -58,8 +59,8 @@ void rowNormCaller(Type* dots, true, stream, false, - raft::L2Op(), - raft::Sum(), + raft::sq_op(), + raft::add_op(), fin_op); break; case LinfNorm: @@ -72,8 +73,8 @@ void rowNormCaller(Type* dots, true, stream, false, - raft::L1Op(), - raft::Max(), + raft::abs_op(), + raft::max_op(), fin_op); break; default: THROW("Unsupported norm type: %d", type); @@ -101,8 +102,8 @@ void colNormCaller(Type* dots, false, stream, false, - raft::L1Op(), - raft::Sum(), + raft::abs_op(), + raft::add_op(), fin_op); break; case L2Norm: @@ -115,8 +116,8 @@ void colNormCaller(Type* dots, false, stream, false, - raft::L2Op(), - raft::Sum(), + raft::sq_op(), + raft::add_op(), fin_op); break; case LinfNorm: @@ -129,8 +130,8 @@ void colNormCaller(Type* dots, false, stream, false, - raft::L1Op(), - raft::Max(), + raft::abs_op(), + raft::max_op(), fin_op); break; default: THROW("Unsupported norm type: %d", type); diff --git a/cpp/include/raft/linalg/detail/qr.cuh b/cpp/include/raft/linalg/detail/qr.cuh index 74e9c3e1aa..4cba028d87 100644 --- a/cpp/include/raft/linalg/detail/qr.cuh +++ b/cpp/include/raft/linalg/detail/qr.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ namespace detail { */ template void qrGetQ_inplace( - const raft::handle_t& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream) + raft::device_resources const& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream) { RAFT_EXPECTS(n_rows >= n_cols, "QR decomposition expects n_rows >= n_cols."); cusolverDnHandle_t cusolver = handle.get_cusolver_dn_handle(); @@ -83,7 +83,7 @@ void qrGetQ_inplace( } template -void qrGetQ(const raft::handle_t& handle, +void qrGetQ(raft::device_resources const& handle, const math_t* M, math_t* Q, int n_rows, @@ -95,7 +95,7 @@ void qrGetQ(const raft::handle_t& handle, } template -void qrGetQR(const raft::handle_t& handle, +void qrGetQR(raft::device_resources const& handle, math_t* M, math_t* Q, math_t* R, diff --git a/cpp/include/raft/linalg/detail/reduce.cuh b/cpp/include/raft/linalg/detail/reduce.cuh index 3022973b43..721ca8179f 100644 --- a/cpp/include/raft/linalg/detail/reduce.cuh +++ b/cpp/include/raft/linalg/detail/reduce.cuh @@ -16,9 +16,9 @@ #pragma once +#include #include #include -#include namespace raft { namespace linalg { @@ -27,9 +27,9 @@ namespace detail { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void reduce(OutType* dots, const InType* data, IdxType D, @@ -39,9 +39,9 @@ void reduce(OutType* dots, bool alongRows, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if (rowMajor && alongRows) { raft::linalg::coalescedReduction( diff --git a/cpp/include/raft/linalg/detail/reduce_cols_by_key.cuh b/cpp/include/raft/linalg/detail/reduce_cols_by_key.cuh index 450fb415e2..a85e04acca 100644 --- a/cpp/include/raft/linalg/detail/reduce_cols_by_key.cuh +++ b/cpp/include/raft/linalg/detail/reduce_cols_by_key.cuh @@ -29,12 +29,12 @@ namespace detail { ///@todo: specialize this to support shared-mem based atomics template -__global__ void reduce_cols_by_key_kernel( +__global__ void reduce_cols_by_key_direct_kernel( const T* data, const KeyIteratorT keys, T* out, IdxType nrows, IdxType ncols, IdxType nkeys) { typedef typename std::iterator_traits::value_type KeyType; - IdxType idx = blockIdx.x * blockDim.x + threadIdx.x; + IdxType idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (idx >= (nrows * ncols)) return; ///@todo: yikes! use fast-int-div IdxType colId = idx % ncols; @@ -43,6 +43,38 @@ __global__ void reduce_cols_by_key_kernel( raft::myAtomicAdd(out + rowId * nkeys + key, data[idx]); } +template +__global__ void reduce_cols_by_key_cached_kernel( + const T* data, const KeyIteratorT keys, T* out, IdxType nrows, IdxType ncols, IdxType nkeys) +{ + typedef typename std::iterator_traits::value_type KeyType; + extern __shared__ char smem[]; + T* out_cache = reinterpret_cast(smem); + + // Initialize the shared memory accumulators to 0. + for (IdxType idx = threadIdx.x; idx < nrows * nkeys; idx += blockDim.x) { + out_cache[idx] = T{0}; + } + __syncthreads(); + + // Accumulate in shared memory + for (IdxType idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + idx < nrows * ncols; + idx += blockDim.x * static_cast(gridDim.x)) { + IdxType colId = idx % ncols; + IdxType rowId = idx / ncols; + KeyType key = keys[colId]; + raft::myAtomicAdd(out_cache + rowId * nkeys + key, data[idx]); + } + + // Add the shared-memory accumulators to the global results. + __syncthreads(); + for (IdxType idx = threadIdx.x; idx < nrows * nkeys; idx += blockDim.x) { + T val = out_cache[idx]; + if (val != T{0}) { raft::myAtomicAdd(out + idx, val); } + } +} + /** * @brief Computes the sum-reduction of matrix columns for each given key * @tparam T the input data type (as well as the output reduced matrix) @@ -60,6 +92,7 @@ __global__ void reduce_cols_by_key_kernel( * @param ncols number of columns in the input data * @param nkeys number of unique keys in the keys array * @param stream cuda stream to launch the kernel onto + * @param reset_sums Whether to reset the output sums to zero before reducing */ template void reduce_cols_by_key(const T* data, @@ -68,16 +101,42 @@ void reduce_cols_by_key(const T* data, IdxType nrows, IdxType ncols, IdxType nkeys, - cudaStream_t stream) + cudaStream_t stream, + bool reset_sums) { typedef typename std::iterator_traits::value_type KeyType; - RAFT_CUDA_TRY(cudaMemsetAsync(out, 0, sizeof(T) * nrows * nkeys, stream)); - constexpr int TPB = 256; - int nblks = (int)raft::ceildiv(nrows * ncols, TPB); - reduce_cols_by_key_kernel<<>>(data, keys, out, nrows, ncols, nkeys); + RAFT_EXPECTS(static_cast(nrows) * static_cast(ncols) <= + static_cast(std::numeric_limits::max()), + "Index type too small to represent indices in the input array."); + RAFT_EXPECTS(static_cast(nrows) * static_cast(nkeys) <= + static_cast(std::numeric_limits::max()), + "Index type too small to represent indices in the output array."); + + // Memset the output to zero to use atomics-based reduction. + if (reset_sums) { RAFT_CUDA_TRY(cudaMemsetAsync(out, 0, sizeof(T) * nrows * nkeys, stream)); } + + // The cached version is used when the cache fits in shared memory and the number of input + // elements is above a threshold (the cached version is slightly slower for small input arrays, + // and orders of magnitude faster for large input arrays). + size_t cache_size = static_cast(nrows * nkeys) * sizeof(T); + if (cache_size <= 49152ull && nrows * ncols >= IdxType{8192}) { + constexpr int TPB = 256; + int n_sm = raft::getMultiProcessorCount(); + int target_nblks = 4 * n_sm; + int max_nblks = raft::ceildiv(nrows * ncols, TPB); + int nblks = std::min(target_nblks, max_nblks); + reduce_cols_by_key_cached_kernel<<>>( + data, keys, out, nrows, ncols, nkeys); + } else { + constexpr int TPB = 256; + int nblks = raft::ceildiv(nrows * ncols, TPB); + reduce_cols_by_key_direct_kernel<<>>( + data, keys, out, nrows, ncols, nkeys); + } RAFT_CUDA_TRY(cudaPeekAtLastError()); } + }; // end namespace detail }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index f96598d9e6..a66a23179b 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,7 +54,7 @@ namespace detail { * @param stream cuda stream */ template -void rsvdFixedRank(const raft::handle_t& handle, +void rsvdFixedRank(raft::device_resources const& handle, math_t* M, int n_rows, int n_cols, @@ -371,7 +371,7 @@ void rsvdFixedRank(const raft::handle_t& handle, * @param stream cuda stream */ template -void rsvdPerc(const raft::handle_t& handle, +void rsvdPerc(raft::device_resources const& handle, math_t* M, int n_rows, int n_cols, diff --git a/cpp/include/raft/linalg/detail/strided_reduction.cuh b/cpp/include/raft/linalg/detail/strided_reduction.cuh index d72bd54a32..0e516b4750 100644 --- a/cpp/include/raft/linalg/detail/strided_reduction.cuh +++ b/cpp/include/raft/linalg/detail/strided_reduction.cuh @@ -18,6 +18,7 @@ #include "unary_op.cuh" #include +#include #include #include #include @@ -107,9 +108,9 @@ __global__ void stridedReductionKernel(OutType* dots, template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void stridedReduction(OutType* dots, const InType* data, IdxType D, @@ -117,15 +118,13 @@ void stridedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { ///@todo: this extra should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) - if (!inplace) - raft::linalg::unaryOp( - dots, dots, D, [init] __device__(OutType a) { return init; }, stream); + if (!inplace) raft::linalg::unaryOp(dots, dots, D, raft::const_op(init), stream); // Arbitrary numbers for now, probably need to tune const dim3 thrds(32, 16); @@ -137,7 +136,7 @@ void stridedReduction(OutType* dots, ///@todo: this complication should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) - if constexpr (std::is_same>::value && + if constexpr (std::is_same::value && std::is_same::value) stridedSummationKernel <<>>(dots, data, D, N, init, main_op); @@ -148,7 +147,7 @@ void stridedReduction(OutType* dots, ///@todo: this complication should go away once we have eliminated the need /// for atomics in stridedKernel (redesign for this is already underway) // Perform final op on output data - if (!std::is_same>::value) + if (!std::is_same::value) raft::linalg::unaryOp(dots, dots, D, final_op, stream); } diff --git a/cpp/include/raft/linalg/detail/subtract.cuh b/cpp/include/raft/linalg/detail/subtract.cuh index ae0f09d2fe..6df09df8ed 100644 --- a/cpp/include/raft/linalg/detail/subtract.cuh +++ b/cpp/include/raft/linalg/detail/subtract.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -27,15 +28,13 @@ namespace detail { template void subtractScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - auto op = [scalar] __device__(InT in) { return OutT(in - scalar); }; - raft::linalg::unaryOp(out, in, len, op, stream); + raft::linalg::unaryOp(out, in, len, raft::sub_const_op(scalar), stream); } template void subtract(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t stream) { - auto op = [] __device__(InT a, InT b) { return OutT(a - b); }; - raft::linalg::binaryOp(out, in1, in2, len, op, stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::sub_op(), stream); } template diff --git a/cpp/include/raft/linalg/detail/svd.cuh b/cpp/include/raft/linalg/detail/svd.cuh index 90a7ddec1f..4850744f51 100644 --- a/cpp/include/raft/linalg/detail/svd.cuh +++ b/cpp/include/raft/linalg/detail/svd.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ #include #include -#include +#include #include #include #include @@ -36,7 +36,7 @@ namespace linalg { namespace detail { template -void svdQR(const raft::handle_t& handle, +void svdQR(raft::device_resources const& handle, T* in, int n_rows, int n_cols, @@ -101,14 +101,14 @@ void svdQR(const raft::handle_t& handle, "This usually occurs when some of the features do not vary enough."); } -template -void svdEig(const raft::handle_t& handle, - T* in, - int n_rows, - int n_cols, - T* S, - T* U, - T* V, +template +void svdEig(raft::device_resources const& handle, + math_t* in, + idx_t n_rows, + idx_t n_cols, + math_t* S, + math_t* U, + math_t* V, bool gen_left_vec, cudaStream_t stream) { @@ -117,11 +117,11 @@ void svdEig(const raft::handle_t& handle, cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); cublasHandle_t cublasH = handle.get_cublas_handle(); - int len = n_cols * n_cols; - rmm::device_uvector in_cross_mult(len, stream); + auto len = n_cols * n_cols; + rmm::device_uvector in_cross_mult(len, stream); - T alpha = T(1); - T beta = T(0); + math_t alpha = math_t(1); + math_t beta = math_t(0); raft::linalg::gemm(handle, in, n_rows, @@ -139,7 +139,7 @@ void svdEig(const raft::handle_t& handle, raft::linalg::eigDC(handle, in_cross_mult.data(), n_cols, n_cols, V, S, stream); raft::matrix::colReverse(V, n_cols, n_cols, stream); - raft::matrix::rowReverse(S, n_cols, 1, stream); + raft::matrix::rowReverse(S, n_cols, idx_t(1), stream); raft::matrix::seqRoot(S, S, alpha, n_cols, stream, true); @@ -162,7 +162,7 @@ void svdEig(const raft::handle_t& handle, } template -void svdJacobi(const raft::handle_t& handle, +void svdJacobi(raft::device_resources const& handle, math_t* in, int n_rows, int n_cols, @@ -232,7 +232,7 @@ void svdJacobi(const raft::handle_t& handle, } template -void svdReconstruction(const raft::handle_t& handle, +void svdReconstruction(raft::device_resources const& handle, math_t* U, math_t* S, math_t* V, @@ -263,7 +263,7 @@ void svdReconstruction(const raft::handle_t& handle, } template -bool evaluateSVDByL2Norm(const raft::handle_t& handle, +bool evaluateSVDByL2Norm(raft::device_resources const& handle, math_t* A_d, math_t* U, math_t* S_vec, diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index ef5551ea7e..9e7b236fed 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include "cublas_wrappers.hpp" #include -#include +#include #include #include #include @@ -29,7 +29,7 @@ namespace linalg { namespace detail { template -void transpose(const raft::handle_t& handle, +void transpose(raft::device_resources const& handle, math_t* in, math_t* out, int n_rows, @@ -82,7 +82,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream) template void transpose_row_major_impl( - handle_t const& handle, + raft::device_resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { @@ -108,7 +108,7 @@ void transpose_row_major_impl( template void transpose_col_major_impl( - handle_t const& handle, + raft::device_resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { diff --git a/cpp/include/raft/linalg/divide.cuh b/cpp/include/raft/linalg/divide.cuh index 53b083045e..0b18e6175c 100644 --- a/cpp/include/raft/linalg/divide.cuh +++ b/cpp/include/raft/linalg/divide.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,13 +21,12 @@ #include "detail/divide.cuh" #include +#include #include namespace raft { namespace linalg { -using detail::divides_scalar; - /** * @defgroup ScalarOps Scalar operations on the input buffer * @tparam OutT output data-type upon which the math operation will be performed @@ -57,7 +56,7 @@ void divideScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_ * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[in] scalar raft::host_scalar_view * @param[out] out Output @@ -67,7 +66,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void divide_scalar(const raft::handle_t& handle, +void divide_scalar(raft::device_resources const& handle, InType in, OutType out, raft::host_scalar_view scalar) diff --git a/cpp/include/raft/linalg/dot.cuh b/cpp/include/raft/linalg/dot.cuh index 48577650bc..917188d695 100644 --- a/cpp/include/raft/linalg/dot.cuh +++ b/cpp/include/raft/linalg/dot.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,13 +21,19 @@ #include #include -#include +#include #include namespace raft::linalg { + +/** + * @defgroup dot BLAS dot routine + * @{ + */ + /** * @brief Computes the dot product of two vectors. - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] x First input vector * @param[in] y Second input vector * @param[out] out The output dot product between the x and y vectors. @@ -37,7 +43,7 @@ template -void dot(const raft::handle_t& handle, +void dot(raft::device_resources const& handle, raft::device_vector_view x, raft::device_vector_view y, raft::device_scalar_view out) @@ -57,7 +63,7 @@ void dot(const raft::handle_t& handle, /** * @brief Computes the dot product of two vectors. - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] x First input vector * @param[in] y Second input vector * @param[out] out The output dot product between the x and y vectors. @@ -67,7 +73,7 @@ template -void dot(const raft::handle_t& handle, +void dot(raft::device_resources const& handle, raft::device_vector_view x, raft::device_vector_view y, raft::host_scalar_view out) @@ -84,5 +90,8 @@ void dot(const raft::handle_t& handle, out.data_handle(), handle.get_stream())); } + +/** @} */ // end of group dot + } // namespace raft::linalg #endif diff --git a/cpp/include/raft/linalg/eig.cuh b/cpp/include/raft/linalg/eig.cuh index 2ad222d42d..03e94a10b1 100644 --- a/cpp/include/raft/linalg/eig.cuh +++ b/cpp/include/raft/linalg/eig.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,11 +25,6 @@ namespace raft { namespace linalg { -/** - * @defgroup eig Eigen Decomposition Methods - * @{ - */ - /** * @brief eig decomp with divide and conquer method for the column-major * symmetric matrices @@ -43,7 +38,7 @@ namespace linalg { * @param stream cuda stream */ template -void eigDC(const raft::handle_t& handle, +void eigDC(raft::device_resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -73,7 +68,7 @@ using detail::OVERWRITE_INPUT; * @param stream cuda stream */ template -void eigSelDC(const raft::handle_t& handle, +void eigSelDC(raft::device_resources const& handle, math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -102,7 +97,7 @@ void eigSelDC(const raft::handle_t& handle, * accuracy. */ template -void eigJacobi(const raft::handle_t& handle, +void eigJacobi(raft::device_resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -115,19 +110,24 @@ void eigJacobi(const raft::handle_t& handle, detail::eigJacobi(handle, in, n_rows, n_cols, eig_vectors, eig_vals, stream, tol, sweeps); } +/** + * @defgroup eig Eigen Decomposition Methods + * @{ + */ + /** * @brief eig decomp with divide and conquer method for the column-major * symmetric matrices * @tparam ValueType the data-type of input and output * @tparam IntegerType Integer used for addressing - * @param handle raft::handle_t + * @param handle raft::device_resources * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and * vectors) * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view * @param[out] eig_vals: eigen values output of type raft::device_vector_view */ template -void eig_dc(const raft::handle_t& handle, +void eig_dc(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view eig_vectors, raft::device_vector_view eig_vals) @@ -149,7 +149,7 @@ void eig_dc(const raft::handle_t& handle, * for the column-major symmetric matrices * @tparam ValueType the data-type of input and output * @tparam IntegerType Integer used for addressing - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and * vectors) * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view @@ -158,7 +158,7 @@ void eig_dc(const raft::handle_t& handle, * @param[in] memUsage: the memory selection for eig vector output */ template -void eig_dc_selective(const raft::handle_t& handle, +void eig_dc_selective(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view eig_vectors, raft::device_vector_view eig_vals, @@ -185,7 +185,7 @@ void eig_dc_selective(const raft::handle_t& handle, * column-major symmetric matrices (in parameter) * @tparam ValueType the data-type of input and output * @tparam IntegerType Integer used for addressing - * @param handle raft::handle_t + * @param handle raft::device_resources * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and * vectors) * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view @@ -196,7 +196,7 @@ void eig_dc_selective(const raft::handle_t& handle, * accuracy. */ template -void eig_jacobi(const raft::handle_t& handle, +void eig_jacobi(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view eig_vectors, raft::device_vector_view eig_vals, diff --git a/cpp/include/raft/linalg/eltwise.cuh b/cpp/include/raft/linalg/eltwise.cuh index dbc06a4af3..2e6c1a4ab5 100644 --- a/cpp/include/raft/linalg/eltwise.cuh +++ b/cpp/include/raft/linalg/eltwise.cuh @@ -23,8 +23,6 @@ namespace raft { namespace linalg { -using detail::adds_scalar; - /** * @defgroup ScalarOps Scalar operations on the input buffer * @tparam InType data-type upon which the math operation will be performed @@ -42,8 +40,6 @@ void scalarAdd(OutType* out, const InType* in, InType scalar, IdxType len, cudaS detail::scalarAdd(out, in, scalar, len, stream); } -using detail::multiplies_scalar; - template void scalarMultiply(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { @@ -90,8 +86,6 @@ void eltwiseDivide( detail::eltwiseDivide(out, in1, in2, len, stream); } -using detail::divides_check_zero; - template void eltwiseDivideCheckZero( OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) diff --git a/cpp/include/raft/linalg/gemm.cuh b/cpp/include/raft/linalg/gemm.cuh index f2354da6c6..d5dc5ffab5 100644 --- a/cpp/include/raft/linalg/gemm.cuh +++ b/cpp/include/raft/linalg/gemm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,7 +52,7 @@ namespace linalg { * @param [in] stream */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const bool trans_a, const bool trans_b, const int m, @@ -91,7 +91,7 @@ void gemm(const raft::handle_t& handle, * @param stream cuda stream */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -126,7 +126,7 @@ void gemm(const raft::handle_t& handle, * @param stream cuda stream */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -161,7 +161,7 @@ void gemm(const raft::handle_t& handle, * @param beta scalar */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, T* z, T* x, T* y, @@ -213,7 +213,7 @@ template >, std::is_same>>>> -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, raft::device_matrix_view x, raft::device_matrix_view y, raft::device_matrix_view z, diff --git a/cpp/include/raft/linalg/gemv.cuh b/cpp/include/raft/linalg/gemv.cuh index 8132a742f8..96846003f6 100644 --- a/cpp/include/raft/linalg/gemv.cuh +++ b/cpp/include/raft/linalg/gemv.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ namespace linalg { * @param [in] stream */ template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const bool trans_a, const int m, const int n, @@ -69,7 +69,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -103,7 +103,7 @@ void gemv(const raft::handle_t& handle, * @param stream stream on which this function is run */ template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -133,7 +133,7 @@ void gemv(const raft::handle_t& handle, * @param stream stream on which this function is run */ template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -165,7 +165,7 @@ void gemv(const raft::handle_t& handle, * @param stream stream on which this function is run */ template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -199,7 +199,7 @@ void gemv(const raft::handle_t& handle, * */ template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -246,7 +246,7 @@ template >, std::is_same>>>> -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, raft::device_matrix_view A, raft::device_vector_view x, raft::device_vector_view y, diff --git a/cpp/include/raft/linalg/lstsq.cuh b/cpp/include/raft/linalg/lstsq.cuh index 7654812886..b36a9eba96 100644 --- a/cpp/include/raft/linalg/lstsq.cuh +++ b/cpp/include/raft/linalg/lstsq.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ #pragma once -#include +#include #include namespace raft { namespace linalg { @@ -37,7 +37,7 @@ namespace linalg { * @param[in] stream cuda stream for ordering operations */ template -void lstsqSvdQR(const raft::handle_t& handle, +void lstsqSvdQR(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -62,7 +62,7 @@ void lstsqSvdQR(const raft::handle_t& handle, * @param[in] stream cuda stream for ordering operations */ template -void lstsqSvdJacobi(const raft::handle_t& handle, +void lstsqSvdJacobi(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -78,7 +78,7 @@ void lstsqSvdJacobi(const raft::handle_t& handle, * (`w = (A^T A)^-1 A^T b`) */ template -void lstsqEig(const raft::handle_t& handle, +void lstsqEig(raft::device_resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -104,7 +104,7 @@ void lstsqEig(const raft::handle_t& handle, * @param[in] stream cuda stream for ordering operations */ template -void lstsqQR(const raft::handle_t& handle, +void lstsqQR(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -125,7 +125,7 @@ void lstsqQR(const raft::handle_t& handle, * Via SVD decomposition of `A = U S Vt`. * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified. * @param[inout] b input target raft::device_vector_view @@ -133,7 +133,7 @@ void lstsqQR(const raft::handle_t& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_svd_qr(const raft::handle_t& handle, +void lstsq_svd_qr(raft::device_resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) @@ -155,7 +155,7 @@ void lstsq_svd_qr(const raft::handle_t& handle, * Via SVD decomposition of `A = U S V^T` using Jacobi iterations. * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified. * @param[inout] b input target raft::device_vector_view @@ -163,7 +163,7 @@ void lstsq_svd_qr(const raft::handle_t& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_svd_jacobi(const raft::handle_t& handle, +void lstsq_svd_jacobi(raft::device_resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) @@ -186,7 +186,7 @@ void lstsq_svd_jacobi(const raft::handle_t& handle, * (`w = (A^T A)^-1 A^T b`) * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified by the cuSOLVER routines. * @param[inout] b input target raft::device_vector_view @@ -194,7 +194,7 @@ void lstsq_svd_jacobi(const raft::handle_t& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_eig(const raft::handle_t& handle, +void lstsq_eig(raft::device_resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) @@ -217,7 +217,7 @@ void lstsq_eig(const raft::handle_t& handle, * (triangular system of equations `Rw = Q^T b`) * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified. * @param[inout] b input target raft::device_vector_view @@ -225,7 +225,7 @@ void lstsq_eig(const raft::handle_t& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_qr(const raft::handle_t& handle, +void lstsq_qr(raft::device_resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) diff --git a/cpp/include/raft/linalg/map.cuh b/cpp/include/raft/linalg/map.cuh index ad35cc5880..2b9e6c80a0 100644 --- a/cpp/include/raft/linalg/map.cuh +++ b/cpp/include/raft/linalg/map.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,9 @@ #include "detail/map.cuh" #include +#include #include +#include namespace raft { namespace linalg { @@ -65,7 +67,7 @@ void map_k( * @tparam TPB threads-per-block in the final kernel launched * @tparam OutType data-type of result of type raft::device_mdspan * @tparam Args additional parameters - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input of type raft::device_mdspan * @param[out] out the output of the map operation of type raft::device_mdspan * @param[in] map the device-lambda @@ -78,7 +80,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void map(const raft::handle_t& handle, InType in, OutType out, MapOp map, Args... args) +void map(raft::device_resources const& handle, InType in, OutType out, MapOp map, Args... args) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -96,9 +98,43 @@ void map(const raft::handle_t& handle, InType in, OutType out, MapOp map, Args.. } } +/** + * @brief Perform an element-wise unary operation on the input offset into the output array + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * #include + * ... + * raft::handle_t handle; + * auto squares = raft::make_device_vector(handle, n); + * raft::linalg::map_offset(handle, squares.view(), raft::sq_op()); + * @endcode + * + * @tparam OutType Output mdspan type + * @tparam MapOp The unary operation type with signature `OutT func(const IdxT& idx);` + * @param[in] handle The raft handle + * @param[out] out Output array + * @param[in] op The unary operation + */ +template > +void map_offset(const raft::device_resources& handle, OutType out, MapOp op) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + + using out_value_t = typename OutType::value_type; + + thrust::tabulate( + handle.get_thrust_policy(), out.data_handle(), out.data_handle() + out.size(), op); +} + /** @} */ // end of map } // namespace linalg }; // namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/linalg/map_reduce.cuh b/cpp/include/raft/linalg/map_reduce.cuh index 180ed128a1..b89f3bdd54 100644 --- a/cpp/include/raft/linalg/map_reduce.cuh +++ b/cpp/include/raft/linalg/map_reduce.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,11 +24,6 @@ namespace raft::linalg { -/** - * @defgroup map_reduce Map-Reduce ops - * @{ - */ - /** * @brief CUDA version of map and then generic reduction operation * @tparam Type data-type upon which the math operation will be performed @@ -67,6 +62,10 @@ void mapReduce(OutType* out, out, len, neutral, map, op, stream, in, args...); } +/** + * @defgroup map_reduce Map-Reduce ops + * @{ + */ /** * @brief CUDA version of map and then generic reduction operation * @tparam InValueType the data-type of the input @@ -76,7 +75,7 @@ void mapReduce(OutType* out, * @tparam OutValueType the data-type of the output * @tparam ScalarIdxType index type of scalar * @tparam Args additional parameters - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input of type raft::device_vector_view * @param[in] neutral The neutral element of the reduction operation. For example: * 0 for sum, 1 for multiply, +Inf for Min, -Inf for Max @@ -92,7 +91,7 @@ template -void map_reduce(const raft::handle_t& handle, +void map_reduce(raft::device_resources const& handle, raft::device_vector_view in, raft::device_scalar_view out, OutValueType neutral, diff --git a/cpp/include/raft/linalg/matrix_vector.cuh b/cpp/include/raft/linalg/matrix_vector.cuh index 57bc0cf21f..fa24ea28b7 100644 --- a/cpp/include/raft/linalg/matrix_vector.cuh +++ b/cpp/include/raft/linalg/matrix_vector.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,11 @@ namespace raft::linalg { +/** + * @defgroup matrix_vector Matrix-Vector Operations + * @{ + */ + /** * @brief multiply each row or column of matrix with vector, skipping zeros in vector * @param [in] handle: raft handle for managing library resources @@ -32,7 +37,7 @@ namespace raft::linalg { * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_mult_skip_zero(const raft::handle_t& handle, +void binary_mult_skip_zero(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) @@ -65,7 +70,7 @@ void binary_mult_skip_zero(const raft::handle_t& handle, * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_div(const raft::handle_t& handle, +void binary_div(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) @@ -100,7 +105,7 @@ void binary_div(const raft::handle_t& handle, * value if false */ template -void binary_div_skip_zero(const raft::handle_t& handle, +void binary_div_skip_zero(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply, @@ -135,7 +140,7 @@ void binary_div_skip_zero(const raft::handle_t& handle, * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_add(const raft::handle_t& handle, +void binary_add(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) @@ -168,7 +173,7 @@ void binary_add(const raft::handle_t& handle, * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_sub(const raft::handle_t& handle, +void binary_sub(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) @@ -191,4 +196,7 @@ void binary_sub(const raft::handle_t& handle, bcast_along_rows, handle.get_stream()); } + +/** @} */ // end of matrix_vector + } // namespace raft::linalg \ No newline at end of file diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 8b5163a714..59b2ca5ee5 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -122,7 +122,7 @@ void matrixVectorOp(MatT* out, * @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major) * @tparam Lambda a device function which represents a binary operator * @tparam IndexType Integer used for addressing - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] matrix input raft::matrix_view * @param[in] vec vector raft::vector_view * @param[out] out output raft::matrix_view @@ -135,7 +135,7 @@ template -void matrix_vector_op(const raft::handle_t& handle, +void matrix_vector_op(raft::device_resources const& handle, raft::device_matrix_view matrix, raft::device_vector_view vec, raft::device_matrix_view out, @@ -182,7 +182,7 @@ void matrix_vector_op(const raft::handle_t& handle, * @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major) * @tparam Lambda a device function which represents a binary operator * @tparam IndexType Integer used for addressing - * @param handle raft::handle_t + * @param handle raft::device_resources * @param matrix input raft::matrix_view * @param vec1 the first vector raft::vector_view * @param vec2 the second vector raft::vector_view @@ -197,7 +197,7 @@ template -void matrix_vector_op(const raft::handle_t& handle, +void matrix_vector_op(raft::device_resources const& handle, raft::device_matrix_view matrix, raft::device_vector_view vec1, raft::device_vector_view vec2, diff --git a/cpp/include/raft/linalg/mean_squared_error.cuh b/cpp/include/raft/linalg/mean_squared_error.cuh index a3360ae35a..62f4896d01 100644 --- a/cpp/include/raft/linalg/mean_squared_error.cuh +++ b/cpp/include/raft/linalg/mean_squared_error.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,14 +53,14 @@ void meanSquaredError( * @tparam IndexType Input/Output index type * @tparam OutValueType Output data-type * @tparam TPB threads-per-block - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] A input raft::device_vector_view * @param[in] B input raft::device_vector_view * @param[out] out the output mean squared error value of type raft::device_scalar_view * @param[in] weight weight to apply to every term in the mean squared error calculation */ template -void mean_squared_error(const raft::handle_t& handle, +void mean_squared_error(raft::device_resources const& handle, raft::device_vector_view A, raft::device_vector_view B, raft::device_scalar_view out, diff --git a/cpp/include/raft/linalg/multiply.cuh b/cpp/include/raft/linalg/multiply.cuh index 119cf667d1..574b88c63d 100644 --- a/cpp/include/raft/linalg/multiply.cuh +++ b/cpp/include/raft/linalg/multiply.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,7 +56,7 @@ void multiplyScalar(out_t* out, const in_t* in, in_t scalar, IdxType len, cudaSt * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input buffer * @param[out] out the output buffer * @param[in] scalar the scalar used in the operations @@ -68,7 +68,7 @@ template , typename = raft::enable_if_output_device_mdspan> void multiply_scalar( - const raft::handle_t& handle, + raft::device_resources const& handle, InType in, OutType out, raft::host_scalar_view scalar) diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index 9abfd3bdb0..8bc6720b4e 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include "linalg_types.hpp" #include +#include #include #include @@ -47,7 +48,7 @@ namespace linalg { * @param stream cuda stream where to launch work * @param fin_op the final lambda op */ -template > +template void rowNorm(Type* dots, const Type* data, IdxType D, @@ -55,7 +56,7 @@ void rowNorm(Type* dots, NormType type, bool rowMajor, cudaStream_t stream, - Lambda fin_op = raft::Nop()) + Lambda fin_op = raft::identity_op()) { detail::rowNormCaller(dots, data, D, N, type, rowMajor, stream, fin_op); } @@ -74,7 +75,7 @@ void rowNorm(Type* dots, * @param stream cuda stream where to launch work * @param fin_op the final lambda op */ -template > +template void colNorm(Type* dots, const Type* data, IdxType D, @@ -82,18 +83,23 @@ void colNorm(Type* dots, NormType type, bool rowMajor, cudaStream_t stream, - Lambda fin_op = raft::Nop()) + Lambda fin_op = raft::identity_op()) { detail::colNormCaller(dots, data, D, N, type, rowMajor, stream, fin_op); } +/** + * @defgroup norm Row- or Col-norm computation + * @{ + */ + /** * @brief Compute norm of the input matrix and perform fin_op * @tparam ElementType Input/Output data type * @tparam LayoutPolicy the layout of input (raft::row_major or raft::col_major) * @tparam IdxType Integer type used to for addressing * @tparam Lambda device final lambda - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input raft::device_matrix_view * @param[out] out the output raft::device_vector_view * @param[in] type the type of norm to be applied @@ -104,13 +110,13 @@ void colNorm(Type* dots, template > -void norm(const raft::handle_t& handle, + typename Lambda = raft::identity_op> +void norm(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view out, NormType type, Apply apply, - Lambda fin_op = raft::Nop()) + Lambda fin_op = raft::identity_op()) { RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); @@ -142,6 +148,8 @@ void norm(const raft::handle_t& handle, } } +/** @} */ + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/normalize.cuh b/cpp/include/raft/linalg/normalize.cuh index 4bdf697581..027ebb16e8 100644 --- a/cpp/include/raft/linalg/normalize.cuh +++ b/cpp/include/raft/linalg/normalize.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,11 +18,17 @@ #include "detail/normalize.cuh" +#include #include namespace raft { namespace linalg { +/** + * @defgroup norm Row- or Col-norm computation + * @{ + */ + /** * @brief Divide rows by their norm defined by main_op, reduce_op and fin_op * @@ -31,7 +37,7 @@ namespace linalg { * @tparam MainLambda Type of main_op * @tparam ReduceLambda Type of reduce_op * @tparam FinalLambda Type of fin_op - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input raft::device_matrix_view * @param[out] out the output raft::device_matrix_view * @param[in] init Initialization value, i.e identity element for the reduction operation @@ -46,7 +52,7 @@ template -void row_normalize(const raft::handle_t& handle, +void row_normalize(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, ElementType init, @@ -79,14 +85,14 @@ void row_normalize(const raft::handle_t& handle, * * @tparam ElementType Input/Output data type * @tparam IndexType Integer type used to for addressing - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input raft::device_matrix_view * @param[out] out the output raft::device_matrix_view * @param[in] norm_type the type of norm to be applied * @param[in] eps If the norm is below eps, the row is considered zero and no division is applied */ template -void row_normalize(const raft::handle_t& handle, +void row_normalize(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, NormType norm_type, @@ -94,38 +100,22 @@ void row_normalize(const raft::handle_t& handle, { switch (norm_type) { case L1Norm: - row_normalize(handle, - in, - out, - ElementType(0), - raft::L1Op(), - raft::Sum(), - raft::Nop(), - eps); + row_normalize( + handle, in, out, ElementType(0), raft::abs_op(), raft::add_op(), raft::identity_op(), eps); break; case L2Norm: - row_normalize(handle, - in, - out, - ElementType(0), - raft::L2Op(), - raft::Sum(), - raft::SqrtOp(), - eps); + row_normalize( + handle, in, out, ElementType(0), raft::sq_op(), raft::add_op(), raft::sqrt_op(), eps); break; case LinfNorm: - row_normalize(handle, - in, - out, - ElementType(0), - raft::L1Op(), - raft::Max(), - raft::Nop(), - eps); + row_normalize( + handle, in, out, ElementType(0), raft::abs_op(), raft::max_op(), raft::identity_op(), eps); break; default: THROW("Unsupported norm type: %d", norm_type); } } +/** @} */ + } // namespace linalg } // namespace raft diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index acd226b71d..1fdfcb3780 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,9 +19,9 @@ #pragma once #include +#include #include #include -#include #include namespace raft { @@ -41,8 +41,7 @@ namespace linalg { template void powerScalar(out_t* out, const in_t* in, const in_t scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp( - out, in, len, [scalar] __device__(in_t in) { return raft::myPow(in, scalar); }, stream); + raft::linalg::unaryOp(out, in, len, raft::pow_const_op(scalar), stream); } /** @} */ @@ -61,8 +60,7 @@ void powerScalar(out_t* out, const in_t* in, const in_t scalar, IdxType len, cud template void power(out_t* out, const in_t* in1, const in_t* in2, IdxType len, cudaStream_t stream) { - raft::linalg::binaryOp( - out, in1, in2, len, [] __device__(in_t a, in_t b) { return raft::myPow(a, b); }, stream); + raft::linalg::binaryOp(out, in1, in2, len, raft::pow_op(), stream); } /** @} */ @@ -75,7 +73,7 @@ void power(out_t* out, const in_t* in1, const in_t* in2, IdxType len, cudaStream * @brief Elementwise power operation on the input buffers * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in1 First Input * @param[in] in2 Second Input * @param[out] out Output @@ -84,7 +82,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void power(const raft::handle_t& handle, InType in1, InType in2, OutType out) +void power(raft::device_resources const& handle, InType in1, InType in2, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -115,7 +113,7 @@ void power(const raft::handle_t& handle, InType in1, InType in2, OutType out) * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[out] out Output * @param[in] scalar raft::host_scalar_view @@ -126,7 +124,7 @@ template , typename = raft::enable_if_output_device_mdspan> void power_scalar( - const raft::handle_t& handle, + raft::device_resources const& handle, InType in, OutType out, const raft::host_scalar_view scalar) diff --git a/cpp/include/raft/linalg/qr.cuh b/cpp/include/raft/linalg/qr.cuh index 7e6e14e680..8e58af63c1 100644 --- a/cpp/include/raft/linalg/qr.cuh +++ b/cpp/include/raft/linalg/qr.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,11 +23,6 @@ namespace raft { namespace linalg { -/** - * @defgroup QRdecomp QR decomposition - * @{ - */ - /** * @brief compute QR decomp and return only Q matrix * @param handle: raft handle @@ -38,7 +33,7 @@ namespace linalg { * @param stream cuda stream */ template -void qrGetQ(const raft::handle_t& handle, +void qrGetQ(raft::device_resources const& handle, const math_t* M, math_t* Q, int n_rows, @@ -59,7 +54,7 @@ void qrGetQ(const raft::handle_t& handle, * @param stream cuda stream */ template -void qrGetQR(const raft::handle_t& handle, +void qrGetQR(raft::device_resources const& handle, math_t* M, math_t* Q, math_t* R, @@ -70,14 +65,19 @@ void qrGetQR(const raft::handle_t& handle, detail::qrGetQR(handle, M, Q, R, n_rows, n_cols, stream); } +/** + * @defgroup qr QR Decomposition + * @{ + */ + /** * @brief Compute the QR decomposition of matrix M and return only the Q matrix. - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M Input raft::device_matrix_view * @param[out] Q Output raft::device_matrix_view */ template -void qr_get_q(const raft::handle_t& handle, +void qr_get_q(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_matrix_view Q) { @@ -88,13 +88,13 @@ void qr_get_q(const raft::handle_t& handle, /** * @brief Compute the QR decomposition of matrix M and return both the Q and R matrices. - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M Input raft::device_matrix_view * @param[in] Q Output raft::device_matrix_view * @param[out] R Output raft::device_matrix_view */ template -void qr_get_qr(const raft::handle_t& handle, +void qr_get_qr(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_matrix_view Q, raft::device_matrix_view R) diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index 5579acf355..ae5457c44f 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include "linalg_types.hpp" #include +#include #include namespace raft { @@ -59,9 +60,9 @@ namespace linalg { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void reduce(OutType* dots, const InType* data, IdxType D, @@ -71,9 +72,9 @@ void reduce(OutType* dots, bool alongRows, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { detail::reduce( dots, data, D, N, init, rowMajor, alongRows, stream, inplace, main_op, reduce_op, final_op); @@ -104,7 +105,7 @@ void reduce(OutType* dots, * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) * It must be a 'callable' supporting the following input and output: *
OutType (*FinalLambda)(OutType);
- * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] data Input of type raft::device_matrix_view * @param[out] dots Output of type raft::device_matrix_view * @param[in] init initial value to use for the reduction @@ -118,18 +119,18 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> -void reduce(const raft::handle_t& handle, + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> +void reduce(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutElementType init, Apply apply, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { RAFT_EXPECTS(raft::is_row_or_column_major(data), "Input must be contiguous"); diff --git a/cpp/include/raft/linalg/reduce_cols_by_key.cuh b/cpp/include/raft/linalg/reduce_cols_by_key.cuh index 436fce26fd..2b744d8134 100644 --- a/cpp/include/raft/linalg/reduce_cols_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_cols_by_key.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include "detail/reduce_cols_by_key.cuh" #include -#include +#include namespace raft { namespace linalg { @@ -43,6 +43,7 @@ namespace linalg { * @param ncols number of columns in the input data * @param nkeys number of unique keys in the keys array * @param stream cuda stream to launch the kernel onto + * @param reset_sums Whether to reset the output sums to zero before reducing */ template void reduce_cols_by_key(const T* data, @@ -51,9 +52,10 @@ void reduce_cols_by_key(const T* data, IdxType nrows, IdxType ncols, IdxType nkeys, - cudaStream_t stream) + cudaStream_t stream, + bool reset_sums = true) { - detail::reduce_cols_by_key(data, keys, out, nrows, ncols, nkeys, stream); + detail::reduce_cols_by_key(data, keys, out, nrows, ncols, nkeys, stream, reset_sums); } /** @@ -67,7 +69,7 @@ void reduce_cols_by_key(const T* data, * @tparam ElementType the input data type (as well as the output reduced matrix) * @tparam KeyType data type of the keys * @tparam IndexType indexing arithmetic type - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] data the input data (dim = nrows x ncols). This is assumed to be in * row-major layout of type raft::device_matrix_view * @param[in] keys keys raft::device_vector_view (len = ncols). It is assumed that each key in this @@ -76,18 +78,26 @@ void reduce_cols_by_key(const T* data, * monotonically increasing keys array. * @param[out] out the output reduced raft::device_matrix_view along columns (dim = nrows x nkeys). * This will be assumed to be in row-major layout - * @param[in] nkeys number of unique keys in the keys array + * @param[in] nkeys Number of unique keys in the keys array. By default, inferred from the number of + * columns of out + * @param[in] reset_sums Whether to reset the output sums to zero before reducing */ template void reduce_cols_by_key( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view keys, raft::device_matrix_view out, - IndexType nkeys) + IndexType nkeys = 0, + bool reset_sums = true) { - RAFT_EXPECTS(out.extent(0) == data.extent(0) && out.extent(1) == nkeys, - "Output is not of size nrows * nkeys"); + if (nkeys > 0) { + RAFT_EXPECTS(out.extent(1) == nkeys, "Output doesn't have nkeys columns"); + } else { + nkeys = out.extent(1); + } + RAFT_EXPECTS(out.extent(0) == data.extent(0), + "Output doesn't have the same number of rows as input"); RAFT_EXPECTS(keys.extent(0) == data.extent(1), "Keys is not of size ncols"); reduce_cols_by_key(data.data_handle(), @@ -96,7 +106,8 @@ void reduce_cols_by_key( data.extent(0), data.extent(1), nkeys, - handle.get_stream()); + handle.get_stream(), + reset_sums); } /** @} */ // end of group reduce_cols_by_key diff --git a/cpp/include/raft/linalg/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/reduce_rows_by_key.cuh index 1dabd92087..484b60238b 100644 --- a/cpp/include/raft/linalg/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_rows_by_key.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include "detail/reduce_rows_by_key.cuh" #include -#include +#include namespace raft { namespace linalg { @@ -136,7 +136,7 @@ void reduce_rows_by_key(const DataIteratorT d_A, * @tparam KeyType data-type of keys * @tparam WeightType data-type of weights * @tparam IndexType index type - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] d_A Input raft::device_mdspan (ncols * nrows) * @param[in] d_keys Keys for each row raft::device_vector_view (1 x nrows) * @param[out] d_sums Row sums by key raft::device_matrix_view (ncols x d_keys) @@ -148,7 +148,7 @@ void reduce_rows_by_key(const DataIteratorT d_A, */ template void reduce_rows_by_key( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view d_A, raft::device_vector_view d_keys, raft::device_matrix_view d_sums, diff --git a/cpp/include/raft/linalg/rsvd.cuh b/cpp/include/raft/linalg/rsvd.cuh index 6f0315642b..eb94547f13 100644 --- a/cpp/include/raft/linalg/rsvd.cuh +++ b/cpp/include/raft/linalg/rsvd.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ namespace linalg { * @param stream cuda stream */ template -void rsvdFixedRank(const raft::handle_t& handle, +void rsvdFixedRank(raft::device_resources const& handle, math_t* M, int n_rows, int n_cols, @@ -104,7 +104,7 @@ void rsvdFixedRank(const raft::handle_t& handle, * @param stream cuda stream */ template -void rsvdPerc(const raft::handle_t& handle, +void rsvdPerc(raft::device_resources const& handle, math_t* M, int n_rows, int n_cols, @@ -154,7 +154,7 @@ void rsvdPerc(const raft::handle_t& handle, * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -164,7 +164,7 @@ void rsvdPerc(const raft::handle_t& handle, * raft::col_major */ template -void rsvd_fixed_rank(const raft::handle_t& handle, +void rsvd_fixed_rank(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -228,7 +228,7 @@ void rsvd_fixed_rank(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -239,7 +239,7 @@ void rsvd_fixed_rank(Args... args) */ template void rsvd_fixed_rank_symmetric( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -303,7 +303,7 @@ void rsvd_fixed_rank_symmetric(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -315,7 +315,7 @@ void rsvd_fixed_rank_symmetric(Args... args) * raft::col_major */ template -void rsvd_fixed_rank_jacobi(const raft::handle_t& handle, +void rsvd_fixed_rank_jacobi(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -381,7 +381,7 @@ void rsvd_fixed_rank_jacobi(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -394,7 +394,7 @@ void rsvd_fixed_rank_jacobi(Args... args) */ template void rsvd_fixed_rank_symmetric_jacobi( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -460,7 +460,7 @@ void rsvd_fixed_rank_symmetric_jacobi(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -471,7 +471,7 @@ void rsvd_fixed_rank_symmetric_jacobi(Args... args) * raft::col_major */ template -void rsvd_perc(const raft::handle_t& handle, +void rsvd_perc(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, @@ -536,7 +536,7 @@ void rsvd_perc(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -547,7 +547,7 @@ void rsvd_perc(Args... args) * raft::col_major */ template -void rsvd_perc_symmetric(const raft::handle_t& handle, +void rsvd_perc_symmetric(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, @@ -612,7 +612,7 @@ void rsvd_perc_symmetric(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -625,7 +625,7 @@ void rsvd_perc_symmetric(Args... args) * raft::col_major */ template -void rsvd_perc_jacobi(const raft::handle_t& handle, +void rsvd_perc_jacobi(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, @@ -692,7 +692,7 @@ void rsvd_perc_jacobi(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -706,7 +706,7 @@ void rsvd_perc_jacobi(Args... args) */ template void rsvd_perc_symmetric_jacobi( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index 2951285c3a..55e661897d 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,8 @@ #pragma once #include +#include #include -#include namespace raft { namespace linalg { @@ -38,8 +38,7 @@ namespace linalg { template void sqrt(out_t* out, const in_t* in, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp( - out, in, len, [] __device__(in_t in) { return raft::mySqrt(in); }, stream); + raft::linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream); } /** @} */ @@ -52,7 +51,7 @@ void sqrt(out_t* out, const in_t* in, IdxType len, cudaStream_t stream) * @brief Elementwise sqrt operation * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[out] out Output */ @@ -60,7 +59,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void sqrt(const raft::handle_t& handle, InType in, OutType out) +void sqrt(raft::device_resources const& handle, InType in, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index 0aa4aecef5..d282a2e1fa 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include "detail/strided_reduction.cuh" #include +#include #include #include @@ -59,9 +60,9 @@ namespace linalg { template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> void stridedReduction(OutType* dots, const InType* data, IdxType D, @@ -69,9 +70,9 @@ void stridedReduction(OutType* dots, OutType init, cudaStream_t stream, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { // Only compile for types supported by myAtomicReduce, but don't make the compilation fail in // other cases, because coalescedReduction supports arbitrary types. @@ -111,7 +112,7 @@ void stridedReduction(OutType* dots, * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) * It must be a 'callable' supporting the following input and output: *
OutType (*FinalLambda)(OutType);
- * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] data Input of type raft::device_matrix_view * @param[out] dots Output of type raft::device_matrix_view * @param[in] init initial value to use for the reduction @@ -124,17 +125,17 @@ template , - typename ReduceLambda = raft::Sum, - typename FinalLambda = raft::Nop> -void strided_reduction(const raft::handle_t& handle, + typename MainLambda = raft::identity_op, + typename ReduceLambda = raft::add_op, + typename FinalLambda = raft::identity_op> +void strided_reduction(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutValueType init, bool inplace = false, - MainLambda main_op = raft::Nop(), - ReduceLambda reduce_op = raft::Sum(), - FinalLambda final_op = raft::Nop()) + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) { if constexpr (std::is_same_v) { RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), diff --git a/cpp/include/raft/linalg/subtract.cuh b/cpp/include/raft/linalg/subtract.cuh index e6f2fa8724..da995b7a2a 100644 --- a/cpp/include/raft/linalg/subtract.cuh +++ b/cpp/include/raft/linalg/subtract.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -97,7 +97,7 @@ void subtractDevScalar(math_t* outDev, * @brief Elementwise subtraction operation on the input buffers * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param handle raft::handle_t + * @param handle raft::device_resources * @param[in] in1 First Input * @param[in] in2 Second Input * @param[out] out Output @@ -106,7 +106,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void subtract(const raft::handle_t& handle, InType in1, InType in2, OutType out) +void subtract(raft::device_resources const& handle, InType in1, InType in2, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -137,7 +137,7 @@ void subtract(const raft::handle_t& handle, InType in1, InType in2, OutType out) * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[out] out Output * @param[in] scalar raft::device_scalar_view @@ -148,7 +148,7 @@ template , typename = raft::enable_if_output_device_mdspan> void subtract_scalar( - const raft::handle_t& handle, + raft::device_resources const& handle, InType in, OutType out, raft::device_scalar_view scalar) @@ -182,7 +182,7 @@ void subtract_scalar( * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[out] out Output * @param[in] scalar raft::host_scalar_view @@ -193,7 +193,7 @@ template , typename = raft::enable_if_output_device_mdspan> void subtract_scalar( - const raft::handle_t& handle, + raft::device_resources const& handle, InType in, OutType out, raft::host_scalar_view scalar) diff --git a/cpp/include/raft/linalg/svd.cuh b/cpp/include/raft/linalg/svd.cuh index 7be1b9d63c..eb51093240 100644 --- a/cpp/include/raft/linalg/svd.cuh +++ b/cpp/include/raft/linalg/svd.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ namespace linalg { * @param stream cuda stream */ template -void svdQR(const raft::handle_t& handle, +void svdQR(raft::device_resources const& handle, T* in, int n_rows, int n_cols, @@ -66,14 +66,14 @@ void svdQR(const raft::handle_t& handle, stream); } -template -void svdEig(const raft::handle_t& handle, - T* in, - int n_rows, - int n_cols, - T* S, - T* U, - T* V, +template +void svdEig(raft::device_resources const& handle, + math_t* in, + idx_t n_rows, + idx_t n_cols, + math_t* S, + math_t* U, + math_t* V, bool gen_left_vec, cudaStream_t stream) { @@ -98,7 +98,7 @@ void svdEig(const raft::handle_t& handle, * @param stream cuda stream */ template -void svdJacobi(const raft::handle_t& handle, +void svdJacobi(raft::device_resources const& handle, math_t* in, int n_rows, int n_cols, @@ -139,7 +139,7 @@ void svdJacobi(const raft::handle_t& handle, * @param stream cuda stream */ template -void svdReconstruction(const raft::handle_t& handle, +void svdReconstruction(raft::device_resources const& handle, math_t* U, math_t* S, math_t* V, @@ -167,7 +167,7 @@ void svdReconstruction(const raft::handle_t& handle, * @param stream cuda stream */ template -bool evaluateSVDByL2Norm(const raft::handle_t& handle, +bool evaluateSVDByL2Norm(raft::device_resources const& handle, math_t* A_d, math_t* U, math_t* S_vec, @@ -195,7 +195,7 @@ bool evaluateSVDByL2Norm(const raft::handle_t& handle, * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] sing_vals singular values raft::device_vector_view of shape (K) * @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout @@ -204,7 +204,7 @@ bool evaluateSVDByL2Norm(const raft::handle_t& handle, * layout raft::col_major and dimensions (n, n) */ template -void svd_qr(const raft::handle_t& handle, +void svd_qr(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view sing_vals, UType&& U_in, @@ -258,7 +258,7 @@ void svd_qr(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] sing_vals singular values raft::device_vector_view of shape (K) * @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout @@ -268,7 +268,7 @@ void svd_qr(Args... args) */ template void svd_qr_transpose_right_vec( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view sing_vals, UType&& U_in, @@ -316,7 +316,7 @@ void svd_qr_transpose_right_vec(Args... args) /** * @brief singular value decomposition (SVD) on a column major * matrix using Eigen decomposition. A square symmetric covariance matrix is constructed for the SVD - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S singular values raft::device_vector_view of shape (K) * @param[out] V right singular values of raft::device_matrix_view with layout @@ -326,7 +326,7 @@ void svd_qr_transpose_right_vec(Args... args) */ template void svd_eig( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view S, raft::device_matrix_view V, @@ -352,7 +352,7 @@ void svd_eig( /** * @brief reconstruct a matrix use left and right singular vectors and * singular values - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] U left singular values of raft::device_matrix_view with layout * raft::col_major and dimensions (m, k) * @param[in] S singular values raft::device_vector_view of shape (k, k) @@ -361,7 +361,7 @@ void svd_eig( * @param[out] out output raft::device_matrix_view with layout raft::col_major of shape (m, n) */ template -void svd_reconstruction(const raft::handle_t& handle, +void svd_reconstruction(raft::device_resources const& handle, raft::device_matrix_view U, raft::device_vector_view S, raft::device_matrix_view V, diff --git a/cpp/include/raft/linalg/ternary_op.cuh b/cpp/include/raft/linalg/ternary_op.cuh index 10e91a0313..aa3859bc23 100644 --- a/cpp/include/raft/linalg/ternary_op.cuh +++ b/cpp/include/raft/linalg/ternary_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ #include "detail/ternary_op.cuh" #include -#include +#include #include namespace raft { @@ -63,7 +63,7 @@ void ternaryOp(out_t* out, * @tparam InType Input Type raft::device_mdspan * @tparam Lambda the device-lambda performing the actual operation * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in1 First input * @param[in] in2 Second input * @param[in] in3 Third input @@ -78,7 +78,7 @@ template , typename = raft::enable_if_output_device_mdspan> void ternary_op( - const raft::handle_t& handle, InType in1, InType in2, InType in3, OutType out, Lambda op) + raft::device_resources const& handle, InType in1, InType in2, InType in3, OutType out, Lambda op) { RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous"); diff --git a/cpp/include/raft/linalg/transpose.cuh b/cpp/include/raft/linalg/transpose.cuh index e765ea7925..a0f418b4f7 100644 --- a/cpp/include/raft/linalg/transpose.cuh +++ b/cpp/include/raft/linalg/transpose.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace linalg { * @param stream: cuda stream */ template -void transpose(const raft::handle_t& handle, +void transpose(raft::device_resources const& handle, math_t* in, math_t* out, int n_rows, @@ -56,6 +56,11 @@ void transpose(math_t* inout, int n, cudaStream_t stream) detail::transpose(inout, n, stream); } +/** + * @defgroup transpose Matrix transpose + * @{ + */ + /** * @brief Transpose a matrix. The output has same layout policy as the input. * @@ -71,7 +76,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream) * @param[out] out Output matirx, storage is pre-allocated by caller. */ template -auto transpose(handle_t const& handle, +auto transpose(raft::device_resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) -> std::enable_if_t, void> @@ -94,6 +99,9 @@ auto transpose(handle_t const& handle, } } } + +/** @} */ // end of group transpose + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/unary_op.cuh b/cpp/include/raft/linalg/unary_op.cuh index a90bda06d5..ce102adfd2 100644 --- a/cpp/include/raft/linalg/unary_op.cuh +++ b/cpp/include/raft/linalg/unary_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include "detail/unary_op.cuh" #include -#include +#include #include namespace raft { @@ -30,17 +30,16 @@ namespace linalg { /** * @brief perform element-wise unary operation in the input array * @tparam InType input data-type - * @tparam Lambda the device-lambda performing the actual operation + * @tparam Lambda Device lambda performing the actual operation, with the signature + * `OutType func(const InType& val);` * @tparam OutType output data-type * @tparam IdxType Integer type used to for addressing * @tparam TPB threads-per-block in the final kernel launched - * @param out the output array - * @param in the input array - * @param len number of elements in the input array - * @param op the device-lambda - * @param stream cuda stream where to launch work - * @note Lambda must be a functor with the following signature: - * `OutType func(const InType& val);` + * @param[out] out Output array [on device], dim = [len] + * @param[in] in Input array [on device], dim = [len] + * @param[in] len Number of elements in the input array + * @param[in] op Device lambda + * @param[in] stream cuda stream where to launch work */ template @@ -81,23 +80,22 @@ void writeOnlyUnaryOp(OutType* out, IdxType len, Lambda op, cudaStream_t stream) */ /** - * @brief perform element-wise binary operation on the input arrays + * @brief Perform an element-wise unary operation into the output array * @tparam InType Input Type raft::device_mdspan - * @tparam Lambda the device-lambda performing the actual operation + * @tparam Lambda Device lambda performing the actual operation, with the signature + * `out_value_t func(const in_value_t& val);` * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t - * @param[in] in Input - * @param[out] out Output - * @param[in] op the device-lambda - * @note Lambda must be a functor with the following signature: - * `InType func(const InType& val);` + * @param[in] handle The raft handle + * @param[in] in Input + * @param[out] out Output + * @param[in] op Device lambda */ template , typename = raft::enable_if_output_device_mdspan> -void unary_op(const raft::handle_t& handle, InType in, OutType out, Lambda op) +void unary_op(raft::device_resources const& handle, InType in, OutType out, Lambda op) { RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); @@ -116,29 +114,32 @@ void unary_op(const raft::handle_t& handle, InType in, OutType out, Lambda op) } /** - * @brief perform element-wise binary operation on the input arrays - * This function does not read from the input - * @tparam InType Input Type raft::device_mdspan - * @tparam Lambda the device-lambda performing the actual operation - * @param[in] handle raft::handle_t - * @param[inout] in Input/Output - * @param[in] op the device-lambda - * @note Lambda must be a functor with the following signature: - * `InType func(const InType& val);` + * @brief Perform an element-wise unary operation on the input index into the output array + * + * @note This operation is deprecated. Please use map_offset in `raft/linalg/map.cuh` instead. + * + * @tparam OutType Output Type raft::device_mdspan + * @tparam Lambda Device lambda performing the actual operation, with the signature + * `void func(out_value_t* out_location, index_t idx);` + * @param[in] handle The raft handle + * @param[out] out Output + * @param[in] op Device lambda */ -template > -void write_only_unary_op(const raft::handle_t& handle, InType in, Lambda op) +template > +void write_only_unary_op(const raft::device_resources& handle, OutType out, Lambda op) { - RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); - using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; - if (in.size() <= std::numeric_limits::max()) { - writeOnlyUnaryOp( - in.data_handle(), in.size(), op, handle.get_stream()); + if (out.size() <= std::numeric_limits::max()) { + writeOnlyUnaryOp( + out.data_handle(), out.size(), op, handle.get_stream()); } else { - writeOnlyUnaryOp( - in.data_handle(), in.size(), op, handle.get_stream()); + writeOnlyUnaryOp( + out.data_handle(), out.size(), op, handle.get_stream()); } } @@ -147,4 +148,4 @@ void write_only_unary_op(const raft::handle_t& handle, InType in, Lambda op) }; // end namespace linalg }; // end namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index e6736b14de..433c161079 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,11 @@ namespace raft::matrix { +/** + * @defgroup argmax Argmax operation + * @{ + */ + /** * @brief Argmax: find the col idx with maximum value for each row * @param[in] handle: raft handle @@ -28,7 +33,7 @@ namespace raft::matrix { * @param[out] out: output vector of size n_rows */ template -void argmax(const raft::handle_t& handle, +void argmax(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view out) { @@ -37,4 +42,7 @@ void argmax(const raft::handle_t& handle, detail::argmax( in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream()); } + +/** @} */ // end of group argmax + } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/argmin.cuh b/cpp/include/raft/matrix/argmin.cuh index e8cf763f70..31ef0c1c1b 100644 --- a/cpp/include/raft/matrix/argmin.cuh +++ b/cpp/include/raft/matrix/argmin.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,11 @@ namespace raft::matrix { +/** + * @defgroup argmin Argmin operation + * @{ + */ + /** * @brief Argmin: find the col idx with minimum value for each row * @param[in] handle: raft handle @@ -28,7 +33,7 @@ namespace raft::matrix { * @param[out] out: output vector of size n_rows */ template -void argmin(const raft::handle_t& handle, +void argmin(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view out) { @@ -37,4 +42,7 @@ void argmin(const raft::handle_t& handle, detail::argmin( in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream()); } + +/** @} */ // end of group argmin + } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index 5f9b3ab848..a4daf097e5 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,6 +52,11 @@ void sort_cols_per_row(const InType* in, in, out, n_rows, n_columns, bAllocWorkspace, workspacePtr, workspaceSize, stream, sortedKeys); } +/** + * @defgroup col_wise_sort Sort rows within each column + * @{ + */ + /** * @brief sort columns within each row of row-major input matrix and return sorted indexes * modelled as key-value sort with key being input matrix and value being index of values @@ -66,7 +71,7 @@ void sort_cols_per_row(const InType* in, * @param[out] sorted_keys_opt: std::optional, output matrix for sorted keys (input) */ template -void sort_cols_per_row(const raft::handle_t& handle, +void sort_cols_per_row(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, sorted_keys_t&& sorted_keys_opt) @@ -126,6 +131,8 @@ void sort_cols_per_row(Args... args) sort_cols_per_row(std::forward(args)..., std::nullopt); } +/** @} */ // end of group col_wise_sort + }; // end namespace raft::matrix #endif \ No newline at end of file diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index 5f1d16485c..42d2562e5e 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,11 @@ namespace raft::matrix { +/** + * @defgroup matrix_copy Matrix copy operations + * @{ + */ + /** * @brief Copy selected rows of the input matrix into contiguous space. * @@ -34,7 +39,7 @@ namespace raft::matrix { * @param[in] indices of the rows to be copied */ template -void copy_rows(const raft::handle_t& handle, +void copy_rows(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, raft::device_vector_view indices) @@ -60,7 +65,7 @@ void copy_rows(const raft::handle_t& handle, * @param[out] out: output matrix */ template -void copy(const raft::handle_t& handle, +void copy(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { @@ -79,7 +84,7 @@ void copy(const raft::handle_t& handle, * @param out: output matrix */ template -void trunc_zero_origin(const raft::handle_t& handle, +void trunc_zero_origin(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { @@ -94,4 +99,6 @@ void trunc_zero_origin(const raft::handle_t& handle, handle.get_stream()); } +/** @} */ // end of group matrix_copy + } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 3738afba5d..f6dc60bf85 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,40 +16,64 @@ #pragma once +#include +#include + namespace raft { namespace matrix { namespace detail { -// gatherKernel conditionally copies rows from the source matrix 'in' into the destination matrix -// 'out' according to a map (or a transformed map) -template +struct gather_policy { + static constexpr int n_threads = tpb; + static constexpr int work_per_thread = wpt; + static constexpr int stride = tpb * wpt; +}; + +/** Conditionally copies rows from the source matrix 'in' into the destination matrix + * 'out' according to a map (or a transformed map) */ +template -__global__ void gatherKernel(const MatrixIteratorT in, - IndexT D, - IndexT N, - MapIteratorT map, - StencilIteratorT stencil, - MatrixIteratorT out, - PredicateOp pred_op, - MapTransformOp transform_op) + typename OutputIteratorT, + typename IndexT> +__global__ void gather_kernel(const InputIteratorT in, + IndexT D, + IndexT len, + const MapIteratorT map, + StencilIteratorT stencil, + OutputIteratorT out, + PredicateOp pred_op, + MapTransformOp transform_op) { typedef typename std::iterator_traits::value_type MapValueT; typedef typename std::iterator_traits::value_type StencilValueT; - IndexT outRowStart = blockIdx.x * D; - MapValueT map_val = map[blockIdx.x]; - StencilValueT stencil_val = stencil[blockIdx.x]; +#pragma unroll + for (IndexT wid = 0; wid < Policy::work_per_thread; wid++) { + IndexT tid = threadIdx.x + (Policy::work_per_thread * static_cast(blockIdx.x) + wid) * + Policy::n_threads; + if (tid < len) { + IndexT i_dst = tid / D; + IndexT j = tid % D; + + MapValueT map_val = map[i_dst]; + StencilValueT stencil_val = stencil[i_dst]; - bool predicate = pred_op(stencil_val); - if (predicate) { - IndexT inRowStart = transform_op(map_val) * D; - for (int i = threadIdx.x; i < D; i += TPB) { - out[outRowStart + i] = in[inRowStart + i]; + bool predicate = pred_op(stencil_val); + if (predicate) { + IndexT i_src = transform_op(map_val); + out[tid] = in[i_src * D + j]; + } } } } @@ -58,7 +82,7 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @brief gather conditionally copies rows from a source matrix into a destination matrix according * to a transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -67,7 +91,10 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -81,18 +108,20 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gatherImpl(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gatherImpl(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) @@ -100,9 +129,6 @@ void gatherImpl(const MatrixIteratorT in, // skip in case of 0 length input if (map_length <= 0 || N <= 0 || D <= 0) return; - // signed integer type for indexing or global offsets - typedef int IndexT; - // map value type typedef typename std::iterator_traits::value_type MapValueT; @@ -119,38 +145,26 @@ void gatherImpl(const MatrixIteratorT in, static_assert((std::is_convertible::value), "UnaryPredicateOp's result type must be convertible to bool type"); - if (D <= 32) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); - } else if (D <= 64) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); - } else if (D <= 128) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + IndexT len = map_length * D; + constexpr int TPB = 128; + const int n_sm = raft::getMultiProcessorCount(); + // The following empirical heuristics enforce that we keep a good balance between having enough + // blocks and enough work per thread. + if (len < static_cast(32 * TPB * n_sm)) { + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); + } else if (len < static_cast(32 * 4 * TPB * n_sm)) { + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); } else { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -158,10 +172,13 @@ void gatherImpl(const MatrixIteratorT in, /** * @brief gather copies rows from a source matrix into a destination matrix according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -172,39 +189,33 @@ void gatherImpl(const MatrixIteratorT in, * @param out Pointer to the output matrix (assumed to be row-major) * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; gatherImpl( - in, - D, - N, - map, - map, - map_length, - out, - [] __device__(MapValueT val) { return true; }, - [] __device__(MapValueT val) { return val; }, - stream); + in, D, N, map, map, map_length, out, raft::const_op(true), raft::identity_op(), stream); } /** * @brief gather copies rows from a source matrix into a destination matrix according to a * transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -216,35 +227,29 @@ void gather(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, MapTransformOp transform_op, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; - gatherImpl( - in, - D, - N, - map, - map, - map_length, - out, - [] __device__(MapValueT val) { return true; }, - transform_op, - stream); + gatherImpl(in, D, N, map, map, map_length, out, raft::const_op(true), transform_op, stream); } /** * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -252,6 +257,9 @@ void gather(const MatrixIteratorT in, * simple pointer type). * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -264,39 +272,31 @@ void gather(const MatrixIteratorT in, * @param pred_op Predicate to apply to the stencil values * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename UnaryPredicateOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; - gatherImpl( - in, - D, - N, - map, - stencil, - map_length, - out, - pred_op, - [] __device__(MapValueT val) { return val; }, - stream); + gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, raft::identity_op(), stream); } /** * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -305,7 +305,10 @@ void gather_if(const MatrixIteratorT in, * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT type. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -319,18 +322,20 @@ void gather_if(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) diff --git a/cpp/include/raft/matrix/detail/linewise_op.cuh b/cpp/include/raft/matrix/detail/linewise_op.cuh index 605726bea6..ef8f0e88c1 100644 --- a/cpp/include/raft/matrix/detail/linewise_op.cuh +++ b/cpp/include/raft/matrix/detail/linewise_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -796,7 +796,8 @@ struct MatrixLinewiseOp { "layout for in and out must be either padded row or col major"); // also statically assert padded matrix alignment == 2^i*VecBytes - assert(raft::Pow2::areSameAlignOffsets(in, out)); + RAFT_EXPECTS(raft::Pow2::areSameAlignOffsets(in.data_handle(), out.data_handle()), + "The matrix views in and out does not have correct alignment"); if (alongLines) return matrixLinewiseVecRowsSpan +#include #include +#include #include #include #include @@ -86,10 +87,10 @@ void seqRoot(math_t* in, if (a < math_t(0)) { return math_t(0); } else { - return sqrt(a * scalar); + return raft::sqrt(a * scalar); } } else { - return sqrt(a * scalar); + return raft::sqrt(a * scalar); } }, stream); @@ -188,21 +189,19 @@ void reciprocal(math_t* in, math_t* out, IdxType len, cudaStream_t stream) template void setValue(math_t* out, const math_t* in, math_t scalar, int len, cudaStream_t stream = 0) { - raft::linalg::unaryOp( - out, in, len, [scalar] __device__(math_t in) { return scalar; }, stream); + raft::linalg::unaryOp(out, in, len, raft::const_op(scalar), stream); } template void ratio( - const raft::handle_t& handle, math_t* src, math_t* dest, IdxType len, cudaStream_t stream) + raft::device_resources const& handle, math_t* src, math_t* dest, IdxType len, cudaStream_t stream) { auto d_src = src; auto d_dest = dest; rmm::device_scalar d_sum(stream); auto* d_sum_ptr = d_sum.data(); - auto no_op = [] __device__(math_t in) { return in; }; - raft::linalg::mapThenSumReduce(d_sum_ptr, len, no_op, stream, src); + raft::linalg::mapThenSumReduce(d_sum_ptr, len, raft::identity_op{}, stream, src); raft::linalg::unaryOp( d_dest, d_src, len, [=] __device__(math_t a) { return a / (*d_sum_ptr); }, stream); } @@ -217,15 +216,7 @@ void matrixVectorBinaryMult(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a * b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::mul_op(), stream); } template @@ -264,15 +255,7 @@ void matrixVectorBinaryDiv(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a / b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::div_op(), stream); } template @@ -295,7 +278,7 @@ void matrixVectorBinaryDivSkipZero(Type* data, rowMajor, bcastAlongRows, [] __device__(Type a, Type b) { - if (raft::myAbs(b) < Type(1e-10)) + if (raft::abs(b) < Type(1e-10)) return Type(0); else return a / b; @@ -311,7 +294,7 @@ void matrixVectorBinaryDivSkipZero(Type* data, rowMajor, bcastAlongRows, [] __device__(Type a, Type b) { - if (raft::myAbs(b) < Type(1e-10)) + if (raft::abs(b) < Type(1e-10)) return a; else return a / b; @@ -330,15 +313,7 @@ void matrixVectorBinaryAdd(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a + b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::add_op(), stream); } template @@ -351,15 +326,7 @@ void matrixVectorBinarySub(Type* data, cudaStream_t stream) { raft::linalg::matrixVectorOp( - data, - data, - vec, - n_col, - n_row, - rowMajor, - bcastAlongRows, - [] __device__(Type a, Type b) { return a - b; }, - stream); + data, data, vec, n_col, n_row, rowMajor, bcastAlongRows, raft::sub_op(), stream); } // Computes an argmin/argmax column-wise in a DxN matrix diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index 17a40be5d6..ef3a873d90 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -299,7 +299,7 @@ void getDiagonalInverseMatrix(m_t* in, idx_t len, cudaStream_t stream) } template -m_t getL2Norm(const raft::handle_t& handle, const m_t* in, idx_t size, cudaStream_t stream) +m_t getL2Norm(raft::device_resources const& handle, const m_t* in, idx_t size, cudaStream_t stream) { cublasHandle_t cublasH = handle.get_cublas_handle(); m_t normval = 0; diff --git a/cpp/include/raft/matrix/detail/print.hpp b/cpp/include/raft/matrix/detail/print.hpp index fc3d14861c..814c6a0b4b 100644 --- a/cpp/include/raft/matrix/detail/print.hpp +++ b/cpp/include/raft/matrix/detail/print.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/topk.cuh b/cpp/include/raft/matrix/detail/select_k.cuh similarity index 59% rename from cpp/include/raft/spatial/knn/detail/topk.cuh rename to cpp/include/raft/matrix/detail/select_k.cuh index f4dcb53088..ac1ba3dfa3 100644 --- a/cpp/include/raft/spatial/knn/detail/topk.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,34 +16,34 @@ #pragma once -#include "topk/radix_topk.cuh" -#include "topk/warpsort_topk.cuh" +#include "select_radix.cuh" +#include "select_warpsort.cuh" #include #include #include -namespace raft::spatial::knn::detail { +namespace raft::matrix::detail { /** * Select k smallest or largest key/values from each row in the input data. * - * If you think of the input data `in_keys` as a row-major matrix with len columns and - * batch_size rows, then this function selects k smallest/largest values in each row and fills - * in the row-major matrix `out` of size (batch_size, k). + * If you think of the input data `in_val` as a row-major matrix with `len` columns and + * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills + * in the row-major matrix `out_val` of size (batch_size, k). * * @tparam T * the type of the keys (what is being compared). * @tparam IdxT * the index type (what is being selected together with the keys). * - * @param[in] in + * @param[in] in_val * contiguous device array of inputs of size (len * batch_size); * these are compared and selected. * @param[in] in_idx * contiguous device array of inputs of size (len * batch_size); - * typically, these are indices of the corresponding in_keys. + * typically, these are indices of the corresponding in_val. * @param batch_size * number of input rows, i.e. the batch size. * @param len @@ -51,12 +51,12 @@ namespace raft::spatial::knn::detail { * Invariant: len >= k. * @param k * the number of outputs to select in each input row. - * @param[out] out + * @param[out] out_val * contiguous device array of outputs of size (k * batch_size); - * the k smallest/largest values from each row of the `in_keys`. + * the k smallest/largest values from each row of the `in_val`. * @param[out] out_idx * contiguous device array of outputs of size (k * batch_size); - * the payload selected together with `out`. + * the payload selected together with `out_val`. * @param select_min * whether to select k smallest (true) or largest (false) keys. * @param stream @@ -64,28 +64,28 @@ namespace raft::spatial::knn::detail { * memory pool here to avoid memory allocations within the call). */ template -void select_topk(const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { common::nvtx::range fun_scope( - "matrix::select_topk(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); + "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); // TODO (achirkin): investigate the trade-off for a wider variety of inputs. const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128; - if (k <= raft::spatial::knn::detail::topk::kMaxCapacity && !radix_faster) { - topk::warp_sort_topk( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + if (k <= select::warpsort::kMaxCapacity && !radix_faster) { + select::warpsort::select_k( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } else { - topk::radix_topk= 4 ? 11 : 8), 512>( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + select::radix::select_k= 4 ? 11 : 8), 512>( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } } -} // namespace raft::spatial::knn::detail +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh similarity index 87% rename from cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh rename to cpp/include/raft/matrix/detail/select_radix.cuh index 9c0f20b706..de19e63a4c 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -27,29 +28,29 @@ #include #include -#include +#include #include -namespace raft::spatial::knn::detail::topk { +namespace raft::matrix::detail::select::radix { constexpr int ITEM_PER_THREAD = 32; constexpr int VECTORIZED_READ_SIZE = 16; template -__host__ __device__ constexpr int calc_num_buckets() +_RAFT_HOST_DEVICE constexpr int calc_num_buckets() { return 1 << BitsPerPass; } template -__host__ __device__ constexpr int calc_num_passes() +_RAFT_HOST_DEVICE constexpr int calc_num_passes() { return ceildiv(sizeof(T) * 8, BitsPerPass); } // Minimum reasonable block size for the given radix size. template -__host__ __device__ constexpr int calc_min_block_size() +_RAFT_HOST_DEVICE constexpr int calc_min_block_size() { return 1 << std::max(BitsPerPass - 4, Pow2::Log2 + 1); } @@ -62,7 +63,7 @@ __host__ __device__ constexpr int calc_min_block_size() * NB: Use pass=-1 for calc_mask(). */ template -__device__ constexpr int calc_start_bit(int pass) +_RAFT_DEVICE constexpr int calc_start_bit(int pass) { int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; if (start_bit < 0) { start_bit = 0; } @@ -70,7 +71,7 @@ __device__ constexpr int calc_start_bit(int pass) } template -__device__ constexpr unsigned calc_mask(int pass) +_RAFT_DEVICE constexpr unsigned calc_mask(int pass) { static_assert(BitsPerPass <= 31); int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); @@ -82,7 +83,7 @@ __device__ constexpr unsigned calc_mask(int pass) * as of integers. */ template -__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); @@ -91,7 +92,7 @@ __device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) } template -__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) { static_assert(BitsPerPass <= sizeof(int) * 8 - 1); // so return type can be int return (twiddle_in(x, greater) >> start_bit) & mask; @@ -112,7 +113,7 @@ __device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) * @param f the lambda taking two arguments (T x, IdxT idx) */ template -__device__ void vectorized_process(const T* in, IdxT len, Func f) +_RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) { const IdxT stride = blockDim.x * gridDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -167,18 +168,18 @@ struct Counter { * (see steps 4-1 in `radix_kernel` description). */ template -__device__ void filter_and_histogram(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - IdxT len, - Counter* counter, - IdxT* histogram, - bool greater, - int pass, - int k) +_RAFT_DEVICE void filter_and_histogram(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + Counter* counter, + IdxT* histogram, + bool greater, + int pass, + int k) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -260,10 +261,10 @@ __device__ void filter_and_histogram(const T* in_buf, * (step 2 in `radix_kernel` description) */ template -__device__ void scan(volatile IdxT* histogram, - const int start, - const int num_buckets, - const IdxT current) +_RAFT_DEVICE void scan(volatile IdxT* histogram, + const int start, + const int num_buckets, + const IdxT current) { typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; @@ -284,7 +285,7 @@ __device__ void scan(volatile IdxT* histogram, * (steps 2-3 in `radix_kernel` description) */ template -__device__ void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) +_RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) { constexpr int num_buckets = calc_num_buckets(); int index = threadIdx.x; @@ -547,21 +548,21 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) * memory pool here to avoid memory allocations within the call). */ template -void radix_topk(const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +void select_k(const T* in, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { // reduce the block size if the input length is too small. if constexpr (BlockSize > calc_min_block_size()) { if (BlockSize * ITEM_PER_THREAD > len) { - return radix_topk( + return select_k( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); } } @@ -573,23 +574,33 @@ void radix_topk(const T* in, dim3 blocks = get_optimal_grid_size(batch_size, len); size_t max_chunk_size = blocks.y; - auto pool_guard = raft::get_pool_memory_resource( - mr, - max_chunk_size * (sizeof(Counter) // counters - + sizeof(IdxT) * (num_buckets + 2) // histograms and IdxT bufs - + sizeof(T) * 2 // T bufs - )); + size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t mem_req = req_aux + req_buf; + size_t mem_free, mem_total; + RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); + std::optional managed_memory; + rmm::mr::device_memory_resource* mr_buf = nullptr; + if (mem_req > mem_free) { + // if there's not enough memory for buffers on the device, resort to the managed memory. + mem_req = req_aux; + managed_memory.emplace(); + mr_buf = &managed_memory.value(); + } + + auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); if (pool_guard) { - RAFT_LOG_DEBUG("radix_topk: using pool memory resource with initial size %zu bytes", + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } + if (mr_buf == nullptr) { mr_buf = mr; } rmm::device_uvector> counters(max_chunk_size, stream, mr); - rmm::device_uvector histograms(num_buckets * max_chunk_size, stream, mr); - rmm::device_uvector buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector buf2(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); + rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); + rmm::device_uvector buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector buf2(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf2(max_chunk_size * len, stream, mr_buf); for (size_t offset = 0; offset < batch_size; offset += max_chunk_size) { blocks.y = std::min(max_chunk_size, batch_size - offset); @@ -646,4 +657,4 @@ void radix_topk(const T* in, } } -} // namespace raft::spatial::knn::detail::topk +} // namespace raft::matrix::detail::select::radix diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh similarity index 71% rename from cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh rename to cpp/include/raft/matrix/detail/select_warpsort.cuh index cbe9f36e97..d362b73792 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,11 @@ #pragma once -#include "bitonic_sort.cuh" - +#include #include +#include #include +#include #include #include @@ -31,12 +32,12 @@ /* Three APIs of different scopes are provided: - 1. host function: warp_sort_topk() + 1. host function: select_k() 2. block-wide API: class block_sort 3. warp-wide API: several implementations of warp_sort_* - 1. warp_sort_topk() + 1. select_k() (see the docstring) 2. class block_sort @@ -74,7 +75,7 @@ These two classes can be regarded as fixed size priority queue for a warp. Usage is similar to class block_sort. No shared memory is needed. - The host function (warp_sort_topk) uses a heuristic to choose between these two classes for + The host function (select_k) uses a heuristic to choose between these two classes for sorting, warp_sort_immediate being chosen when the number of inputs per warp is somewhat small (see the usage of LaunchThreshold::len_factor_for_choosing). @@ -94,7 +95,7 @@ } */ -namespace raft::spatial::knn::detail::topk { +namespace raft::matrix::detail::select::warpsort { static constexpr int kMaxCapacity = 256; @@ -102,18 +103,12 @@ namespace { /** Whether 'left` should indeed be on the left w.r.t. `right`. */ template -__device__ __forceinline__ auto is_ordered(T left, T right) -> bool +_RAFT_DEVICE _RAFT_FORCEINLINE auto is_ordered(T left, T right) -> bool { if constexpr (Ascending) { return left < right; } if constexpr (!Ascending) { return left > right; } } -constexpr auto calc_capacity(int k) -> int -{ - int capacity = isPo2(k) ? k : (1 << (log2(k) + 1)); - return capacity; -} - } // namespace /** @@ -134,7 +129,7 @@ constexpr auto calc_capacity(int k) -> int */ template class warp_sort { - static_assert(isPo2(Capacity)); + static_assert(is_a_power_of_two(Capacity)); static_assert(std::is_default_constructible_v); public: @@ -148,13 +143,16 @@ class warp_sort { /** The number of elements to select. */ const int k; + /** Extra memory required per-block for keeping the state (shared or global). */ + constexpr static auto mem_required(uint32_t block_size) -> size_t { return 0; } + /** * Construct the warp_sort empty queue. * * @param k * number of elements to select. */ - __device__ warp_sort(int k) : k(k) + _RAFT_DEVICE warp_sort(int k) : k(k) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -182,7 +180,7 @@ class warp_sort { * It serves as a conditional; when `false` the function does nothing. * We need it to ensure threads within a full warp don't diverge calling `bitonic::merge()`. */ - __device__ void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) + _RAFT_DEVICE void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) { if (do_merge) { int idx = Pow2::mod(laneId()) ^ Pow2::Mask; @@ -198,7 +196,7 @@ class warp_sort { } } if (kWarpWidth < WarpSize || do_merge) { - topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); } } @@ -211,14 +209,23 @@ class warp_sort { * @param[out] out_idx * device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth` * (length: k <= kWarpWidth * kMaxArrLen). + * @param valF (optional) postprocess values (T -> OutT) + * @param idxF (optional) postprocess indices (IdxT -> OutIdxT) */ - __device__ void store(T* out, IdxT* out_idx) const + template + _RAFT_DEVICE void store(OutT* out, + OutIdxT* out_idx, + ValF valF = raft::identity_op{}, + IdxF idxF = raft::identity_op{}) const { int idx = Pow2::mod(laneId()); #pragma unroll kMaxArrLen for (int i = 0; i < kMaxArrLen && idx < k; i++, idx += kWarpWidth) { - out[idx] = val_arr_[i]; - out_idx[idx] = idx_arr_[i]; + out[idx] = valF(val_arr_[i]); + out_idx[idx] = idxF(idx_arr_[i]); } } @@ -245,8 +252,8 @@ class warp_sort { * the associated indices of the elements in the same format as `keys_in`. */ template - __device__ __forceinline__ void merge_in(const T* __restrict__ keys_in, - const IdxT* __restrict__ ids_in) + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_in(const T* __restrict__ keys_in, + const IdxT* __restrict__ ids_in) { #pragma unroll for (int i = std::min(kMaxArrLen, PerThreadSizeIn); i > 0; i--) { @@ -257,7 +264,7 @@ class warp_sort { idx_arr_[kMaxArrLen - i] = ids_in[PerThreadSizeIn - i]; } } - topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); } }; @@ -275,8 +282,9 @@ class warp_sort_filtered : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_filtered(int k, T limit) + explicit _RAFT_DEVICE warp_sort_filtered(int k, T limit = kDummy) : warp_sort(k), buf_len_(0), k_th_(limit) { #pragma unroll @@ -286,12 +294,14 @@ class warp_sort_filtered : public warp_sort { } } - __device__ __forceinline__ explicit warp_sort_filtered(int k) - : warp_sort_filtered(k, kDummy) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) { + return warp_sort_filtered{k, limit}; } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE void add(T val, IdxT idx) { // comparing for k_th should reduce the total amount of updates: // `false` means the input value is surely not in the top-k values. @@ -309,22 +319,22 @@ class warp_sort_filtered : public warp_sort { if (do_add) { add_to_buf_(val, idx); } } - __device__ void done() + _RAFT_DEVICE void done() { if (any(buf_len_ != 0)) { merge_buf_(); } } private: - __device__ __forceinline__ void set_k_th_() + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() { // NB on using srcLane: it's ok if it is outside the warp size / width; // the modulo op will be done inside the __shfl_sync. k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); } - __device__ __forceinline__ void merge_buf_() + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); buf_len_ = 0; set_k_th_(); // contains warp sync @@ -334,7 +344,7 @@ class warp_sort_filtered : public warp_sort { } } - __device__ __forceinline__ void add_to_buf_(T val, IdxT idx) + _RAFT_DEVICE _RAFT_FORCEINLINE void add_to_buf_(T val, IdxT idx) { // NB: the loop is used here to ensure the constant indexing, // to not force the buffers spill into the local memory. @@ -373,8 +383,9 @@ class warp_sort_distributed : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_distributed(int k, T limit) + explicit _RAFT_DEVICE warp_sort_distributed(int k, T limit = kDummy) : warp_sort(k), buf_val_(kDummy), buf_idx_(IdxT{}), @@ -383,12 +394,14 @@ class warp_sort_distributed : public warp_sort { { } - __device__ __forceinline__ explicit warp_sort_distributed(int k) - : warp_sort_distributed(k, kDummy) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) { + return warp_sort_distributed{k, limit}; } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE void add(T val, IdxT idx) { // mask tells which lanes in the warp have valid items to be added uint32_t mask = ballot(is_ordered(val, k_th_)); @@ -428,7 +441,7 @@ class warp_sort_distributed : public warp_sort { } } - __device__ void done() + _RAFT_DEVICE void done() { if (buf_len_ != 0) { merge_buf_(); @@ -437,16 +450,16 @@ class warp_sort_distributed : public warp_sort { } private: - __device__ __forceinline__ void set_k_th_() + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() { // NB on using srcLane: it's ok if it is outside the warp size / width; // the modulo op will be done inside the __shfl_sync. k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); } - __device__ __forceinline__ void merge_buf_() + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() { - topk::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val_, buf_idx_); + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val_, buf_idx_); this->merge_in<1>(&buf_val_, &buf_idx_); set_k_th_(); // contains warp sync buf_val_ = kDummy; @@ -463,6 +476,117 @@ class warp_sort_distributed : public warp_sort { T k_th_; }; +/** + * The same as `warp_sort_distributed`, but keeps the temporary value and index buffers + * in the given external pointers (normally, a shared memory pointer should be passed in). + */ +template +class warp_sort_distributed_ext : public warp_sort { + public: + using warp_sort::kDummy; + using warp_sort::kWarpWidth; + using warp_sort::k; + + constexpr static auto mem_required(uint32_t block_size) -> size_t + { + return (sizeof(T) + sizeof(IdxT)) * block_size; + } + + _RAFT_DEVICE warp_sort_distributed_ext(int k, T* val_buf, IdxT* idx_buf, T limit = kDummy) + : warp_sort(k), + val_buf_(val_buf), + idx_buf_(idx_buf), + buf_len_(0), + k_th_(limit) + { + val_buf_[laneId()] = kDummy; + } + + _RAFT_DEVICE static auto init_blockwide(int k, uint8_t* shmem, T limit = kDummy) + { + T* val_buf = nullptr; + IdxT* idx_buf = nullptr; + if constexpr (alignof(T) >= alignof(IdxT)) { + val_buf = reinterpret_cast(shmem); + idx_buf = reinterpret_cast(val_buf + blockDim.x); + } else { + idx_buf = reinterpret_cast(shmem); + val_buf = reinterpret_cast(idx_buf + blockDim.x); + } + auto warp_offset = Pow2::roundDown(threadIdx.x); + val_buf += warp_offset; + idx_buf += warp_offset; + return warp_sort_distributed_ext{k, val_buf, idx_buf, limit}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) + { + bool do_add = is_ordered(val, k_th_); + // mask tells which lanes in the warp have valid items to be added + uint32_t mask = ballot(do_add); + if (mask == 0) { return; } + // where to put the element in the tmp buffer + int dst_ix = buf_len_ + __popc(mask & ((1u << laneId()) - 1u)); + // put all elements, which fit into the current tmp buffer + if (do_add && dst_ix < WarpSize) { + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + do_add = false; + } + // Total number of elements to be added + buf_len_ += __popc(mask); + // If the buffer is still not full, we can return + if (buf_len_ < WarpSize) { return; } + // Otherwise, merge the warp tmp buffer into the queue + merge_buf_(); // implies warp sync + buf_len_ -= WarpSize; + // save the inputs that couldn't fit before the merge + if (do_add) { + dst_ix -= WarpSize; + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + } + } + + _RAFT_DEVICE void done() + { + if (buf_len_ != 0) { + merge_buf_(); + buf_len_ = 0; + } + __syncthreads(); + } + + private: + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() + { + // NB on using srcLane: it's ok if it is outside the warp size / width; + // the modulo op will be done inside the __shfl_sync. + k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); + } + + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() + { + __syncwarp(); // make sure the threads are aware of the data written by others + T buf_val = val_buf_[laneId()]; + IdxT buf_idx = idx_buf_[laneId()]; + val_buf_[laneId()] = kDummy; + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val, buf_idx); + this->merge_in<1>(&buf_val, &buf_idx); + set_k_th_(); // contains warp sync + } + + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + + T* val_buf_; + IdxT* idx_buf_; + uint32_t buf_len_; // 0 <= buf_len_ < WarpSize + + T k_th_; +}; + /** * This version of warp_sort adds every input element into the intermediate sorting * buffer, and thus does the sorting step every `Capacity` input elements. @@ -475,8 +599,10 @@ class warp_sort_immediate : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_immediate(int k) : warp_sort(k), buf_len_(0) + explicit _RAFT_DEVICE warp_sort_immediate(int k) + : warp_sort(k), buf_len_(0) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -485,7 +611,12 @@ class warp_sort_immediate : public warp_sort { } } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, uint8_t* = nullptr) + { + return warp_sort_immediate{k}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) { // NB: the loop is used here to ensure the constant indexing, // to not force the buffers spill into the local memory. @@ -499,7 +630,7 @@ class warp_sort_immediate : public warp_sort { ++buf_len_; if (buf_len_ == kMaxArrLen) { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -509,10 +640,10 @@ class warp_sort_immediate : public warp_sort { } } - __device__ void done() + _RAFT_DEVICE void done() { if (buf_len_ != 0) { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); } } @@ -544,15 +675,11 @@ class block_sort { using queue_t = WarpSortWarpWide; template - __device__ block_sort(int k, uint8_t* smem_buf, Args... args) : queue_(k, args...) + _RAFT_DEVICE block_sort(int k, Args... args) : queue_(queue_t::init_blockwide(k, args...)) { - val_smem_ = reinterpret_cast(smem_buf); - const int num_of_warp = subwarp_align::div(blockDim.x); - idx_smem_ = reinterpret_cast( - smem_buf + Pow2<256>::roundUp(ceildiv(num_of_warp, 2) * sizeof(T) * k)); } - __device__ void add(T val, IdxT idx) { queue_.add(val, idx); } + _RAFT_DEVICE void add(T val, IdxT idx) { queue_.add(val, idx); } /** * At the point of calling this function, the warp-level queues consumed all input @@ -560,22 +687,26 @@ class block_sort { * * Here we tree-merge the results using the shared memory and block sync. */ - __device__ void done() + _RAFT_DEVICE void done(uint8_t* smem_buf) { queue_.done(); + int nwarps = subwarp_align::div(blockDim.x); + auto val_smem = reinterpret_cast(smem_buf); + auto idx_smem = reinterpret_cast( + smem_buf + Pow2<256>::roundUp(ceildiv(nwarps, 2) * sizeof(T) * queue_.k)); + const int warp_id = subwarp_align::div(threadIdx.x); // NB: there is no need for the second __synchthreads between .load_sorted and .store: // we shift the pointers every iteration, such that individual warps either access the same // locations or do not overlap with any of the other warps. The access patterns within warps // are different for the two functions, but .load_sorted implies warp sync at the end, so // there is no need for __syncwarp either. - for (int shift_mask = ~0, nwarps = subwarp_align::div(blockDim.x), split = (nwarps + 1) >> 1; - nwarps > 1; + for (int shift_mask = ~0, split = (nwarps + 1) >> 1; nwarps > 1; nwarps = split, split = (nwarps + 1) >> 1) { if (warp_id < nwarps && warp_id >= split) { int dst_warp_shift = (warp_id - (split & shift_mask)) * queue_.k; - queue_.store(val_smem_ + dst_warp_shift, idx_smem_ + dst_warp_shift); + queue_.store(val_smem + dst_warp_shift, idx_smem + dst_warp_shift); } __syncthreads(); @@ -585,22 +716,27 @@ class block_sort { // The last argument serves as a condition for loading // -- to make sure threads within a full warp do not diverge on `bitonic::merge()` queue_.load_sorted( - val_smem_ + src_warp_shift, idx_smem_ + src_warp_shift, warp_id < nwarps - split); + val_smem + src_warp_shift, idx_smem + src_warp_shift, warp_id < nwarps - split); } } } /** Save the content by the pointer location. */ - __device__ void store(T* out, IdxT* out_idx) const + template + _RAFT_DEVICE void store(OutT* out, + OutIdxT* out_idx, + ValF valF = raft::identity_op{}, + IdxF idxF = raft::identity_op{}) const { - if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx); } + if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx, valF, idxF); } } private: using subwarp_align = Pow2; queue_t queue_; - T* val_smem_; - IdxT* idx_smem_; }; /** @@ -618,7 +754,10 @@ __launch_bounds__(256) __global__ void block_kernel(const T* in, const IdxT* in_idx, IdxT len, int k, T* out, IdxT* out_idx) { extern __shared__ __align__(256) uint8_t smem_buf_bytes[]; - block_sort queue(k, smem_buf_bytes); + using bq_t = block_sort; + uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr; + bq_t queue(k, warp_smem); + in += blockIdx.y * len; if (in_idx != nullptr) { in_idx += blockIdx.y * len; } @@ -629,7 +768,7 @@ __launch_bounds__(256) __global__ (i < len && in_idx != nullptr) ? __ldcs(in_idx + i) : i); } - queue.done(); + queue.done(smem_buf_bytes); const int block_id = blockIdx.x + gridDim.x * blockIdx.y; queue.store(out + block_id * k, out_idx + block_id * k); } @@ -656,7 +795,7 @@ struct launch_setup { int* min_grid_size, int block_size_limit = 0) { - const int capacity = calc_capacity(k); + const int capacity = bound_by_power_of_two(k); if constexpr (Capacity > 1) { if (capacity < Capacity) { return launch_setup::calc_optimal_params( @@ -689,7 +828,7 @@ struct launch_setup { IdxT* out_idx, rmm::cuda_stream_view stream) { - const int capacity = calc_capacity(k); + const int capacity = bound_by_power_of_two(k); if constexpr (Capacity > 1) { if (capacity < Capacity) { return launch_setup::kernel(k, @@ -740,6 +879,18 @@ struct LaunchThreshold { static constexpr int len_factor_for_single_block = 32; }; +template <> +struct LaunchThreshold { + static constexpr int len_factor_for_multi_block = 2; + static constexpr int len_factor_for_single_block = 32; +}; + +template <> +struct LaunchThreshold { + static constexpr int len_factor_for_multi_block = 2; + static constexpr int len_factor_for_single_block = 32; +}; + template <> struct LaunchThreshold { static constexpr int len_factor_for_choosing = 4; @@ -751,7 +902,7 @@ template