diff --git a/.github/actions/build/action.yml b/.github/actions/build/action.yml new file mode 100644 index 000000000..2b470d9dc --- /dev/null +++ b/.github/actions/build/action.yml @@ -0,0 +1,28 @@ +name: 'DiskANN Build Bootstrap' +description: 'Prepares DiskANN build environment and executes build' +runs: + using: "composite" + steps: + # ------------ Linux Build --------------- + - name: Prepare and Execute Build + if: ${{ runner.os == 'Linux' }} + run: | + sudo scripts/dev/install-dev-deps-ubuntu.bash + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DUNIT_TEST=True + cmake --build build -- -j + cmake --install build --prefix="dist" + shell: bash + # ------------ End Linux Build --------------- + # ------------ Windows Build --------------- + - name: Add VisualStudio command line tools into path + if: runner.os == 'Windows' + uses: ilammy/msvc-dev-cmd@v1 + - name: Run configure and build for Windows + if: runner.os == 'Windows' + run: | + mkdir build && cd build && cmake .. -DUNIT_TEST=True && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary" + cd .. + mkdir dist + mklink /j .\dist\bin .\x64\Release\ + shell: cmd + # ------------ End Windows Build --------------- \ No newline at end of file diff --git a/.github/actions/format-check/action.yml b/.github/actions/format-check/action.yml new file mode 100644 index 000000000..6ed08c095 --- /dev/null +++ b/.github/actions/format-check/action.yml @@ -0,0 +1,13 @@ +name: 'Checking code formatting...' +description: 'Ensures code complies with code formatting rules' +runs: + using: "composite" + steps: + - name: Checking code formatting... + run: | + sudo apt install clang-format + find include -name '*.h' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run + find src -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run + find apps -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run + find python -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run + shell: bash diff --git a/.github/actions/generate-random/action.yml b/.github/actions/generate-random/action.yml new file mode 100644 index 000000000..75554773e --- /dev/null +++ b/.github/actions/generate-random/action.yml @@ -0,0 +1,35 @@ +name: 'Generating Random Data (Basic)' +description: 'Generates the random data files used in acceptance tests' +runs: + using: "composite" + steps: + - name: Generate Random Data (Basic) + run: | + mkdir data + + echo "Generating random vectors for index" + dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_norm1.0.bin -D 10 -N 10000 --norm 1.0 + dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0 + dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0 + + echo "Generating random vectors for query" + dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_norm1.0.bin -D 10 -N 1000 --norm 1.0 + dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0 + dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0 + + echo "Computing ground truth for floats across l2, mips, and cosine distance functions" + dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100 + dist/bin/compute_groundtruth --data_type float --dist_fn mips --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100 + dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100 + + echo "Computing ground truth for int8s across l2, mips, and cosine distance functions" + dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 + dist/bin/compute_groundtruth --data_type int8 --dist_fn mips --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/mips_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 + dist/bin/compute_groundtruth --data_type int8 --dist_fn cosine --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 + + echo "Computing ground truth for uint8s across l2, mips, and cosine distance functions" + dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 + dist/bin/compute_groundtruth --data_type uint8 --dist_fn mips --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 + dist/bin/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 + + shell: bash diff --git a/.github/actions/python-wheel/action.yml b/.github/actions/python-wheel/action.yml new file mode 100644 index 000000000..6a2880c6d --- /dev/null +++ b/.github/actions/python-wheel/action.yml @@ -0,0 +1,22 @@ +name: Build Python Wheel +description: Builds a python wheel with cibuildwheel +inputs: + cibw-identifier: + description: "CI build wheel identifier to build" + required: true +runs: + using: "composite" + steps: + - uses: actions/setup-python@v3 + - name: Install cibuildwheel + run: python -m pip install cibuildwheel==2.11.3 + shell: bash + - name: Building Python ${{inputs.cibw-identifier}} Wheel + run: python -m cibuildwheel --output-dir dist + env: + CIBW_BUILD: ${{inputs.cibw-identifier}} + shell: bash + - uses: actions/upload-artifact@v3 + with: + name: wheels + path: ./dist/*.whl diff --git a/.github/workflows/build-python.yml b/.github/workflows/build-python.yml new file mode 100644 index 000000000..b825398d1 --- /dev/null +++ b/.github/workflows/build-python.yml @@ -0,0 +1,42 @@ +name: DiskANN Build Python Wheel +on: [workflow_call] +jobs: + linux-build: + name: Python - Ubuntu - ${{matrix.cibw-identifier}} + strategy: + fail-fast: false + matrix: + cibw-identifier: ["cp39-manylinux_x86_64", "cp310-manylinux_x86_64", "cp311-manylinux_x86_64"] + runs-on: ubuntu-latest + defaults: + run: + shell: bash + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 1 + - name: Building python wheel ${{matrix.cibw-identifier}} + uses: ./.github/actions/python-wheel + with: + cibw-identifier: ${{matrix.cibw-identifier}} + windows-build: + name: Python - Windows - ${{matrix.cibw-identifier}} + strategy: + fail-fast: false + matrix: + cibw-identifier: ["cp39-win_amd64", "cp310-win_amd64", "cp311-win_amd64"] + runs-on: windows-latest + defaults: + run: + shell: bash + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + submodules: true + fetch-depth: 1 + - name: Building python wheel ${{matrix.cibw-identifier}} + uses: ./.github/actions/python-wheel + with: + cibw-identifier: ${{matrix.cibw-identifier}} diff --git a/.github/workflows/common.yml b/.github/workflows/common.yml new file mode 100644 index 000000000..09c020abe --- /dev/null +++ b/.github/workflows/common.yml @@ -0,0 +1,28 @@ +name: DiskANN Common Checks +# common means common to both pr-test and push-test +on: [workflow_call] +jobs: + formatting-check: + strategy: + fail-fast: true + name: Code Formatting Test + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 1 + - name: Checking code formatting... + uses: ./.github/actions/format-check + docker-container-build: + name: Docker Container Build + needs: [formatting-check] + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 1 + - name: Docker build + run: | + docker build . \ No newline at end of file diff --git a/.github/workflows/disk-pq.yml b/.github/workflows/disk-pq.yml new file mode 100644 index 000000000..35c662184 --- /dev/null +++ b/.github/workflows/disk-pq.yml @@ -0,0 +1,107 @@ +name: Disk With PQ +on: [workflow_call] +jobs: + acceptance-tests-disk-pq: + name: Disk, PQ + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-2019, windows-latest] + runs-on: ${{matrix.os}} + defaults: + run: + shell: bash + steps: + - name: Checkout repository + if: ${{ runner.os == 'Linux' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + - name: Checkout repository + if: ${{ runner.os == 'Windows' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + submodules: true + - name: DiskANN Build CLI Applications + uses: ./.github/actions/build + + - name: Generate Data + uses: ./.github/actions/generate-random + + - name: build and search disk index (one shot graph build, L2, no diskPQ) (float) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1 + dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + - name: build and search disk index (one shot graph build, L2, no diskPQ) (int8) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1 + dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + - name: build and search disk index (one shot graph build, L2, no diskPQ) (uint8) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1 + dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + + - name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (float) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5 + dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + - name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (int8) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5 + dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16\ + - name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (uint8) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5 + dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + + - name: build and search disk index (sharded graph build, L2, no diskPQ) (float) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 + dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + - name: build and search disk index (sharded graph build, L2, no diskPQ) (int8) + run: | + dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 + dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + - name: build and search disk index (sharded graph build, L2, no diskPQ) (uint8) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 + dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + + - name: build and search disk index (one shot graph build, L2, diskPQ) (float) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5 + dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + - name: build and search disk index (one shot graph build, L2, diskPQ) (int8) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5 + dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + - name: build and search disk index (one shot graph build, L2, diskPQ) (uint8) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5 + dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + + - name: build and search disk index (sharded graph build, MIPS, diskPQ) (float) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 --PQ_disk_bytes 5 + dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + + - name: upload data and bin + uses: actions/upload-artifact@v3 + with: + name: disk-pq + path: | + ./dist/** + ./data/** diff --git a/.github/workflows/dynamic.yml b/.github/workflows/dynamic.yml new file mode 100644 index 000000000..35eb6d42d --- /dev/null +++ b/.github/workflows/dynamic.yml @@ -0,0 +1,75 @@ +name: Dynamic +on: [workflow_call] +jobs: + acceptance-tests-dynamic: + name: Dynamic + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-2019, windows-latest] + runs-on: ${{matrix.os}} + defaults: + run: + shell: bash + steps: + - name: Checkout repository + if: ${{ runner.os == 'Linux' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + - name: Checkout repository + if: ${{ runner.os == 'Windows' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + submodules: true + - name: DiskANN Build CLI Applications + uses: ./.github/actions/build + + - name: Generate Data + uses: ./.github/actions/generate-random + + - name: test a streaming index (float) + run: | + dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 + dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags + dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1 + - name: test a streaming index (int8) + if: success() || failure() + run: | + dist/bin/test_streaming_scenario --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200 + dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags + dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1 + - name: test a streaming index + if: success() || failure() + run: | + dist/bin/test_streaming_scenario --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200 + dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags + dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1 + + - name: build and search an incremental index (float) + if: success() || failure() + run: | + dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2; + dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags + dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1 + - name: build and search an incremental index (int8) + if: success() || failure() + run: | + dist/bin/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200 + dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags + dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1 + - name: build and search an incremental index (uint8) + if: success() || failure() + run: | + dist/bin/test_insert_deletes_consolidate --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200 + dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_10K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags + dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_10K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1 + + - name: upload data and bin + uses: actions/upload-artifact@v3 + with: + name: dynamic + path: | + ./dist/** + ./data/** diff --git a/.github/workflows/in-mem-no-pq.yml b/.github/workflows/in-mem-no-pq.yml new file mode 100644 index 000000000..0039754d2 --- /dev/null +++ b/.github/workflows/in-mem-no-pq.yml @@ -0,0 +1,81 @@ +name: In-Memory Without PQ +on: [workflow_call] +jobs: + acceptance-tests-mem-no-pq: + name: In-Mem, Without PQ + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-2019, windows-latest] + runs-on: ${{matrix.os}} + defaults: + run: + shell: bash + steps: + - name: Checkout repository + if: ${{ runner.os == 'Linux' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + - name: Checkout repository + if: ${{ runner.os == 'Windows' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + submodules: true + - name: DiskANN Build CLI Applications + uses: ./.github/actions/build + + - name: Generate Data + uses: ./.github/actions/generate-random + + - name: build and search in-memory index with L2 metrics (float) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 + dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 + - name: build and search in-memory index with L2 metrics (int8) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 + dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 + - name: build and search in-memory index with L2 metrics (uint8) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 + dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 + + - name: Searching with fast_l2 distance function (float) + if: runner.os != 'Windows' && (success() || failure()) + run: | + dist/bin/search_memory_index --data_type float --dist_fn fast_l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 + + - name: build and search in-memory index with MIPS metric (float) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_mips_rand_float_10D_10K_norm1.0 + dist/bin/search_memory_index --data_type float --dist_fn mips --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 + + - name: build and search in-memory index with cosine metric (float) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_cosine_rand_float_10D_10K_norm1.0 + dist/bin/search_memory_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 + - name: build and search in-memory index with cosine metric (int8) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type int8 --dist_fn cosine --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_int8_10D_10K_norm50.0 + dist/bin/search_memory_index --data_type int8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 + - name: build and search in-memory index with cosine metric + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50.0 + dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 + + - name: upload data and bin + uses: actions/upload-artifact@v3 + with: + name: in-memory-no-pq + path: | + ./dist/** + ./data/** diff --git a/.github/workflows/in-mem-pq.yml b/.github/workflows/in-mem-pq.yml new file mode 100644 index 000000000..f9276adfc --- /dev/null +++ b/.github/workflows/in-mem-pq.yml @@ -0,0 +1,56 @@ +name: In-Memory With PQ +on: [workflow_call] +jobs: + acceptance-tests-mem-pq: + name: In-Mem, PQ + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-2019, windows-latest] + runs-on: ${{matrix.os}} + defaults: + run: + shell: bash + steps: + - name: Checkout repository + if: ${{ runner.os == 'Linux' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + - name: Checkout repository + if: ${{ runner.os == 'Windows' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + submodules: true + - name: DiskANN Build CLI Applications + uses: ./.github/actions/build + + - name: Generate Data + uses: ./.github/actions/generate-random + + - name: build and search in-memory index with L2 metric with PQ based distance comparisons (float) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --build_PQ_bytes 5 + dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 + + - name: build and search in-memory index with L2 metrics with PQ base distance comparisons (int8) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5 + dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 + + - name: build and search in-memory index with L2 metrics with PQ base distance comparisons (uint8) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5 + dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 + + - name: upload data and bin + uses: actions/upload-artifact@v3 + with: + name: in-memory-pq + path: | + ./dist/** + ./data/** \ No newline at end of file diff --git a/.github/workflows/labels.yml b/.github/workflows/labels.yml new file mode 100644 index 000000000..e811c1ff5 --- /dev/null +++ b/.github/workflows/labels.yml @@ -0,0 +1,106 @@ +name: Labels +on: [workflow_call] +jobs: + acceptance-tests-labels: + name: Labels + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-2019, windows-latest] + runs-on: ${{matrix.os}} + defaults: + run: + shell: bash + steps: + - name: Checkout repository + if: ${{ runner.os == 'Linux' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + - name: Checkout repository + if: ${{ runner.os == 'Windows' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + submodules: true + - name: DiskANN Build CLI Applications + uses: ./.github/actions/build + + - name: Generate Data + uses: ./.github/actions/generate-random + + - name: Generate Labels + run: | + echo "Generating synthetic labels and computing ground truth for filtered search with universal label" + dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/rand_labels_50_10K.txt --distribution_type random + dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100 + dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100 + dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100 + + echo "Generating synthetic labels with a zipf distribution and computing ground truth for filtered search with universal label" + dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/zipf_labels_50_10K.txt --distribution_type zipf + dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100 + dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/mips_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100 + dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100 + + echo "Generating synthetic labels and computing ground truth for filtered search without a universal label" + dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --K 100 + dist/bin/generate_synthetic_labels --num_labels 10 --num_points 1000 --output_file data/query_labels_1K.txt --distribution_type one_per_point + dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label_file data/query_labels_1K.txt --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100 + + - name: build and search in-memory index with labels using L2 and Cosine metrics (random distributed labels) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel + dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel + dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32 + dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32 + - name: build and search disk index with labels using L2 and Cosine metrics (random distributed labels) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel -R 16 -L 32 -B 0.00003 -M 1 + dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + - name: build and search in-memory index with labels using L2 and Cosine metrics (zipf distributed labels) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel + dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel + dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32 + dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32 + - name: build and search disk index with labels using L2 and Cosine metrics (zipf distributed labels) + if: success() || failure() + run: | + dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel -R 16 -L 32 -B 0.00003 -M 1 + dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + - name : build and search in-memory and disk index (without universal label, zipf distributed) + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal + dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal -R 16 -L 32 -B 0.00003 -M 1 + dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal -L 16 32 + dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 + - name: Generate combined GT for each query with a separate label and search + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel + dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --query_filters_file data/query_labels_1K.txt --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32 + - name: build and search in-memory index with pq_dist of 5 with 10 dimensions + if: success() || failure() + run: | + dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --build_PQ_bytes 5 + dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32 + - name: Build and search stitched vamana with random and zipf distributed labels + if: success() || failure() + run: | + dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_rand_32_100_64_new --universal_label 0 + dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_zipf_32_100_64_new --universal_label 0 + dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 10 --index_path_prefix data/stit_rand_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/rand_stit_96_10_90_new --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150 + dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/stit_zipf_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/zipf_stit_96_10_90_new --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150 + + - name: upload data and bin + uses: actions/upload-artifact@v3 + with: + name: labels + path: | + ./dist/** + ./data/** diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 46495dbf6..38eefb3ff 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -1,258 +1,29 @@ name: DiskANN Pull Request Build and Test on: [pull_request] jobs: - build-and-run: - name: Build and run tests for ${{ matrix.os }} - runs-on: ${{ matrix.os }} - + common: strategy: - fail-fast: false - matrix: - os: [ubuntu-latest, windows-2019, windows-latest] - - # Use bash for Windows as well. - defaults: - run: - shell: bash - - steps: - - name: Checkout repository - uses: actions/checkout@v2 - with: - submodules: true - - name: Install dependencies - if: runner.os != 'Windows' - run: | - sudo apt install cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libmkl-full-dev - - name: Ubuntu CMake Configure - if: runner.os != 'Windows' - run: | - mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. - - - name: Clang Format Check - if: ${{ matrix.os == 'ubuntu-latest' }} - run: | - cd build && make checkformat - - - name: build - if: runner.os != 'Windows' - run: | - cd build && make -j - - - name: Add VisualStudio command line tools into path - if: runner.os == 'Windows' - uses: ilammy/msvc-dev-cmd@v1 - - - name: Run configure and build for Windows - if: runner.os == 'Windows' - run: | - mkdir build && cd build && cmake .. && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary" - shell: cmd - - - name: Set environment variables for running the tests on ${{ runner.os }} - if: runner.os != 'Windows' - run: | - echo "diskann_built_tests=./build/tests" >> $GITHUB_ENV - echo "diskann_built_utils=./build/tests/utils" >> $GITHUB_ENV - - - name: Set environment variables for running the tests on ${{ runner.os }} - if: runner.os == 'Windows' - run: | - echo "diskann_built_tests=./x64/Release" >> $GITHUB_ENV - echo "diskann_built_utils=./x64/Release" >> $GITHUB_ENV - - - name: Generate 10K random float32 index vectors, 1K query vectors, in 10 dims and compute GT - run: | - ${{ env.diskann_built_utils }}/rand_data_gen --data_type float --output_file ./rand_float_10D_10K_norm1.0.bin -D 10 -N 10000 --norm 1.0 - ${{ env.diskann_built_utils }}/rand_data_gen --data_type float --output_file ./rand_float_10D_1K_norm1.0.bin -D 10 -N 1000 --norm 1.0 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type float --dist_fn l2 --base_file ./rand_float_10D_10K_norm1.0.bin --query_file ./rand_float_10D_1K_norm1.0.bin --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type float --dist_fn mips --base_file ./rand_float_10D_10K_norm1.0.bin --query_file ./rand_float_10D_1K_norm1.0.bin --gt_file ./mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type float --dist_fn cosine --base_file ./rand_float_10D_10K_norm1.0.bin --query_file ./rand_float_10D_1K_norm1.0.bin --gt_file ./cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100 - - name: build and search in-memory index with L2 metrics - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type float --dist_fn l2 --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 - ${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 --query_file ./rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 - - name: Searching with fast_l2 distance function - if: runner.os != 'Windows' - run: | - ${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn fast_l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 --query_file ./rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 - - name: build and search in-memory index with MIPS metric - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type float --dist_fn mips --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./index_mips_rand_float_10D_10K_norm1.0 - ${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn mips --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 --query_file ./rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file ./mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 - - name: build and search in-memory index with cosine metric - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type float --dist_fn cosine --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./index_cosine_rand_float_10D_10K_norm1.0 - ${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 --query_file ./rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file ./cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 - - name: build and search in-memory index with L2 metric with PQ based distance comparisons - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type float --dist_fn l2 --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0_buildpq5 --build_PQ_bytes 5 - ${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0_buildpq5 --query_file ./rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 - - name: build and search disk index (one shot graph build, L2, no diskPQ) - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type float --dist_fn l2 --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1 - ${{ env.diskann_built_tests }}/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file ./rand_float_10D_1K_norm1.0.bin --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type float --dist_fn l2 --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5 - ${{ env.diskann_built_tests }}/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file ./rand_float_10D_1K_norm1.0.bin --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search disk index (sharded graph build, L2, no diskPQ) - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type float --dist_fn l2 --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 - ${{ env.diskann_built_tests }}/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded --result_path /tmp/res --query_file ./rand_float_10D_1K_norm1.0.bin --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search disk index (one shot graph build, L2, diskPQ) - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type float --dist_fn l2 --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5 - ${{ env.diskann_built_tests }}/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot --result_path /tmp/res --query_file ./rand_float_10D_1K_norm1.0.bin --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search disk index (sharded graph build, MIPS, diskPQ) - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type float --dist_fn mips --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 --PQ_disk_bytes 5 - ${{ env.diskann_built_tests }}/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded --result_path /tmp/res --query_file ./rand_float_10D_1K_norm1.0.bin --gt_file ./mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search an incremental index - run: | - ${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type float --dist_fn l2 --data_path rand_float_10D_10K_norm1.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2; - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type float --dist_fn l2 --base_file index_ins_del.after-concurrent-delete-del2500-7500.data --query_file rand_float_10D_1K_norm1.0.bin --K 100 --gt_file gt100_random10D_1K-conc-2500-7500 --tags_file index_ins_del.after-concurrent-delete-del2500-7500.tags - ${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_ins_del.after-concurrent-delete-del2500-7500 --result_path res_ins_del --query_file ./rand_float_10D_1K_norm1.0.bin --gt_file gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1 - - name: test a streaming index - run: | - ${{ env.diskann_built_tests }}/test_streaming_scenario --data_type float --dist_fn l2 --data_path rand_float_10D_10K_norm1.0.bin --index_path_prefix index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type float --dist_fn l2 --base_file index_stream.after-streaming-act4000-cons2000-max10000.data --query_file rand_float_10D_1K_norm1.0.bin --K 100 --gt_file gt100_base-act4000-cons2000-max10000 --tags_file index_stream.after-streaming-act4000-cons2000-max10000.tags - ${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file ./rand_float_10D_1K_norm1.0.bin --gt_file gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1 - - - - name: Generate 10K random int8 index vectors, 1K query vectors, in 10 dims and compute GT - run: | - ${{ env.diskann_built_utils }}/rand_data_gen --data_type int8 --output_file ./rand_int8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0 - ${{ env.diskann_built_utils }}/rand_data_gen --data_type int8 --output_file ./rand_int8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type int8 --dist_fn l2 --base_file ./rand_int8_10D_10K_norm50.0.bin --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file ./l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type int8 --dist_fn mips --base_file ./rand_int8_10D_10K_norm50.0.bin --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file ./mips_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type int8 --dist_fn cosine --base_file ./rand_int8_10D_10K_norm50.0.bin --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file ./cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 - - name: build and search in-memory index with L2 metrics - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type int8 --dist_fn l2 --data_path ./rand_int8_10D_10K_norm50.0.bin --index_path_prefix ./index_l2_rand_int8_10D_10K_norm50.0 - ${{ env.diskann_built_tests }}/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_int8_10D_10K_norm50.0 --query_file ./rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 - - name: build and search in-memory index with cosine metric - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type int8 --dist_fn cosine --data_path ./rand_int8_10D_10K_norm50.0.bin --index_path_prefix ./index_cosine_rand_int8_10D_10K_norm50.0 - ${{ env.diskann_built_tests }}/search_memory_index --data_type int8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_int8_10D_10K_norm50.0 --query_file ./rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 - - name: build and search in-memory index with L2 metrics with PQ base distance comparisons - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type int8 --dist_fn l2 --data_path ./rand_int8_10D_10K_norm50.0.bin --index_path_prefix ./index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5 - ${{ env.diskann_built_tests }}/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --query_file ./rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 - - name: build and search disk index (one shot graph build, L2, no diskPQ) - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type int8 --dist_fn l2 --data_path ./rand_int8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1 - ${{ env.diskann_built_tests }}/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file ./l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type int8 --dist_fn l2 --data_path ./rand_int8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5 - ${{ env.diskann_built_tests }}/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file ./l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search disk index (sharded graph build, L2, no diskPQ) - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type int8 --dist_fn l2 --data_path ./rand_int8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 - ${{ env.diskann_built_tests }}/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file ./l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search disk index (one shot graph build, L2, diskPQ) - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type int8 --dist_fn l2 --data_path ./rand_int8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5 - ${{ env.diskann_built_tests }}/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file ./l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search an incremental index - run: | - ${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path rand_int8_10D_10K_norm50.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type int8 --dist_fn l2 --base_file index_ins_del.after-concurrent-delete-del2500-7500.data --query_file rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_random10D_1K-conc-2500-7500 --tags_file index_ins_del.after-concurrent-delete-del2500-7500.tags - ${{ env.diskann_built_tests }}/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_ins_del.after-concurrent-delete-del2500-7500 --result_path res_ins_del --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1 - - name: test a streaming index - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/test_streaming_scenario --data_type int8 --dist_fn l2 --data_path rand_int8_10D_10K_norm50.0.bin --index_path_prefix index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type int8 --dist_fn l2 --base_file index_stream.after-streaming-act4000-cons2000-max10000.data --query_file rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_base-act4000-cons2000-max10000 --tags_file index_stream.after-streaming-act4000-cons2000-max10000.tags - ${{ env.diskann_built_tests }}/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1 - - - - name: Generate 10K random uint8 index vectors, 1K query vectors, in 10 dims and compute GT - if: success() || failure() - run: | - ${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0 - ${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn mips --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 - - name: build and search in-memory index with L2 metrics - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0 - ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 - - name: build and search in-memory index with cosine metric - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn cosine --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_cosine_rand_uint8_10D_10K_norm50.0 - ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 - - name: build and search in-memory index with L2 metrics with PQ base distance comparisons - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5 - ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 - - name: build and search disk index (one shot graph build, L2, no diskPQ) - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1 - ${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5 - ${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search disk index (sharded graph build, L2, no diskPQ) - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 - ${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search disk index (one shot graph build, L2, diskPQ) - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5 - ${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - - name: build and search an incremental index - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type uint8 --dist_fn l2 --data_path rand_uint8_10D_10K_norm50.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200; - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file index_ins_del.after-concurrent-delete-del2500-7500.data --query_file rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_random10D_10K-conc-2500-7500 --tags_file index_ins_del.after-concurrent-delete-del2500-7500.tags - ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_ins_del.after-concurrent-delete-del2500-7500 --result_path res_ins_del --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file gt100_random10D_10K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1 - - name: test a streaming index - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/test_streaming_scenario --data_type uint8 --dist_fn l2 --data_path rand_uint8_10D_10K_norm50.0.bin --index_path_prefix index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200 - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file index_stream.after-streaming-act4000-cons2000-max10000.data --query_file rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_base-act4000-cons2000-max10000 --tags_file index_stream.after-streaming-act4000-cons2000-max10000.tags - ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1 - - - name: Generate 10K random uint8 index vectors, 1K query vectors, 10K Label Points (50 unique labels), in 10 dims and compute GT - if: success() || failure() - run: | - ${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0 - ${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0 - ${{ env.diskann_built_utils }}/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file ./rand_labels_10_10K.txt - ${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 10 --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100 - - name: build and search in-memory index with labels using L2 metrics - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel - ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32 - - - name: build and search in-memory index with pq_dist of 5 with 10 dimensions - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --build_PQ_bytes 5 - ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32 - - - name: Build and search stitched vamana - if: success() || failure() - run: | - ${{ env.diskann_built_tests }}/build_stitched_index --num_threads 48 --data_type uint8 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix ./stit_32_100_64_new --universal_label 0 - ${{ env.diskann_built_tests }}/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 10 --index_path_prefix ./stit_32_100_64_new --query_file ./rand_uint8_10D_1K_norm50.0.bin --result_path ./rand_stit_96_10_90_new --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 10 10 10 10 10 30 50 70 90 110 130 150 170 190 210 230 250 270 290 310 330 350 370 390 410 - - uses: actions/setup-python@v3 - - name: Install cibuildwheel - run: python -m pip install cibuildwheel==2.11.3 - - name: build wheels - run: python -m cibuildwheel --output-dir wheelhouse - env: - CIBW_ARCHS_WINDOWS: AMD64 - CIBW_ARCHS_LINUX: x86_64 + fail-fast: true + name: DiskANN Common Build Checks + uses: ./.github/workflows/common.yml + unit-tests: + name: Unit tests + uses: ./.github/workflows/unit-tests.yml + in-mem-pq: + name: In-Memory with PQ + uses: ./.github/workflows/in-mem-pq.yml + in-mem-no-pq: + name: In-Memory without PQ + uses: ./.github/workflows/in-mem-no-pq.yml + disk-pq: + name: Disk with PQ + uses: ./.github/workflows/disk-pq.yml + labels: + name: Labels + uses: ./.github/workflows/labels.yml + dynamic: + name: Dynamic + uses: ./.github/workflows/dynamic.yml + python: + name: Python + uses: ./.github/workflows/build-python.yml diff --git a/.github/workflows/push-test.yml b/.github/workflows/push-test.yml index cd186af9d..4de999014 100644 --- a/.github/workflows/push-test.yml +++ b/.github/workflows/push-test.yml @@ -1,55 +1,35 @@ name: DiskANN Push Build on: [push] jobs: - ubuntu-latest-build: - runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v2 - - name: Install deps - run: | - sudo apt install cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libmkl-full-dev libcpprest-dev - - name: Clang Format Check - run: | - mkdir build && cd build && cmake -DRESTAPI=True -DCMAKE_BUILD_TYPE=Release .. - make checkformat - - name: build - run: | - cd build && make -j - - uses: actions/setup-python@v3 - with: - python-version: "3.10" - - name: Install Python Build Module - run: python -m pip install build - - name: Python Build - run: python -m build - - windows-build: - name: Build for ${{ matrix.os }} - runs-on: ${{ matrix.os }} - + common: + strategy: + fail-fast: true + name: DiskANN Common Build Checks + uses: ./.github/workflows/common.yml + build: strategy: + fail-fast: false matrix: - os: [windows-2019, windows-latest] - + os: [ ubuntu-latest, windows-2019, windows-latest ] + name: Build for ${{matrix.os}} + runs-on: ${{matrix.os}} + defaults: + run: + shell: bash steps: - - name: Checkout repository - uses: actions/checkout@v2 - with: + - name: Checkout repository + if: ${{ runner.os == 'Linux' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + - name: Checkout repository + if: ${{ runner.os == 'Windows' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 submodules: true - - - name: Add VisualStudio command line tools into path - uses: ilammy/msvc-dev-cmd@v1 - - - name: Run configure and build - run: | - mkdir build && cd build && cmake .. && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary" - shell: cmd - - - uses: actions/setup-python@v3 - with: - python-version: "3.10" - - name: Install Python Build Module - run: python -m pip install build - - name: Python Build - run: python -m build + - name: DiskANN Build CLI Applications + uses: ./.github/actions/build +# python: +# name: DiskANN Build Python Wheel +# uses: ./.github/workflows/build-python.yml diff --git a/.github/workflows/python-release.yml b/.github/workflows/python-release.yml index 2651328b8..a1e72ad90 100644 --- a/.github/workflows/python-release.yml +++ b/.github/workflows/python-release.yml @@ -3,39 +3,20 @@ on: release: types: [published] jobs: - build_wheels: - name: Build wheels on ${{ matrix.os }} - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest, windows-latest] - steps: - - uses: actions/checkout@v3 - with: - submodules: true - - uses: actions/setup-python@v3 - - name: Install cibuildwheel - run: python -m pip install cibuildwheel==2.11.3 - - name: build wheels - run: python -m cibuildwheel --output-dir wheelhouse - env: - CIBW_ARCHS_WINDOWS: AMD64 - CIBW_ARCHS_LINUX: x86_64 - - uses: actions/upload-artifact@v3 - with: - name: wheelhouse - path: ./wheelhouse/*.whl + python-release-wheels: + name: Python + uses: ./.github/workflows/build-python.yml release: runs-on: ubuntu-latest - needs: build_wheels + needs: python-release-wheels steps: - uses: actions/download-artifact@v3 with: - name: wheelhouse - path: wheelhouse/ + name: wheels + path: dist/ - name: Generate SHA256 files for each wheel run: | - sha256sum wheelhouse/*.whl > checksums.txt + sha256sum dist/*.whl > checksums.txt cat checksums.txt - uses: actions/setup-python@v3 - name: Install twine @@ -45,11 +26,11 @@ jobs: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: | - twine upload wheelhouse/*.whl + twine upload dist/*.whl - name: Update release with SHA256 and Artifacts uses: softprops/action-gh-release@v1 with: token: ${{ secrets.GITHUB_TOKEN }} files: | - wheelhouse/*.whl + dist/*.whl checksums.txt diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml new file mode 100644 index 000000000..6ae6877b8 --- /dev/null +++ b/.github/workflows/unit-tests.yml @@ -0,0 +1,32 @@ +name: Unit Tests +on: [workflow_call] +jobs: + acceptance-tests-labels: + name: Unit Tests + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-2019, windows-latest] + runs-on: ${{matrix.os}} + defaults: + run: + shell: bash + steps: + - name: Checkout repository + if: ${{ runner.os == 'Linux' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + - name: Checkout repository + if: ${{ runner.os == 'Windows' }} + uses: actions/checkout@v3 + with: + fetch-depth: 1 + submodules: true + - name: DiskANN Build CLI Applications + uses: ./.github/actions/build + + - name: Run Unit Tests + run: | + cd build + ctest -C Release \ No newline at end of file diff --git a/.gitignore b/.gitignore index d7a8e4741..f80e5c682 100644 --- a/.gitignore +++ b/.gitignore @@ -357,6 +357,8 @@ MigrationBackup/ cscope* build/ +build_linux/ +!.github/actions/build # jetbrains specific stuff .idea/ @@ -371,3 +373,8 @@ wheelhouse/* dist/* venv*/** *.swp + +gperftools + +# Rust +rust/target diff --git a/CMakeLists.txt b/CMakeLists.txt index f29e13f3a..89530f818 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,16 +15,19 @@ # Contact for this feature: gopalrs. # Some variables like MSVC are defined only after project(), so put that first. -project(diskann) - cmake_minimum_required(VERSION 3.15) +project(diskann) -set(CMAKE_STANDARD 14) +set(CMAKE_STANDARD 17) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) if(NOT MSVC) set(CMAKE_CXX_COMPILER g++) endif() +set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}") + # Install nuget packages for dependencies. if (MSVC) find_program(NUGET_EXE NAMES nuget) @@ -42,7 +45,7 @@ if (MSVC) if (RESTAPI) set(DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG ${CMAKE_BINARY_DIR}/restapi/packages.config) configure_file(${PROJECT_SOURCE_DIR}/windows/packages_restapi.config.in ${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG}) - exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\") + exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\") endif() message(STATUS "Finished setting up nuget dependencies") endif() @@ -136,18 +139,63 @@ if (MSVC) "${DISKANN_MKL_LIB_PATH}/mkl_intel_thread.lib") else() # expected path for manual intel mkl installs - set(OMP_PATH /opt/intel/oneapi/compiler/2022.0.2/linux/compiler/lib/intel64_lin/ CACHE PATH "Intel OneAPI OpenMP library implementation path") - set(MKL_ROOT /opt/intel/oneapi/mkl/latest CACHE PATH "Intel OneAPI MKL library implementation path") - link_directories(${OMP_PATH} ${MKL_ROOT}/lib/intel64) - include_directories(${MKL_ROOT}/include) - - # expected path for apt packaged intel mkl installs - link_directories(/usr/lib/x86_64-linux-gnu/mkl) - include_directories(/usr/include/mkl) + set(POSSIBLE_OMP_PATHS "/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin/libiomp5.so;/usr/lib/x86_64-linux-gnu/libiomp5.so;/opt/intel/lib/intel64_lin/libiomp5.so") + foreach(POSSIBLE_OMP_PATH ${POSSIBLE_OMP_PATHS}) + if (EXISTS ${POSSIBLE_OMP_PATH}) + get_filename_component(OMP_PATH ${POSSIBLE_OMP_PATH} DIRECTORY) + endif() + endforeach() + + if(NOT OMP_PATH) + message(FATAL_ERROR "Could not find Intel OMP in standard locations; use -DOMP_PATH to specify the install location for your environment") + endif() + link_directories(${OMP_PATH}) + + set(POSSIBLE_MKL_LIB_PATHS "/opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so;/usr/lib/x86_64-linux-gnu/libmkl_core.so;/opt/intel/mkl/lib/intel64/libmkl_core.so") + foreach(POSSIBLE_MKL_LIB_PATH ${POSSIBLE_MKL_LIB_PATHS}) + if (EXISTS ${POSSIBLE_MKL_LIB_PATH}) + get_filename_component(MKL_PATH ${POSSIBLE_MKL_LIB_PATH} DIRECTORY) + endif() + endforeach() + + set(POSSIBLE_MKL_INCLUDE_PATHS "/opt/intel/oneapi/mkl/latest/include;/usr/include/mkl;/opt/intel/mkl/include/;") + foreach(POSSIBLE_MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATHS}) + if (EXISTS ${POSSIBLE_MKL_INCLUDE_PATH}) + set(MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATH}) + endif() + endforeach() + if(NOT MKL_PATH) + message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_PATH to specify the install location for your environment") + elseif(NOT MKL_INCLUDE_PATH) + message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_INCLUDE_PATH to specify the install location for headers for your environment") + endif() + if (EXISTS ${MKL_PATH}/libmkl_def.so.2) + set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so.2) + elseif(EXISTS ${MKL_PATH}/libmkl_def.so) + set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so) + else() + message(FATAL_ERROR "Despite finding MKL, libmkl_def.so was not found in expected locations.") + endif() + link_directories(${MKL_PATH}) + include_directories(${MKL_INCLUDE_PATH}) # compile flags and link libraries add_compile_options(-m64 -Wl,--no-as-needed) - link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core iomp5 pthread m dl) + if (NOT PYBIND) + link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core iomp5 pthread m dl) + else() + # static linking for python so as to minimize customer dependency issues + link_libraries( + ${MKL_PATH}/libmkl_intel_ilp64.a + ${MKL_PATH}/libmkl_intel_thread.a + ${MKL_PATH}/libmkl_core.a + ${MKL_DEF_SO} + iomp5 + pthread + m + dl + ) + endif() endif() add_definitions(-DMKL_ILP64) @@ -193,7 +241,7 @@ if (MSVC) add_dependencies(libtcmalloc_minimal_for_exe build_libtcmalloc_minimal) set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS libtcmalloc_minimal_for_exe) -else() +elseif(NOT PYBIND) set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-ltcmalloc") endif() @@ -212,7 +260,7 @@ endif() #Main compiler/linker settings if(MSVC) #language options - add_compile_options(/permissive- /openmp:experimental /Zc:twoPhase- /Zc:inline /WX- /std:c++14 /Gd /W3 /MP /Zi /FC /nologo) + add_compile_options(/permissive- /openmp:experimental /Zc:twoPhase- /Zc:inline /WX- /std:c++17 /Gd /W3 /MP /Zi /FC /nologo) #code generation options add_compile_options(/arch:AVX2 /fp:fast /fp:except- /EHsc /GS- /Gy) #optimization options @@ -231,19 +279,30 @@ if(MSVC) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release) else() - set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000) - # set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG -O0 -fsanitize=address -fsanitize=leak -fsanitize=undefined") - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG -Wall -Wextra") - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast -DNDEBUG -march=native -mtune=native -ftree-vectorize") - add_compile_options(-march=native -Wall -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DUSE_AVX2) - if (PYBIND) - add_compile_options(-fPIC) + set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma -msse2 -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DUSE_AVX2") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG") + if (NOT PYBIND) + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast") + if (NOT PORTABLE) + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -march=native -mtune=native") + endif() + else() + # -Ofast is not supported in a python extension module + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -fPIC") endif() endif() add_subdirectory(src) -add_subdirectory(tests) -add_subdirectory(tests/utils) +if (NOT PYBIND) + add_subdirectory(apps) + add_subdirectory(apps/utils) +endif() + +if (UNIT_TEST) + enable_testing() + add_subdirectory(tests) +endif() if (MSVC) message(STATUS "The ${PROJECT_NAME}.sln has been created, opened it from VisualStudio to build Release or Debug configurations.\n" @@ -258,12 +317,11 @@ if (RESTAPI) link_libraries("${DISKANN_CPPRESTSDK}/x64/lib/cpprest142_2_10.lib") include_directories("${DISKANN_CPPRESTSDK}/include") endif() - add_subdirectory(tests/restapi) + add_subdirectory(apps/restapi) endif() include(clang-format.cmake) - if(PYBIND) add_subdirectory(python) endif() diff --git a/Dockerfile b/Dockerfile index 95c5f3494..ea1979f3f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,13 +1,17 @@ -FROM ubuntu:16.04 -MAINTAINER Changxu Wang +#Copyright(c) Microsoft Corporation.All rights reserved. +#Licensed under the MIT license. -RUN apt-get update -y -RUN apt-get install -y g++ cmake libboost-dev libgoogle-perftools-dev +FROM ubuntu:jammy -COPY . /opt/nsg +RUN apt update +RUN apt install -y software-properties-common +RUN add-apt-repository -y ppa:git-core/ppa +RUN apt update +RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libmkl-full-dev libcpprest-dev python3.10 -WORKDIR /opt/nsg - -RUN mkdir -p build && cd build && \ - cmake -DCMAKE_BUILD_TYPE=Release .. && \ - make -j $(nproc) +WORKDIR /app +RUN git clone https://github.com/microsoft/DiskANN.git +WORKDIR /app/DiskANN +RUN mkdir build +RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release +RUN cmake --build build -- -j diff --git a/DockerfileDev b/DockerfileDev new file mode 100644 index 000000000..0e95e405f --- /dev/null +++ b/DockerfileDev @@ -0,0 +1,17 @@ +#Copyright(c) Microsoft Corporation.All rights reserved. +#Licensed under the MIT license. + +FROM ubuntu:jammy + +RUN apt update +RUN apt install -y software-properties-common +RUN add-apt-repository -y ppa:git-core/ppa +RUN apt update +RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libboost-test-dev libmkl-full-dev libcpprest-dev python3.10 + +WORKDIR /app +RUN git clone https://github.com/microsoft/DiskANN.git +WORKDIR /app/DiskANN +RUN mkdir build +RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DUNIT_TEST=True +RUN cmake --build build -- -j diff --git a/MANIFEST.in b/MANIFEST.in index f6b8b82c0..0735c2783 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -10,4 +10,3 @@ recursive-include python * recursive-include windows * prune python/tests recursive-include src * -recursive-include tests * diff --git a/README.md b/README.md index 41f2fae4f..2922c16c1 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,12 @@ # DiskANN -[![DiskANN Pull Request Build and Test](https://github.com/microsoft/DiskANN/actions/workflows/pr-test.yml/badge.svg)](https://github.com/microsoft/DiskANN/actions/workflows/pr-test.yml) +[![DiskANN Paper](https://img.shields.io/badge/Paper-NeurIPS%3A_DiskANN-blue)](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf) +[![DiskANN Paper](https://img.shields.io/badge/Paper-Arxiv%3A_Fresh--DiskANN-blue)](https://arxiv.org/abs/2105.09613) +[![DiskANN Paper](https://img.shields.io/badge/Paper-Filtered--DiskANN-blue)](https://harsha-simhadri.org/pubs/Filtered-DiskANN23.pdf) +[![DiskANN Main](https://github.com/microsoft/DiskANN/actions/workflows/push-test.yml/badge.svg?branch=main)](https://github.com/microsoft/DiskANN/actions/workflows/push-test.yml) +[![PyPI version](https://img.shields.io/pypi/v/diskannpy.svg)](https://pypi.org/project/diskannpy/) +[![Downloads shield](https://pepy.tech/badge/diskannpy)](https://pepy.tech/project/diskannpy) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) DiskANN is a suite of scalable, accurate and cost-effective approximate nearest neighbor search algorithms for large-scale vector search that support real-time changes and simple filters. This code is based on ideas from the [DiskANN](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf), [Fresh-DiskANN](https://arxiv.org/abs/2105.09613) and the [Filtered-DiskANN](https://harsha-simhadri.org/pubs/Filtered-DiskANN23.pdf) papers with further improvements. @@ -12,8 +18,6 @@ contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additio See [guidelines](CONTRIBUTING.md) for contributing to this project. - - ## Linux build: Install the following packages through apt-get @@ -23,7 +27,7 @@ sudo apt install make cmake g++ libaio-dev libgoogle-perftools-dev clang-format ``` ### Install Intel MKL -#### Ubuntu 20.04 +#### Ubuntu 20.04 or newer ```bash sudo apt install libmkl-full-dev ``` @@ -71,12 +75,16 @@ OR for Visual Studio 2017 and earlier: ``` \cmake .. ``` -* This will create a diskann.sln solution. Open it from VisualStudio and build either Release or Debug configuration. - * Alternatively, use MSBuild: +**This will create a diskann.sln solution**. Now you can: + +- Open it from VisualStudio and build either Release or Debug configuration. +- `\cmake --build build` +- Use MSBuild: ``` msbuild.exe diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" ``` - * This will also build gperftools submodule for libtcmalloc_minimal dependency. + +* This will also build gperftools submodule for libtcmalloc_minimal dependency. * Generated binaries are stored in the x64/Release or x64/Debug directories. ## Usage: @@ -88,16 +96,16 @@ Please see the following pages on using the compiled code: - [Commandline examples for using in-memory streaming indices](workflows/dynamic_index.md) - [Commandline interface for building and search in memory indices with label data and filters](workflows/filtered_in_memory.md) - [Commandline interface for building and search SSD based indices with label data and filters](workflows/filtered_ssd_index.md) -- To be added: Python interfaces and docker files +- [diskannpy - DiskANN as a python extension module](python/README.md) Please cite this software in your work as: ``` @misc{diskann-github, - author = {Simhadri, Harsha Vardhan and Krishnaswamy, Ravishankar and Srinivasa, Gopal and Subramanya, Suhas Jayaram and Antonijevic, Andrija and Pryce, Dax and Kaczynski, David and Williams, Shane and Gollapudi, Siddarth and Sivashankar, Varun and Karia, Neel and Singh, Aditi and Jaiswal, Shikhar and Mahapatro, Neelam and Adams, Philip and Tower, Bryan}}, + author = {Simhadri, Harsha Vardhan and Krishnaswamy, Ravishankar and Srinivasa, Gopal and Subramanya, Suhas Jayaram and Antonijevic, Andrija and Pryce, Dax and Kaczynski, David and Williams, Shane and Gollapudi, Siddarth and Sivashankar, Varun and Karia, Neel and Singh, Aditi and Jaiswal, Shikhar and Mahapatro, Neelam and Adams, Philip and Tower, Bryan and Patel, Yash}}, title = {{DiskANN: Graph-structured Indices for Scalable, Fast, Fresh and Filtered Approximate Nearest Neighbor Search}}, url = {https://github.com/Microsoft/DiskANN}, - version = {0.5}, + version = {0.6.0}, year = {2023} } ``` diff --git a/apps/CMakeLists.txt b/apps/CMakeLists.txt new file mode 100644 index 000000000..e42c0b6cb --- /dev/null +++ b/apps/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_COMPILE_WARNING_AS_ERROR ON) + +add_executable(build_memory_index build_memory_index.cpp) +target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) + +add_executable(build_stitched_index build_stitched_index.cpp) +target_link_libraries(build_stitched_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) + +add_executable(search_memory_index search_memory_index.cpp) +target_link_libraries(search_memory_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) + +add_executable(build_disk_index build_disk_index.cpp) +target_link_libraries(build_disk_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB} Boost::program_options) + +add_executable(search_disk_index search_disk_index.cpp) +target_link_libraries(search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) + +add_executable(range_search_disk_index range_search_disk_index.cpp) +target_link_libraries(range_search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) + +add_executable(test_streaming_scenario test_streaming_scenario.cpp) +target_link_libraries(test_streaming_scenario ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) + +add_executable(test_insert_deletes_consolidate test_insert_deletes_consolidate.cpp) +target_link_libraries(test_insert_deletes_consolidate ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) + +if (NOT MSVC) + install(TARGETS build_memory_index + build_stitched_index + search_memory_index + build_disk_index + search_disk_index + range_search_disk_index + test_streaming_scenario + test_insert_deletes_consolidate + RUNTIME + ) +endif() diff --git a/apps/build_disk_index.cpp b/apps/build_disk_index.cpp new file mode 100644 index 000000000..b617a5f4a --- /dev/null +++ b/apps/build_disk_index.cpp @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include + +#include "utils.h" +#include "disk_utils.h" +#include "math_utils.h" +#include "index.h" +#include "partition.h" +#include "program_options_utils.hpp" + +namespace po = boost::program_options; + +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label, + label_type; + uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold; + float B, M; + bool append_reorder_data = false; + bool use_opq = false; + + po::options_description desc{ + program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")}; + try + { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + required_configs.add_options()("search_DRAM_budget,B", po::value(&B)->required(), + "DRAM budget in GB for searching the index to set the " + "compressed level for data while search happens"); + required_configs.add_options()("build_DRAM_budget,M", po::value(&M)->required(), + "DRAM budget in GB for building the index"); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()("QD", po::value(&QD)->default_value(0), + " Quantized Dimension for compression"); + optional_configs.add_options()("codebook_prefix", po::value(&codebook_prefix)->default_value(""), + "Path prefix for pre-trained codebook"); + optional_configs.add_options()("PQ_disk_bytes", po::value(&disk_PQ)->default_value(0), + "Number of bytes to which vectors should be compressed " + "on SSD; 0 for no compression"); + optional_configs.add_options()("append_reorder_data", po::bool_switch()->default_value(false), + "Include full precision data in the index. Use only in " + "conjuction with compressed data on SSD."); + optional_configs.add_options()("build_PQ_bytes", po::value(&build_PQ)->default_value(0), + program_options_utils::BUIlD_GRAPH_PQ_BYTES); + optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false), + program_options_utils::USE_OPQ); + optional_configs.add_options()("label_file", po::value(&label_file)->default_value(""), + program_options_utils::LABEL_FILE); + optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), + program_options_utils::UNIVERSAL_LABEL); + optional_configs.add_options()("FilteredLbuild", po::value(&Lf)->default_value(0), + program_options_utils::FILTERED_LBUILD); + optional_configs.add_options()("filter_threshold,F", po::value(&filter_threshold)->default_value(0), + "Threshold to break up the existing nodes to generate new graph " + "internally where each node has a maximum F labels."); + optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + if (vm["append_reorder_data"].as()) + append_reorder_data = true; + if (vm["use_opq"].as()) + use_opq = true; + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; + } + + bool use_filters = (label_file != "") ? true : false; + diskann::Metric metric; + if (dist_fn == std::string("l2")) + metric = diskann::Metric::L2; + else if (dist_fn == std::string("mips")) + metric = diskann::Metric::INNER_PRODUCT; + else + { + std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl; + return -1; + } + + if (append_reorder_data) + { + if (disk_PQ == 0) + { + std::cout << "Error: It is not necessary to append data for reordering " + "when vectors are not compressed on disk." + << std::endl; + return -1; + } + if (data_type != std::string("float")) + { + std::cout << "Error: Appending data for reordering currently only " + "supported for float data type." + << std::endl; + return -1; + } + } + + std::string params = std::string(std::to_string(R)) + " " + std::string(std::to_string(L)) + " " + + std::string(std::to_string(B)) + " " + std::string(std::to_string(M)) + " " + + std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " + + std::string(std::to_string(append_reorder_data)) + " " + + std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD)); + + try + { + if (label_file != "" && label_type == "ushort") + { + if (data_type == std::string("int8")) + return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else if (data_type == std::string("uint8")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, + use_filters, label_file, universal_label, filter_threshold, Lf); + else if (data_type == std::string("float")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, + use_filters, label_file, universal_label, filter_threshold, Lf); + else + { + diskann::cerr << "Error. Unsupported data type" << std::endl; + return -1; + } + } + else + { + if (data_type == std::string("int8")) + return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else if (data_type == std::string("uint8")) + return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else if (data_type == std::string("float")) + return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf); + else + { + diskann::cerr << "Error. Unsupported data type" << std::endl; + return -1; + } + } + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index build failed." << std::endl; + return -1; + } +} diff --git a/apps/build_memory_index.cpp b/apps/build_memory_index.cpp new file mode 100644 index 000000000..92b269f4f --- /dev/null +++ b/apps/build_memory_index.cpp @@ -0,0 +1,206 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include + +#include "index.h" +#include "utils.h" +#include "program_options_utils.hpp" + +#ifndef _WINDOWS +#include +#include +#else +#include +#endif + +#include "memory_mapper.h" +#include "ann_exception.h" +#include "index_factory.h" + +namespace po = boost::program_options; + +template +int build_in_memory_index(const diskann::Metric &metric, const std::string &data_path, const uint32_t R, + const uint32_t L, const float alpha, const std::string &save_path, const uint32_t num_threads, + const bool use_pq_build, const size_t num_pq_bytes, const bool use_opq, + const std::string &label_file, const std::string &universal_label, const uint32_t Lf) +{ + diskann::IndexWriteParameters paras = diskann::IndexWriteParametersBuilder(L, R) + .with_filter_list_size(Lf) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + std::string labels_file_to_use = save_path + "_label_formatted.txt"; + std::string mem_labels_int_map_file = save_path + "_labels_map.txt"; + + size_t data_num, data_dim; + diskann::get_bin_metadata(data_path, data_num, data_dim); + + diskann::Index index(metric, data_dim, data_num, false, false, false, use_pq_build, num_pq_bytes, + use_opq); + auto s = std::chrono::high_resolution_clock::now(); + if (label_file == "") + { + index.build(data_path.c_str(), data_num, paras); + } + else + { + convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label); + if (universal_label != "") + { + LabelT unv_label_as_num = 0; + index.set_universal_label(unv_label_as_num); + } + index.build_filtered_index(data_path.c_str(), labels_file_to_use, data_num, paras); + } + std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; + + std::cout << "Indexing time: " << diff.count() << "\n"; + index.save(save_path.c_str()); + if (label_file != "") + std::remove(labels_file_to_use.c_str()); + return 0; +} + +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type; + uint32_t num_threads, R, L, Lf, build_PQ_bytes; + float alpha; + bool use_pq_build, use_opq; + + po::options_description desc{ + program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")}; + try + { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()("build_PQ_bytes", po::value(&build_PQ_bytes)->default_value(0), + program_options_utils::BUIlD_GRAPH_PQ_BYTES); + optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false), + program_options_utils::USE_OPQ); + optional_configs.add_options()("label_file", po::value(&label_file)->default_value(""), + program_options_utils::LABEL_FILE); + optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), + program_options_utils::UNIVERSAL_LABEL); + + optional_configs.add_options()("FilteredLbuild", po::value(&Lf)->default_value(0), + program_options_utils::FILTERED_LBUILD); + optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + use_pq_build = (build_PQ_bytes > 0); + use_opq = vm["use_opq"].as(); + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("mips")) + { + metric = diskann::Metric::INNER_PRODUCT; + } + else if (dist_fn == std::string("l2")) + { + metric = diskann::Metric::L2; + } + else if (dist_fn == std::string("cosine")) + { + metric = diskann::Metric::COSINE; + } + else + { + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product/Cosine are supported." + << std::endl; + return -1; + } + + try + { + diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha + << " #threads: " << num_threads << std::endl; + + size_t data_num, data_dim; + diskann::get_bin_metadata(data_path, data_num, data_dim); + + auto config = diskann::IndexConfigBuilder() + .with_metric(metric) + .with_dimension(data_dim) + .with_max_points(data_num) + .with_data_load_store_strategy(diskann::MEMORY) + .with_data_type(data_type) + .with_label_type(label_type) + .is_dynamic_index(false) + .is_enable_tags(false) + .is_use_opq(use_opq) + .is_pq_dist_build(use_pq_build) + .with_num_pq_chunks(build_PQ_bytes) + .build(); + + auto index_build_params = diskann::IndexWriteParametersBuilder(L, R) + .with_filter_list_size(Lf) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + auto build_params = diskann::IndexBuildParamsBuilder(index_build_params) + .with_universal_label(universal_label) + .with_label_file(label_file) + .with_save_path_prefix(index_path_prefix) + .build(); + auto index_factory = diskann::IndexFactory(config); + auto index = index_factory.create_instance(); + index->build(data_path, data_num, build_params); + index->save(index_path_prefix.c_str()); + index.reset(); + return 0; + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Index build failed." << std::endl; + return -1; + } +} diff --git a/apps/build_stitched_index.cpp b/apps/build_stitched_index.cpp new file mode 100644 index 000000000..80481f8b0 --- /dev/null +++ b/apps/build_stitched_index.cpp @@ -0,0 +1,439 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include +#include +#include "filter_utils.h" +#include +#ifndef _WINDOWS +#include +#endif + +#include "index.h" +#include "memory_mapper.h" +#include "parameters.h" +#include "utils.h" +#include "program_options_utils.hpp" + +namespace po = boost::program_options; +typedef std::tuple>, uint64_t> stitch_indices_return_values; + +/* + * Inline function to display progress bar. + */ +inline void print_progress(double percentage) +{ + int val = (int)(percentage * 100); + int lpad = (int)(percentage * PBWIDTH); + int rpad = PBWIDTH - lpad; + printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, ""); + fflush(stdout); +} + +/* + * Inline function to generate a random integer in a range. + */ +inline size_t random(size_t range_from, size_t range_to) +{ + std::random_device rand_dev; + std::mt19937 generator(rand_dev()); + std::uniform_int_distribution distr(range_from, range_to); + return distr(generator); +} + +/* + * function to handle command line parsing. + * + * Arguments are merely the inputs from the command line. + */ +void handle_args(int argc, char **argv, std::string &data_type, path &input_data_path, path &final_index_path_prefix, + path &label_data_path, std::string &universal_label, uint32_t &num_threads, uint32_t &R, uint32_t &L, + uint32_t &stitched_R, float &alpha) +{ + po::options_description desc{ + program_options_utils::make_program_description("build_stitched_index", "Build a stitched DiskANN index.")}; + try + { + desc.add_options()("help,h", "Print information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("index_path_prefix", + po::value(&final_index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("data_path", po::value(&input_data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()("label_file", po::value(&label_data_path)->default_value(""), + program_options_utils::LABEL_FILE); + optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), + program_options_utils::UNIVERSAL_LABEL); + optional_configs.add_options()("stitched_R", po::value(&stitched_R)->default_value(100), + "Degree to prune final graph down to"); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + exit(0); + } + po::notify(vm); + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + throw; + } +} + +/* + * Custom index save to write the in-memory index to disk. + * Also writes required files for diskANN API - + * 1. labels_to_medoids + * 2. universal_label + * 3. data (redundant for static indices) + * 4. labels (redundant for static indices) + */ +void save_full_index(path final_index_path_prefix, path input_data_path, uint64_t final_index_size, + std::vector> stitched_graph, + tsl::robin_map entry_points, std::string universal_label, + path label_data_path) +{ + // aux. file 1 + auto saving_index_timer = std::chrono::high_resolution_clock::now(); + std::ifstream original_label_data_stream; + original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + original_label_data_stream.open(label_data_path, std::ios::binary); + std::ofstream new_label_data_stream; + new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + new_label_data_stream.open(final_index_path_prefix + "_labels.txt", std::ios::binary); + new_label_data_stream << original_label_data_stream.rdbuf(); + original_label_data_stream.close(); + new_label_data_stream.close(); + + // aux. file 2 + std::ifstream original_input_data_stream; + original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + original_input_data_stream.open(input_data_path, std::ios::binary); + std::ofstream new_input_data_stream; + new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + new_input_data_stream.open(final_index_path_prefix + ".data", std::ios::binary); + new_input_data_stream << original_input_data_stream.rdbuf(); + original_input_data_stream.close(); + new_input_data_stream.close(); + + // aux. file 3 + std::ofstream labels_to_medoids_writer; + labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit); + labels_to_medoids_writer.open(final_index_path_prefix + "_labels_to_medoids.txt"); + for (auto iter : entry_points) + labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl; + labels_to_medoids_writer.close(); + + // aux. file 4 (only if we're using a universal label) + if (universal_label != "") + { + std::ofstream universal_label_writer; + universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit); + universal_label_writer.open(final_index_path_prefix + "_universal_label.txt"); + universal_label_writer << universal_label << std::endl; + universal_label_writer.close(); + } + + // main index + uint64_t index_num_frozen_points = 0, index_num_edges = 0; + uint32_t index_max_observed_degree = 0, index_entry_point = 0; + const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); + for (auto &point_neighbors : stitched_graph) + { + index_max_observed_degree = std::max(index_max_observed_degree, (uint32_t)point_neighbors.size()); + } + + std::ofstream stitched_graph_writer; + stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit); + stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary); + + stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t)); + stitched_graph_writer.write((char *)&index_max_observed_degree, sizeof(uint32_t)); + stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t)); + stitched_graph_writer.write((char *)&index_num_frozen_points, sizeof(uint64_t)); + + size_t bytes_written = METADATA; + for (uint32_t node_point = 0; node_point < stitched_graph.size(); node_point++) + { + uint32_t current_node_num_neighbors = (uint32_t)stitched_graph[node_point].size(); + std::vector current_node_neighbors = stitched_graph[node_point]; + stitched_graph_writer.write((char *)¤t_node_num_neighbors, sizeof(uint32_t)); + bytes_written += sizeof(uint32_t); + for (const auto ¤t_node_neighbor : current_node_neighbors) + { + stitched_graph_writer.write((char *)¤t_node_neighbor, sizeof(uint32_t)); + bytes_written += sizeof(uint32_t); + } + index_num_edges += current_node_num_neighbors; + } + + if (bytes_written != final_index_size) + { + std::cerr << "Error: written bytes does not match allocated space" << std::endl; + throw; + } + + stitched_graph_writer.close(); + + std::chrono::duration saving_index_time = std::chrono::high_resolution_clock::now() - saving_index_timer; + std::cout << "Stitched graph written in " << saving_index_time.count() << " seconds" << std::endl; + std::cout << "Stitched graph average degree: " << ((float)index_num_edges) / ((float)(stitched_graph.size())) + << std::endl; + std::cout << "Stitched graph max degree: " << index_max_observed_degree << std::endl << std::endl; +} + +/* + * Unions the per-label graph indices together via the following policy: + * - any two nodes can only have at most one edge between them - + * + * Returns the "stitched" graph and its expected file size. + */ +template +stitch_indices_return_values stitch_label_indices( + path final_index_path_prefix, uint32_t total_number_of_points, label_set all_labels, + tsl::robin_map labels_to_number_of_points, + tsl::robin_map &label_entry_points, + tsl::robin_map> label_id_to_orig_id_map) +{ + size_t final_index_size = 0; + std::vector> stitched_graph(total_number_of_points); + + auto stitching_index_timer = std::chrono::high_resolution_clock::now(); + for (const auto &lbl : all_labels) + { + path curr_label_index_path(final_index_path_prefix + "_" + lbl); + std::vector> curr_label_index; + uint64_t curr_label_index_size; + uint32_t curr_label_entry_point; + + std::tie(curr_label_index, curr_label_index_size) = + diskann::load_label_index(curr_label_index_path, labels_to_number_of_points[lbl]); + curr_label_entry_point = (uint32_t)random(0, curr_label_index.size()); + label_entry_points[lbl] = label_id_to_orig_id_map[lbl][curr_label_entry_point]; + + for (uint32_t node_point = 0; node_point < curr_label_index.size(); node_point++) + { + uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point]; + for (auto &node_neighbor : curr_label_index[node_point]) + { + uint32_t original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor]; + std::vector curr_point_neighbors = stitched_graph[original_point_id]; + if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), original_neighbor_id) == + curr_point_neighbors.end()) + { + stitched_graph[original_point_id].push_back(original_neighbor_id); + final_index_size += sizeof(uint32_t); + } + } + } + } + + const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); + final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA); + + std::chrono::duration stitching_index_time = + std::chrono::high_resolution_clock::now() - stitching_index_timer; + std::cout << "stitched graph generated in memory in " << stitching_index_time.count() << " seconds" << std::endl; + + return std::make_tuple(stitched_graph, final_index_size); +} + +/* + * Applies the prune_neighbors function from src/index.cpp to + * every node in the stitched graph. + * + * This is an optional step, hence the saving of both the full + * and pruned graph. + */ +template +void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, path input_data_path, + std::vector> stitched_graph, uint32_t stitched_R, + tsl::robin_map label_entry_points, std::string universal_label, + path label_data_path, uint32_t num_threads) +{ + size_t dimension, number_of_label_points; + auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr); + auto std_cout_buffer = std::cout.rdbuf(nullptr); + auto pruning_index_timer = std::chrono::high_resolution_clock::now(); + + diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension); + diskann::Index index(diskann::Metric::L2, dimension, number_of_label_points, false, false); + + // not searching this index, set search_l to 0 + index.load(full_index_path_prefix.c_str(), num_threads, 1); + + std::cout << "parsing labels" << std::endl; + + index.prune_all_neighbors(stitched_R, 750, 1.2); + index.save((final_index_path_prefix).c_str()); + + diskann::cout.rdbuf(diskann_cout_buffer); + std::cout.rdbuf(std_cout_buffer); + std::chrono::duration pruning_index_time = std::chrono::high_resolution_clock::now() - pruning_index_timer; + std::cout << "pruning performed in " << pruning_index_time.count() << " seconds\n" << std::endl; +} + +/* + * Delete all temporary artifacts. + * In the process of creating the stitched index, some temporary artifacts are + * created: + * 1. the separate bin files for each labels' points + * 2. the separate diskANN indices built for each label + * 3. the '.data' file created while generating the indices + */ +void clean_up_artifacts(path input_data_path, path final_index_path_prefix, label_set all_labels) +{ + for (const auto &lbl : all_labels) + { + path curr_label_input_data_path(input_data_path + "_" + lbl); + path curr_label_index_path(final_index_path_prefix + "_" + lbl); + path curr_label_index_path_data(curr_label_index_path + ".data"); + + if (std::remove(curr_label_index_path.c_str()) != 0) + throw; + if (std::remove(curr_label_input_data_path.c_str()) != 0) + throw; + if (std::remove(curr_label_index_path_data.c_str()) != 0) + throw; + } +} + +int main(int argc, char **argv) +{ + // 1. handle cmdline inputs + std::string data_type; + path input_data_path, final_index_path_prefix, label_data_path; + std::string universal_label; + uint32_t num_threads, R, L, stitched_R; + float alpha; + + auto index_timer = std::chrono::high_resolution_clock::now(); + handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, label_data_path, universal_label, + num_threads, R, L, stitched_R, alpha); + + path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt"; + path labels_map_file = final_index_path_prefix + "_labels_map.txt"; + + convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label); + + // 2. parse label file and create necessary data structures + std::vector point_ids_to_labels; + tsl::robin_map labels_to_number_of_points; + label_set all_labels; + + std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) = + diskann::parse_label_file(labels_file_to_use, universal_label); + + // 3. for each label, make a separate data file + tsl::robin_map> label_id_to_orig_id_map; + uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size(); + +#ifndef _WINDOWS + if (data_type == "uint8") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else if (data_type == "int8") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else if (data_type == "float") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else + throw; +#else + if (data_type == "uint8") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else if (data_type == "int8") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else if (data_type == "float") + label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); + else + throw; +#endif + + // 4. for each created data file, create a vanilla diskANN index + if (data_type == "uint8") + diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, + num_threads); + else if (data_type == "int8") + diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, + num_threads); + else if (data_type == "float") + diskann::generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, + num_threads); + else + throw; + + // 5. "stitch" the indices together + std::vector> stitched_graph; + tsl::robin_map label_entry_points; + uint64_t stitched_graph_size; + + if (data_type == "uint8") + std::tie(stitched_graph, stitched_graph_size) = + stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); + else if (data_type == "int8") + std::tie(stitched_graph, stitched_graph_size) = + stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); + else if (data_type == "float") + std::tie(stitched_graph, stitched_graph_size) = + stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); + else + throw; + path full_index_path_prefix = final_index_path_prefix + "_full"; + // 5a. save the stitched graph to disk + save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, stitched_graph, label_entry_points, + universal_label, labels_file_to_use); + + // 6. run a prune on the stitched index, and save to disk + if (data_type == "uint8") + prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, + stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); + else if (data_type == "int8") + prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, + stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); + else if (data_type == "float") + prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, + stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); + else + throw; + + std::chrono::duration index_time = std::chrono::high_resolution_clock::now() - index_timer; + std::cout << "pruned/stitched graph generated in " << index_time.count() << " seconds" << std::endl; + + clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels); +} diff --git a/tests/python/README.md b/apps/python/README.md similarity index 100% rename from tests/python/README.md rename to apps/python/README.md diff --git a/tests/python/requirements.txt b/apps/python/requirements.txt similarity index 100% rename from tests/python/requirements.txt rename to apps/python/requirements.txt diff --git a/tests/python/restapi/__init__.py b/apps/python/restapi/__init__.py similarity index 100% rename from tests/python/restapi/__init__.py rename to apps/python/restapi/__init__.py diff --git a/tests/python/restapi/disk_ann_util.py b/apps/python/restapi/disk_ann_util.py similarity index 93% rename from tests/python/restapi/disk_ann_util.py rename to apps/python/restapi/disk_ann_util.py index 624a68538..ec8931035 100644 --- a/tests/python/restapi/disk_ann_util.py +++ b/apps/python/restapi/disk_ann_util.py @@ -20,7 +20,7 @@ def output_vectors( # there is probably a clever way to have numpy write out C++ friendly floats, so feel free to remove this in # favor of something more sane later vectors_as_bin_path = os.path.join(temporary_file_path, "vectors.bin") - tsv_to_bin_path = os.path.join(diskann_build_path, "tests", "utils", "tsv_to_bin") + tsv_to_bin_path = os.path.join(diskann_build_path, "apps", "utils", "tsv_to_bin") number_of_points, dimensions = vectors.shape args = [ @@ -45,7 +45,7 @@ def build_ssd_index( ): vectors_as_bin_path = output_vectors(diskann_build_path, temporary_file_path, vectors, timeout=per_process_timeout) - ssd_builder_path = os.path.join(diskann_build_path, "tests", "build_disk_index") + ssd_builder_path = os.path.join(diskann_build_path, "apps", "build_disk_index") args = [ ssd_builder_path, "--data_type", "float", diff --git a/tests/python/restapi/test_ssd_rest_api.py b/apps/python/restapi/test_ssd_rest_api.py similarity index 98% rename from tests/python/restapi/test_ssd_rest_api.py rename to apps/python/restapi/test_ssd_rest_api.py index 6493893c7..281d246d3 100644 --- a/tests/python/restapi/test_ssd_rest_api.py +++ b/apps/python/restapi/test_ssd_rest_api.py @@ -67,7 +67,7 @@ def setUpClass(cls): rest_port = rng.integers(10000, 10100) cls._rest_address = f"http://127.0.0.1:{rest_port}/" - ssd_server_path = os.path.join(diskann_build_dir, "tests", "restapi", "ssd_server") + ssd_server_path = os.path.join(diskann_build_dir, "apps", "restapi", "ssd_server") args = [ ssd_server_path, diff --git a/tests/range_search_disk_index.cpp b/apps/range_search_disk_index.cpp similarity index 72% rename from tests/range_search_disk_index.cpp rename to apps/range_search_disk_index.cpp index 2120c2f6c..31675724b 100644 --- a/tests/range_search_disk_index.cpp +++ b/apps/range_search_disk_index.cpp @@ -15,6 +15,7 @@ #include "pq_flash_index.h" #include "partition.h" #include "timer.h" +#include "program_options_utils.hpp" #ifndef _WINDOWS #include @@ -51,8 +52,8 @@ void print_stats(std::string category, std::vector percentiles, std::vect template int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &query_file, - std::string >_file, const unsigned num_threads, const float search_range, - const unsigned beamwidth, const unsigned num_nodes_to_cache, const std::vector &Lvec) + std::string >_file, const uint32_t num_threads, const float search_range, + const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const std::vector &Lvec) { std::string pq_prefix = index_path_prefix + "_pq"; std::string disk_index_file = index_path_prefix + "_disk.index"; @@ -66,7 +67,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre // load query bin T *query = nullptr; - std::vector> groundtruth_ids; + std::vector> groundtruth_ids; size_t query_num, query_dim, query_aligned_dim, gt_num; diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); @@ -110,7 +111,8 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre diskann::cout << "Caching " << num_nodes_to_cache << " BFS nodes around medoid(s)" << std::endl; _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); // _pFlashIndex->generate_cache_list_from_sample_queries( - // warmup_query_file, 15, 6, num_nodes_to_cache, num_threads, node_list); + // warmup_query_file, 15, 6, num_nodes_to_cache, num_threads, + // node_list); _pFlashIndex->load_cache_list(node_list); node_list.clear(); node_list.shrink_to_fit(); @@ -129,7 +131,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre } else { - warmup_num = (std::min)((_u32)150000, (_u32)15000 * num_threads); + warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); warmup_dim = query_dim; warmup_aligned_dim = query_aligned_dim; diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T)); @@ -150,7 +152,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre std::vector warmup_result_dists(warmup_num, 0); #pragma omp parallel for schedule(dynamic, 1) - for (_s64 i = 0; i < (int64_t)warmup_num; i++) + for (int64_t i = 0; i < (int64_t)warmup_num; i++) { _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L, warmup_result_ids_64.data() + (i * 1), @@ -183,7 +185,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { - _u64 L = Lvec[test_id]; + uint32_t L = Lvec[test_id]; if (beamwidth <= 0) { @@ -200,16 +202,17 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre auto s = std::chrono::high_resolution_clock::now(); #pragma omp parallel for schedule(dynamic, 1) - for (_s64 i = 0; i < (int64_t)query_num; i++) + for (int64_t i = 0; i < (int64_t)query_num; i++) { - std::vector<_u64> indices; + std::vector indices; std::vector distances; - _u32 res_count = _pFlashIndex->range_search(query + (i * query_aligned_dim), search_range, L, max_list_size, - indices, distances, optimized_beamwidth, stats + i); + uint32_t res_count = + _pFlashIndex->range_search(query + (i * query_aligned_dim), search_range, L, max_list_size, indices, + distances, optimized_beamwidth, stats + i); query_result_ids[test_id][i].reserve(res_count); query_result_ids[test_id][i].resize(res_count); - for (_u32 idx = 0; idx < res_count; idx++) - query_result_ids[test_id][i][idx] = indices[idx]; + for (uint32_t idx = 0; idx < res_count; idx++) + query_result_ids[test_id][i][idx] = (uint32_t)indices[idx]; } auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; @@ -221,24 +224,25 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre auto latency_999 = diskann::get_percentile_stats( stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; }); - auto mean_ios = diskann::get_mean_stats(stats, query_num, + auto mean_ios = diskann::get_mean_stats(stats, query_num, [](const diskann::QueryStats &stats) { return stats.n_ios; }); - float mean_cpuus = diskann::get_mean_stats( + double mean_cpuus = diskann::get_mean_stats( stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; }); - float recall = 0; - float ratio_of_sums = 0; + double recall = 0; + double ratio_of_sums = 0; if (calc_recall_flag) { - recall = diskann::calculate_range_search_recall(query_num, groundtruth_ids, query_result_ids[test_id]); + recall = + diskann::calculate_range_search_recall((uint32_t)query_num, groundtruth_ids, query_result_ids[test_id]); - _u32 total_true_positive = 0; - _u32 total_positive = 0; - for (_u32 i = 0; i < query_num; i++) + uint32_t total_true_positive = 0; + uint32_t total_positive = 0; + for (uint32_t i = 0; i < query_num; i++) { - total_true_positive += query_result_ids[test_id][i].size(); - total_positive += groundtruth_ids[i].size(); + total_true_positive += (uint32_t)query_result_ids[test_id][i].size(); + total_positive += (uint32_t)groundtruth_ids[i].size(); } ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive); @@ -266,33 +270,46 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre int main(int argc, char **argv) { std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file; - unsigned num_threads, W, num_nodes_to_cache; - std::vector Lvec; + uint32_t num_threads, W, num_nodes_to_cache; + std::vector Lvec; float range; - po::options_description desc{"Arguments"}; + po::options_description desc{program_options_utils::make_program_description( + "range_search_disk_index", "Searches disk DiskANN indexes using ranges")}; try { desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("dist_fn", po::value(&dist_fn)->required(), - "distance function "); - desc.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - "Path prefix to the index"); - desc.add_options()("query_file", po::value(&query_file)->required(), - "Query file in binary format"); - desc.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), - "ground truth file for the queryset"); - desc.add_options()("range_threshold,K", po::value(&range)->required(), - "Number of neighbors to be returned"); - desc.add_options()("search_list,L", po::value>(&Lvec)->multitoken(), - "List of L values of search"); - desc.add_options()("beamwidth,W", po::value(&W)->default_value(2), "Beamwidth for search"); - desc.add_options()("num_nodes_to_cache", po::value(&num_nodes_to_cache)->default_value(100000), - "Beamwidth for search"); - desc.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), - "Number of threads used for building index (defaults to " - "omp_get_num_procs())"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("query_file", po::value(&query_file)->required(), + program_options_utils::QUERY_FILE_DESCRIPTION); + required_configs.add_options()("search_list,L", + po::value>(&Lvec)->multitoken()->required(), + program_options_utils::SEARCH_LIST_DESCRIPTION); + required_configs.add_options()("range_threshold,K", po::value(&range)->required(), + "Number of neighbors to be returned"); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), + program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); + optional_configs.add_options()("num_nodes_to_cache", po::value(&num_nodes_to_cache)->default_value(0), + program_options_utils::NUMBER_OF_NODES_TO_CACHE); + optional_configs.add_options()("beamwidth,W", po::value(&W)->default_value(2), + program_options_utils::BEAMWIDTH); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); diff --git a/tests/restapi/CMakeLists.txt b/apps/restapi/CMakeLists.txt similarity index 98% rename from tests/restapi/CMakeLists.txt rename to apps/restapi/CMakeLists.txt index e0f31a9c1..c73b427d2 100644 --- a/tests/restapi/CMakeLists.txt +++ b/apps/restapi/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) add_executable(inmem_server inmem_server.cpp) if(MSVC) diff --git a/tests/restapi/client.cpp b/apps/restapi/client.cpp similarity index 100% rename from tests/restapi/client.cpp rename to apps/restapi/client.cpp diff --git a/tests/restapi/inmem_server.cpp b/apps/restapi/inmem_server.cpp similarity index 100% rename from tests/restapi/inmem_server.cpp rename to apps/restapi/inmem_server.cpp diff --git a/tests/restapi/main.cpp b/apps/restapi/main.cpp similarity index 100% rename from tests/restapi/main.cpp rename to apps/restapi/main.cpp diff --git a/tests/restapi/multiple_ssdindex_server.cpp b/apps/restapi/multiple_ssdindex_server.cpp similarity index 100% rename from tests/restapi/multiple_ssdindex_server.cpp rename to apps/restapi/multiple_ssdindex_server.cpp diff --git a/tests/restapi/ssd_server.cpp b/apps/restapi/ssd_server.cpp similarity index 100% rename from tests/restapi/ssd_server.cpp rename to apps/restapi/ssd_server.cpp diff --git a/tests/search_disk_index.cpp b/apps/search_disk_index.cpp similarity index 66% rename from tests/search_disk_index.cpp rename to apps/search_disk_index.cpp index b1d04d7c6..b46b37aef 100644 --- a/tests/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -12,6 +12,7 @@ #include "pq_flash_index.h" #include "timer.h" #include "percentile_stats.h" +#include "program_options_utils.hpp" #ifndef _WINDOWS #include @@ -49,34 +50,43 @@ void print_stats(std::string category, std::vector percentiles, std::vect template int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &result_output_prefix, const std::string &query_file, std::string >_file, - const unsigned num_threads, const unsigned recall_at, const unsigned beamwidth, - const unsigned num_nodes_to_cache, const _u32 search_io_limit, const std::vector &Lvec, - const float fail_if_recall_below, const bool use_reorder_data = false, - const std::string &filter_label = "") + const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth, + const uint32_t num_nodes_to_cache, const uint32_t search_io_limit, + const std::vector &Lvec, const float fail_if_recall_below, + const std::vector &query_filters, const bool use_reorder_data = false) { diskann::cout << "Search parameters: #threads: " << num_threads << ", "; if (beamwidth <= 0) diskann::cout << "beamwidth to be optimized for each L value" << std::flush; else diskann::cout << " beamwidth: " << beamwidth << std::flush; - if (search_io_limit == std::numeric_limits<_u32>::max()) + if (search_io_limit == std::numeric_limits::max()) diskann::cout << "." << std::endl; else diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl; - bool filtered_search = false; - if (filter_label != "") - filtered_search = true; - std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; // load query bin T *query = nullptr; - unsigned *gt_ids = nullptr; + uint32_t *gt_ids = nullptr; float *gt_dists = nullptr; size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); + bool filtered_search = false; + if (!query_filters.empty()) + { + filtered_search = true; + if (query_filters.size() != 1 && query_filters.size() != query_num) + { + std::cout << "Error. Mismatch in number of queries and size of query " + "filters file" + << std::endl; + return -1; // To return -1 or some other error handling? + } + } + bool calc_recall_flag = false; if (gt_file != std::string("null") && gt_file != std::string("NULL") && file_exists(gt_file)) { @@ -133,7 +143,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre } else { - warmup_num = (std::min)((_u32)150000, (_u32)15000 * num_threads); + warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); warmup_dim = query_dim; warmup_aligned_dim = query_aligned_dim; diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T)); @@ -154,7 +164,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre std::vector warmup_result_dists(warmup_num, 0); #pragma omp parallel for schedule(dynamic, 1) - for (_s64 i = 0; i < (int64_t)warmup_num; i++) + for (int64_t i = 0; i < (int64_t)warmup_num; i++) { _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L, warmup_result_ids_64.data() + (i * 1), @@ -185,11 +195,11 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre uint32_t optimized_beamwidth = 2; - float best_recall = 0.0; + double best_recall = 0.0; for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { - _u64 L = Lvec[test_id]; + uint32_t L = Lvec[test_id]; if (L < recall_at) { @@ -215,7 +225,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre auto s = std::chrono::high_resolution_clock::now(); #pragma omp parallel for schedule(dynamic, 1) - for (_s64 i = 0; i < (int64_t)query_num; i++) + for (int64_t i = 0; i < (int64_t)query_num; i++) { if (!filtered_search) { @@ -226,7 +236,15 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre } else { - LabelT label_for_search = _pFlashIndex->get_converted_label(filter_label); + LabelT label_for_search; + if (query_filters.size() == 1) + { // one label for all queries + label_for_search = _pFlashIndex->get_converted_label(query_filters[0]); + } + else + { // one label for each query + label_for_search = _pFlashIndex->get_converted_label(query_filters[i]); + } _pFlashIndex->cached_beam_search( query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, @@ -235,7 +253,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre } auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; - float qps = (1.0 * query_num) / (1.0 * diff.count()); + double qps = (1.0 * query_num) / (1.0 * diff.count()); diskann::convert_types(query_result_ids_64.data(), query_result_ids[test_id].data(), query_num, recall_at); @@ -246,17 +264,17 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre auto latency_999 = diskann::get_percentile_stats( stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; }); - auto mean_ios = diskann::get_mean_stats(stats, query_num, + auto mean_ios = diskann::get_mean_stats(stats, query_num, [](const diskann::QueryStats &stats) { return stats.n_ios; }); auto mean_cpuus = diskann::get_mean_stats(stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; }); - float recall = 0; + double recall = 0; if (calc_recall_flag) { - recall = diskann::calculate_recall(query_num, gt_ids, gt_dists, gt_dim, query_result_ids[test_id].data(), - recall_at, recall_at); + recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, + query_result_ids[test_id].data(), recall_at, recall_at); best_recall = std::max(recall, best_recall); } @@ -273,14 +291,14 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre } diskann::cout << "Done searching. Now saving results " << std::endl; - _u64 test_id = 0; + uint64_t test_id = 0; for (auto L : Lvec) { if (L < recall_at) continue; std::string cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin"; - diskann::save_bin<_u32>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at); + diskann::save_bin(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at); cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin"; diskann::save_bin(cur_result_path, query_result_dists[test_id++].data(), query_num, recall_at); @@ -295,51 +313,68 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre int main(int argc, char **argv) { std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label, - label_type; - unsigned num_threads, K, W, num_nodes_to_cache, search_io_limit; - std::vector Lvec; + label_type, query_filters_file; + uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit; + std::vector Lvec; bool use_reorder_data = false; float fail_if_recall_below = 0.0f; - po::options_description desc{"Arguments"}; + po::options_description desc{ + program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")}; try { desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("dist_fn", po::value(&dist_fn)->required(), - "distance function "); - desc.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - "Path prefix to the index"); - desc.add_options()("result_path", po::value(&result_path_prefix)->required(), - "Path prefix for saving results of the queries"); - desc.add_options()("query_file", po::value(&query_file)->required(), - "Query file in binary format"); - desc.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), - "ground truth file for the queryset"); - desc.add_options()("recall_at,K", po::value(&K)->required(), "Number of neighbors to be returned"); - desc.add_options()("search_list,L", po::value>(&Lvec)->multitoken(), - "List of L values of search"); - desc.add_options()("beamwidth,W", po::value(&W)->default_value(2), - "Beamwidth for search. Set 0 to optimize internally."); - desc.add_options()("num_nodes_to_cache", po::value(&num_nodes_to_cache)->default_value(0), - "Beamwidth for search"); - desc.add_options()("search_io_limit", - po::value(&search_io_limit)->default_value(std::numeric_limits<_u32>::max()), - "Max #IOs for search"); - desc.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), - "Number of threads used for building index (defaults to " - "omp_get_num_procs())"); - desc.add_options()("use_reorder_data", po::bool_switch()->default_value(false), - "Include full precision data in the index. Use only in " - "conjuction with compressed data on SSD."); - desc.add_options()("filter_label", po::value(&filter_label)->default_value(std::string("")), - "Filter Label for Filtered Search"); - desc.add_options()("label_type", po::value(&label_type)->default_value("uint"), - "Storage type of Labels , default value is uint which " - "will consume memory 4 bytes per filter"); - desc.add_options()("fail_if_recall_below", po::value(&fail_if_recall_below)->default_value(0.0f), - "If set to a value >0 and <100%, program returns -1 if best recall " - "found is below this threshold. "); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("result_path", po::value(&result_path_prefix)->required(), + program_options_utils::RESULT_PATH_DESCRIPTION); + required_configs.add_options()("query_file", po::value(&query_file)->required(), + program_options_utils::QUERY_FILE_DESCRIPTION); + required_configs.add_options()("recall_at,K", po::value(&K)->required(), + program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION); + required_configs.add_options()("search_list,L", + po::value>(&Lvec)->multitoken()->required(), + program_options_utils::SEARCH_LIST_DESCRIPTION); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), + program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); + optional_configs.add_options()("beamwidth,W", po::value(&W)->default_value(2), + program_options_utils::BEAMWIDTH); + optional_configs.add_options()("num_nodes_to_cache", po::value(&num_nodes_to_cache)->default_value(0), + program_options_utils::NUMBER_OF_NODES_TO_CACHE); + optional_configs.add_options()( + "search_io_limit", + po::value(&search_io_limit)->default_value(std::numeric_limits::max()), + "Max #IOs for search. Default value: uint32::max()"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("use_reorder_data", po::bool_switch()->default_value(false), + "Include full precision data in the index. Use only in " + "conjuction with compressed data on SSD. Default value: false"); + optional_configs.add_options()("filter_label", + po::value(&filter_label)->default_value(std::string("")), + program_options_utils::FILTER_LABEL_DESCRIPTION); + optional_configs.add_options()("query_filters_file", + po::value(&query_filters_file)->default_value(std::string("")), + program_options_utils::FILTERS_FILE_DESCRIPTION); + optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); + optional_configs.add_options()("fail_if_recall_below", + po::value(&fail_if_recall_below)->default_value(0.0f), + program_options_utils::FAIL_IF_RECALL_BELOW); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -393,22 +428,38 @@ int main(int argc, char **argv) return -1; } + if (filter_label != "" && query_filters_file != "") + { + std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl; + return -1; + } + + std::vector query_filters; + if (filter_label != "") + { + query_filters.push_back(filter_label); + } + else if (query_filters_file != "") + { + query_filters = read_file_to_vector_of_strings(query_filters_file); + } + try { - if (filter_label != "" && label_type == "ushort") + if (!query_filters.empty() && label_type == "ushort") { if (data_type == std::string("float")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, use_reorder_data, filter_label); + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); else if (data_type == std::string("int8")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, use_reorder_data, filter_label); + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); else if (data_type == std::string("uint8")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, use_reorder_data, filter_label); + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); else { std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; @@ -420,15 +471,15 @@ int main(int argc, char **argv) if (data_type == std::string("float")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, use_reorder_data, filter_label); + fail_if_recall_below, query_filters, use_reorder_data); else if (data_type == std::string("int8")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, use_reorder_data, filter_label); + fail_if_recall_below, query_filters, use_reorder_data); else if (data_type == std::string("uint8")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, use_reorder_data, filter_label); + fail_if_recall_below, query_filters, use_reorder_data); else { std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; diff --git a/tests/search_memory_index.cpp b/apps/search_memory_index.cpp similarity index 52% rename from tests/search_memory_index.cpp rename to apps/search_memory_index.cpp index 4e093998d..44817242c 100644 --- a/tests/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -20,24 +20,26 @@ #include "index.h" #include "memory_mapper.h" #include "utils.h" +#include "program_options_utils.hpp" +#include "index_factory.h" namespace po = boost::program_options; template int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix, - const std::string &query_file, const std::string &truthset_file, const unsigned num_threads, - const unsigned recall_at, const bool print_all_recalls, const std::vector &Lvec, + const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads, + const uint32_t recall_at, const bool print_all_recalls, const std::vector &Lvec, const bool dynamic, const bool tags, const bool show_qps_per_thread, - const std::string &filter_label, const float fail_if_recall_below) + const std::vector &query_filters, const float fail_if_recall_below) { + using TagT = uint32_t; // Load the query file T *query = nullptr; - unsigned *gt_ids = nullptr; + uint32_t *gt_ids = nullptr; float *gt_dists = nullptr; size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); - // Check for ground truth bool calc_recall_flag = false; if (truthset_file != std::string("null") && file_exists(truthset_file)) { @@ -54,30 +56,50 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } bool filtered_search = false; - if (filter_label != "") + if (!query_filters.empty()) { filtered_search = true; + if (query_filters.size() != 1 && query_filters.size() != query_num) + { + std::cout << "Error. Mismatch in number of queries and size of query " + "filters file" + << std::endl; + return -1; // To return -1 or some other error handling? + } } - using TagT = uint32_t; - const bool concurrent = false, pq_dist_build = false, use_opq = false; - const size_t num_pq_chunks = 0; - using IndexType = diskann::Index; - const size_t num_frozen_pts = IndexType::get_graph_num_frozen_points(index_path); - IndexType index(metric, query_dim, 0, dynamic, tags, concurrent, pq_dist_build, num_pq_chunks, use_opq, - num_frozen_pts); - std::cout << "Index class instantiated" << std::endl; - index.load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end()))); + const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path); + + auto config = diskann::IndexConfigBuilder() + .with_metric(metric) + .with_dimension(query_dim) + .with_max_points(0) + .with_data_load_store_strategy(diskann::MEMORY) + .with_data_type(diskann_type_to_name()) + .with_label_type(diskann_type_to_name()) + .with_tag_type(diskann_type_to_name()) + .is_dynamic_index(dynamic) + .is_enable_tags(tags) + .is_concurrent_consolidate(false) + .is_pq_dist_build(false) + .is_use_opq(false) + .with_num_pq_chunks(0) + .with_num_frozen_pts(num_frozen_pts) + .build(); + + auto index_factory = diskann::IndexFactory(config); + auto index = index_factory.create_instance(); + index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end()))); std::cout << "Index loaded" << std::endl; + if (metric == diskann::FAST_L2) - index.optimize_index_layout(); + index->optimize_index_layout(); std::cout << "Using " << num_threads << " threads to search" << std::endl; - diskann::Parameters paras; std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); std::cout.precision(2); const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS"; - unsigned table_width = 0; + uint32_t table_width = 0; if (tags) { std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(20) << "Mean Latency (mus)" @@ -90,11 +112,11 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, << std::setw(20) << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency"; table_width += 4 + 12 + 18 + 20 + 15; } - unsigned recalls_to_print = 0; - const unsigned first_recall = print_all_recalls ? 1 : recall_at; + uint32_t recalls_to_print = 0; + const uint32_t first_recall = print_all_recalls ? 1 : recall_at; if (calc_recall_flag) { - for (unsigned curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) + for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) { std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall)); } @@ -107,10 +129,10 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, std::vector> query_result_ids(Lvec.size()); std::vector> query_result_dists(Lvec.size()); std::vector latency_stats(query_num, 0); - std::vector cmp_stats; + std::vector cmp_stats; if (not tags) { - cmp_stats = std::vector(query_num, 0); + cmp_stats = std::vector(query_num, 0); } std::vector query_result_tags; @@ -119,11 +141,11 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, query_result_tags.resize(recall_at * query_num); } - float best_recall = 0.0; + double best_recall = 0.0; for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { - _u64 L = Lvec[test_id]; + uint32_t L = Lvec[test_id]; if (L < recall_at) { diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl; @@ -142,21 +164,22 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, auto qs = std::chrono::high_resolution_clock::now(); if (filtered_search) { - LabelT filter_label_as_num = index.get_converted_label(filter_label); - auto retval = index.search_with_filters(query + i * query_aligned_dim, filter_label_as_num, recall_at, - L, query_result_ids[test_id].data() + i * recall_at, - query_result_dists[test_id].data() + i * recall_at); + std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; + + auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L, + query_result_ids[test_id].data() + i * recall_at, + query_result_dists[test_id].data() + i * recall_at); cmp_stats[i] = retval.second; } else if (metric == diskann::FAST_L2) { - index.search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at); + index->search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L, + query_result_ids[test_id].data() + i * recall_at); } else if (tags) { - index.search_with_tags(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, nullptr, res); + index->search_with_tags(query + i * query_aligned_dim, recall_at, L, + query_result_tags.data() + i * recall_at, nullptr, res); for (int64_t r = 0; r < (int64_t)recall_at; r++) { query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; @@ -165,34 +188,34 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, else { cmp_stats[i] = index - .search(query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at) + ->search(query + i * query_aligned_dim, recall_at, L, + query_result_ids[test_id].data() + i * recall_at) .second; } auto qe = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = qe - qs; - latency_stats[i] = diff.count() * 1000000; + latency_stats[i] = (float)(diff.count() * 1000000); } std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; - float displayed_qps = static_cast(query_num) / diff.count(); + double displayed_qps = query_num / diff.count(); if (show_qps_per_thread) displayed_qps /= num_threads; - std::vector recalls; + std::vector recalls; if (calc_recall_flag) { recalls.reserve(recalls_to_print); - for (unsigned curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) + for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) { - recalls.push_back(diskann::calculate_recall(query_num, gt_ids, gt_dists, gt_dim, + recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, query_result_ids[test_id].data(), recall_at, curr_recall)); } } std::sort(latency_stats.begin(), latency_stats.end()); - float mean_latency = + double mean_latency = std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / static_cast(query_num); float avg_cmps = (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / (float)query_num; @@ -200,15 +223,15 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, if (tags) { std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(20) << (float)mean_latency - << std::setw(15) << (float)latency_stats[(_u64)(0.999 * query_num)]; + << std::setw(15) << (float)latency_stats[(uint64_t)(0.999 * query_num)]; } else { std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(18) << avg_cmps << std::setw(20) << (float)mean_latency << std::setw(15) - << (float)latency_stats[(_u64)(0.999 * query_num)]; + << (float)latency_stats[(uint64_t)(0.999 * query_num)]; } - for (float recall : recalls) + for (double recall : recalls) { std::cout << std::setw(12) << recall; best_recall = std::max(recall, best_recall); @@ -217,7 +240,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } std::cout << "Done searching. Now saving results " << std::endl; - _u64 test_id = 0; + uint64_t test_id = 0; for (auto L : Lvec) { if (L < recall_at) @@ -225,63 +248,89 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl; continue; } - std::string cur_result_path = result_path_prefix + "_" + std::to_string(L) + "_idx_uint32.bin"; - diskann::save_bin<_u32>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at); + std::string cur_result_path_prefix = result_path_prefix + "_" + std::to_string(L); + + std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin"; + diskann::save_bin(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at); + + cur_result_path = cur_result_path_prefix + "_dists_float.bin"; + diskann::save_bin(cur_result_path, query_result_dists[test_id].data(), query_num, recall_at); + test_id++; } diskann::aligned_free(query); - return best_recall >= fail_if_recall_below ? 0 : -1; } int main(int argc, char **argv) { - std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type; - unsigned num_threads, K; - std::vector Lvec; + std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type, + query_filters_file; + uint32_t num_threads, K; + std::vector Lvec; bool print_all_recalls, dynamic, tags, show_qps_per_thread; float fail_if_recall_below = 0.0f; - po::options_description desc{"Arguments"}; + po::options_description desc{ + program_options_utils::make_program_description("search_memory_index", "Searches in-memory DiskANN indexes")}; try { - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("dist_fn", po::value(&dist_fn)->required(), - "distance function "); - desc.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - "Path prefix to the index"); - desc.add_options()("result_path", po::value(&result_path)->required(), - "Path prefix for saving results of the queries"); - desc.add_options()("query_file", po::value(&query_file)->required(), - "Query file in binary format"); - desc.add_options()("filter_label", po::value(&filter_label)->default_value(std::string("")), - "Filter Label for Filtered Search"); - desc.add_options()("label_type", po::value(&label_type)->default_value("uint"), - "Storage type of Labels , default value is uint which " - "will consume memory 4 bytes per filter"); - desc.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), - "ground truth file for the queryset"); - desc.add_options()("recall_at,K", po::value(&K)->required(), "Number of neighbors to be returned"); - desc.add_options()("print_all_recalls", po::bool_switch(&print_all_recalls), - "Print recalls at all positions, from 1 up to specified " - "recall_at value"); - desc.add_options()("search_list,L", po::value>(&Lvec)->multitoken(), - "List of L values of search"); - desc.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), - "Number of threads used for building index (defaults to " - "omp_get_num_procs())"); - desc.add_options()("dynamic", po::value(&dynamic)->default_value(false), - "Whether the index is dynamic. Default false."); - desc.add_options()("tags", po::value(&tags)->default_value(false), - "Whether to search with tags. Default false."); - desc.add_options()("qps_per_thread", po::bool_switch(&show_qps_per_thread), - "Print overall QPS divided by the number of threads in " - "the output table"); - desc.add_options()("fail_if_recall_below", po::value(&fail_if_recall_below)->default_value(0.0f), - "If set to a value >0 and <100%, program returns -1 if best recall " - "found is below this threshold. "); + desc.add_options()("help,h", "Print this information on arguments"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("result_path", po::value(&result_path)->required(), + program_options_utils::RESULT_PATH_DESCRIPTION); + required_configs.add_options()("query_file", po::value(&query_file)->required(), + program_options_utils::QUERY_FILE_DESCRIPTION); + required_configs.add_options()("recall_at,K", po::value(&K)->required(), + program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION); + required_configs.add_options()("search_list,L", + po::value>(&Lvec)->multitoken()->required(), + program_options_utils::SEARCH_LIST_DESCRIPTION); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("filter_label", + po::value(&filter_label)->default_value(std::string("")), + program_options_utils::FILTER_LABEL_DESCRIPTION); + optional_configs.add_options()("query_filters_file", + po::value(&query_filters_file)->default_value(std::string("")), + program_options_utils::FILTERS_FILE_DESCRIPTION); + optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), + program_options_utils::LABEL_TYPE_DESCRIPTION); + optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), + program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()( + "dynamic", po::value(&dynamic)->default_value(false), + "Whether the index is dynamic. Dynamic indices must have associated tags. Default false."); + optional_configs.add_options()("tags", po::value(&tags)->default_value(false), + "Whether to search with external identifiers (tags). Default false."); + optional_configs.add_options()("fail_if_recall_below", + po::value(&fail_if_recall_below)->default_value(0.0f), + program_options_utils::FAIL_IF_RECALL_BELOW); + + // Output controls + po::options_description output_controls("Output controls"); + output_controls.add_options()("print_all_recalls", po::bool_switch(&print_all_recalls), + "Print recalls at all positions, from 1 up to specified " + "recall_at value"); + output_controls.add_options()("print_qps_per_thread", po::bool_switch(&show_qps_per_thread), + "Print overall QPS divided by the number of threads in " + "the output table"); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs).add(output_controls); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -336,27 +385,43 @@ int main(int argc, char **argv) return -1; } + if (filter_label != "" && query_filters_file != "") + { + std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl; + return -1; + } + + std::vector query_filters; + if (filter_label != "") + { + query_filters.push_back(filter_label); + } + else if (query_filters_file != "") + { + query_filters = read_file_to_vector_of_strings(query_filters_file); + } + try { - if (filter_label != "" && label_type == "ushort") + if (!query_filters.empty() && label_type == "ushort") { if (data_type == std::string("int8")) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, filter_label, fail_if_recall_below); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); } else if (data_type == std::string("uint8")) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, filter_label, fail_if_recall_below); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, filter_label, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below); } else { @@ -370,19 +435,19 @@ int main(int argc, char **argv) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, filter_label, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below); } else if (data_type == std::string("uint8")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, filter_label, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, filter_label, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below); } else { diff --git a/tests/test_insert_deletes_consolidate.cpp b/apps/test_insert_deletes_consolidate.cpp similarity index 61% rename from tests/test_insert_deletes_consolidate.cpp rename to apps/test_insert_deletes_consolidate.cpp index ef598c659..700f4d7b6 100644 --- a/tests/test_insert_deletes_consolidate.cpp +++ b/apps/test_insert_deletes_consolidate.cpp @@ -11,6 +11,8 @@ #include #include "utils.h" +#include "program_options_utils.hpp" +#include "index_factory.h" #ifndef _WINDOWS #include @@ -37,8 +39,8 @@ inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t o int npts_i32, dim_i32; reader.read((char *)&npts_i32, sizeof(int)); reader.read((char *)&dim_i32, sizeof(int)); - size_t npts = (unsigned)npts_i32; - size_t dim = (unsigned)dim_i32; + size_t npts = (uint32_t)npts_i32; + size_t dim = (uint32_t)dim_i32; size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t); if (actual_file_size != expected_actual_file_size) @@ -90,7 +92,7 @@ std::string get_save_filename(const std::string &save_path, size_t points_to_ski } template -void insert_till_next_checkpoint(diskann::Index &index, size_t start, size_t end, size_t thread_count, T *data, +void insert_till_next_checkpoint(diskann::AbstractIndex &index, size_t start, size_t end, int32_t thread_count, T *data, size_t aligned_dim) { diskann::Timer insert_timer; @@ -106,8 +108,8 @@ void insert_till_next_checkpoint(diskann::Index &index, size_t start, s } template -void delete_from_beginning(diskann::Index &index, diskann::Parameters &delete_params, size_t points_to_skip, - size_t points_to_delete_from_beginning) +void delete_from_beginning(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, + size_t points_to_skip, size_t points_to_delete_from_beginning) { try { @@ -115,7 +117,7 @@ void delete_from_beginning(diskann::Index &index, diskann::Parameters & << "Lazy deleting points " << points_to_skip << " to " << points_to_skip + points_to_delete_from_beginning << "... "; for (size_t i = points_to_skip; i < points_to_skip + points_to_delete_from_beginning; ++i) - index.lazy_delete(i + 1); // Since tags are data location + 1 + index.lazy_delete(static_cast(i + 1)); // Since tags are data location + 1 std::cout << "done." << std::endl; auto report = index.consolidate_deletes(delete_params); @@ -125,8 +127,8 @@ void delete_from_beginning(diskann::Index &index, diskann::Parameters & << "deletes processed: " << report._slots_released << std::endl << "latest delete size: " << report._delete_set_size << std::endl << "rate: (" << points_to_delete_from_beginning / report._time << " points/second overall, " - << points_to_delete_from_beginning / report._time / delete_params.Get("num_threads") - << " per thread)" << std::endl; + << points_to_delete_from_beginning / report._time / delete_params.num_threads << " per thread)" + << std::endl; } catch (std::system_error &e) { @@ -135,34 +137,39 @@ void delete_from_beginning(diskann::Index &index, diskann::Parameters & } template -void build_incremental_index(const std::string &data_path, const unsigned L, const unsigned R, const float alpha, - const unsigned thread_count, size_t points_to_skip, size_t max_points_to_insert, - size_t beginning_index_size, float start_point_norm, unsigned num_start_pts, - size_t points_per_checkpoint, size_t checkpoints_per_snapshot, +void build_incremental_index(const std::string &data_path, diskann::IndexWriteParameters ¶ms, size_t points_to_skip, + size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm, + uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot, const std::string &save_path, size_t points_to_delete_from_beginning, size_t start_deletes_after, bool concurrent) { - const unsigned C = 500; - const bool saturate_graph = false; - - diskann::Parameters params; - params.Set("L", L); - params.Set("R", R); - params.Set("C", C); - params.Set("alpha", alpha); - params.Set("saturate_graph", saturate_graph); - params.Set("num_rnds", 1); - params.Set("num_threads", thread_count); - params.Set("Lf", 0); // TODO: get this from params and default to some - // value to make it backward compatible. - params.Set("num_frozen_pts", num_start_pts); - size_t dim, aligned_dim; size_t num_points; - diskann::get_bin_metadata(data_path, num_points, dim); aligned_dim = ROUND_UP(dim, 8); + bool enable_tags = true; + using TagT = uint32_t; + auto data_type = diskann_type_to_name(); + auto tag_type = diskann_type_to_name(); + diskann::IndexConfig index_config = diskann::IndexConfigBuilder() + .with_metric(diskann::L2) + .with_dimension(dim) + .with_max_points(max_points_to_insert) + .is_dynamic_index(true) + .with_index_write_params(params) + .with_search_threads(params.num_threads) + .with_initial_search_list_size(params.search_list_size) + .with_data_type(data_type) + .with_tag_type(tag_type) + .with_data_load_store_strategy(diskann::MEMORY) + .is_enable_tags(enable_tags) + .is_concurrent_consolidate(concurrent) + .build(); + + diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); + auto index = index_factory.create_instance(); + if (points_to_skip > num_points) { throw diskann::ANNException("Asked to skip more points than in data file", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -180,12 +187,6 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con << " points since the data file has only that many" << std::endl; } - using TagT = uint32_t; - const bool enable_tags = true; - - diskann::Index index(diskann::L2, dim, max_points_to_insert, true, params, params, enable_tags, - concurrent); - size_t current_point_offset = points_to_skip; const size_t last_point_threshold = points_to_skip + max_points_to_insert; @@ -214,13 +215,11 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con if (beginning_index_size > 0) { - index.build(data, beginning_index_size, params, tags); - index.enable_delete(); + index->build(data, beginning_index_size, params, tags); } else { - index.set_start_points_at_random(static_cast(start_point_norm)); - index.enable_delete(); + index->set_start_points_at_random(static_cast(start_point_norm)); } const double elapsedSeconds = timer.elapsed() / 1000000.0; @@ -231,14 +230,14 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con if (points_to_delete_from_beginning > max_points_to_insert) { - points_to_delete_from_beginning = static_cast(max_points_to_insert); + points_to_delete_from_beginning = static_cast(max_points_to_insert); std::cerr << "WARNING: Reducing points to delete from beginning to " << points_to_delete_from_beginning << " points since the data file has only that many" << std::endl; } if (concurrent) { - int sub_threads = (thread_count + 1) / 2; + int32_t sub_threads = (params.num_threads + 1) / 2; bool delete_launched = false; std::future delete_task; @@ -252,7 +251,7 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con auto insert_task = std::async(std::launch::async, [&]() { load_aligned_bin_part(data_path, data, start, end - start); - insert_till_next_checkpoint(index, start, end, sub_threads, data, aligned_dim); + insert_till_next_checkpoint(*index, start, end, sub_threads, data, aligned_dim); }); insert_task.wait(); @@ -260,10 +259,12 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con end >= points_to_skip + points_to_delete_from_beginning) { delete_launched = true; - params.Set("num_threads", sub_threads); + diskann::IndexWriteParameters delete_params = + diskann::IndexWriteParametersBuilder(params).with_num_threads(sub_threads).build(); delete_task = std::async(std::launch::async, [&]() { - delete_from_beginning(index, params, points_to_skip, points_to_delete_from_beginning); + delete_from_beginning(*index, delete_params, points_to_skip, + points_to_delete_from_beginning); }); } } @@ -272,7 +273,7 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; const auto save_path_inc = get_save_filename(save_path + ".after-concurrent-delete-", points_to_skip, points_to_delete_from_beginning, last_point_threshold); - index.save(save_path_inc.c_str(), true); + index->save(save_path_inc.c_str(), true); } else { @@ -286,7 +287,7 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl; load_aligned_bin_part(data_path, data, start, end - start); - insert_till_next_checkpoint(index, start, end, thread_count, data, aligned_dim); + insert_till_next_checkpoint(*index, start, end, (int32_t)params.num_threads, data, aligned_dim); if (checkpoints_per_snapshot > 0 && --num_checkpoints_till_snapshot == 0) { @@ -294,7 +295,7 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con const auto save_path_inc = get_save_filename(save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, end); - index.save(save_path_inc.c_str(), false); + index->save(save_path_inc.c_str(), false); const double elapsedSeconds = save_timer.elapsed() / 1000000.0; const size_t points_saved = end - points_to_skip; @@ -317,11 +318,11 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con if (points_to_delete_from_beginning > 0) { - delete_from_beginning(index, params, points_to_skip, points_to_delete_from_beginning); + delete_from_beginning(*index, params, points_to_skip, points_to_delete_from_beginning); } const auto save_path_inc = get_save_filename(save_path + ".after-delete-", points_to_skip, points_to_delete_from_beginning, last_point_threshold); - index.save(save_path_inc.c_str(), true); + index->save(save_path_inc.c_str(), true); } diskann::aligned_free(data); @@ -330,50 +331,68 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con int main(int argc, char **argv) { std::string data_type, dist_fn, data_path, index_path_prefix; - unsigned num_threads, R, L, num_start_pts; + uint32_t num_threads, R, L, num_start_pts; float alpha, start_point_norm; size_t points_to_skip, max_points_to_insert, beginning_index_size, points_per_checkpoint, checkpoints_per_snapshot, points_to_delete_from_beginning, start_deletes_after; bool concurrent; - po::options_description desc{"Arguments"}; + po::options_description desc{program_options_utils::make_program_description("test_insert_deletes_consolidate", + "Test insert deletes & consolidate")}; try { desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("dist_fn", po::value(&dist_fn)->required(), "distance function "); - desc.add_options()("data_path", po::value(&data_path)->required(), - "Input data file in bin format"); - desc.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - "Path prefix for saving index file components"); - desc.add_options()("max_degree,R", po::value(&R)->default_value(64), "Maximum graph degree"); - desc.add_options()("Lbuild,L", po::value(&L)->default_value(100), - "Build complexity, higher value results in better graphs"); - desc.add_options()("alpha", po::value(&alpha)->default_value(1.2f), - "alpha controls density and diameter of graph, set 1 for sparse graph, " - "1.2 or 1.4 for denser graphs with lower diameter"); - desc.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), - "Number of threads used for building index (defaults to " - "omp_get_num_procs())"); - desc.add_options()("points_to_skip", po::value(&points_to_skip)->required(), - "Skip these first set of points from file"); - desc.add_options()("max_points_to_insert", po::value(&max_points_to_insert)->default_value(0), - "These number of points from the file are inserted after " - "points_to_skip"); - desc.add_options()("beginning_index_size", po::value(&beginning_index_size)->required(), - "Batch build will be called on these set of points"); - desc.add_options()("points_per_checkpoint", po::value(&points_per_checkpoint)->required(), - "Insertions are done in batches of points_per_checkpoint"); - desc.add_options()("checkpoints_per_snapshot", po::value(&checkpoints_per_snapshot)->required(), - "Save the index to disk every few checkpoints"); - desc.add_options()("points_to_delete_from_beginning", - po::value(&points_to_delete_from_beginning)->required(), ""); - desc.add_options()("do_concurrent", po::value(&concurrent)->default_value(false), ""); - desc.add_options()("start_deletes_after", po::value(&start_deletes_after)->default_value(0), ""); - desc.add_options()("start_point_norm", po::value(&start_point_norm)->default_value(0), - "Set the start point to a random point on a sphere of this radius"); - desc.add_options()("num_start_points", po::value(&num_start_pts)->default_value(0), - "Set the number of random start (frozen) points to use when inserting and searching"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + required_configs.add_options()("points_to_skip", po::value(&points_to_skip)->required(), + "Skip these first set of points from file"); + required_configs.add_options()("beginning_index_size", po::value(&beginning_index_size)->required(), + "Batch build will be called on these set of points"); + required_configs.add_options()("points_per_checkpoint", po::value(&points_per_checkpoint)->required(), + "Insertions are done in batches of points_per_checkpoint"); + required_configs.add_options()("checkpoints_per_snapshot", + po::value(&checkpoints_per_snapshot)->required(), + "Save the index to disk every few checkpoints"); + required_configs.add_options()("points_to_delete_from_beginning", + po::value(&points_to_delete_from_beginning)->required(), ""); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()("max_points_to_insert", + po::value(&max_points_to_insert)->default_value(0), + "These number of points from the file are inserted after " + "points_to_skip"); + optional_configs.add_options()("do_concurrent", po::value(&concurrent)->default_value(false), ""); + optional_configs.add_options()("start_deletes_after", + po::value(&start_deletes_after)->default_value(0), ""); + optional_configs.add_options()("start_point_norm", po::value(&start_point_norm)->default_value(0), + "Set the start point to a random point on a sphere of this radius"); + optional_configs.add_options()( + "num_start_points", + po::value(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC), + "Set the number of random start (frozen) points to use when " + "inserting and searching"); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -386,7 +405,8 @@ int main(int argc, char **argv) if (beginning_index_size == 0) if (start_point_norm == 0) { - std::cout << "When beginning_index_size is 0, use a start point with " + std::cout << "When beginning_index_size is 0, use a start " + "point with " "appropriate norm" << std::endl; return -1; @@ -400,18 +420,25 @@ int main(int argc, char **argv) try { + diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R) + .with_max_occlusion_size(500) + .with_alpha(alpha) + .with_num_threads(num_threads) + .with_num_frozen_points(num_start_pts) + .build(); + if (data_type == std::string("int8")) - build_incremental_index(data_path, L, R, alpha, num_threads, points_to_skip, max_points_to_insert, + build_incremental_index(data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning, start_deletes_after, concurrent); else if (data_type == std::string("uint8")) - build_incremental_index(data_path, L, R, alpha, num_threads, points_to_skip, max_points_to_insert, + build_incremental_index(data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning, start_deletes_after, concurrent); else if (data_type == std::string("float")) - build_incremental_index(data_path, L, R, alpha, num_threads, points_to_skip, max_points_to_insert, + build_incremental_index(data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning, start_deletes_after, concurrent); diff --git a/tests/test_streaming_scenario.cpp b/apps/test_streaming_scenario.cpp similarity index 56% rename from tests/test_streaming_scenario.cpp rename to apps/test_streaming_scenario.cpp index b1f655162..55e4e61cf 100644 --- a/tests/test_streaming_scenario.cpp +++ b/apps/test_streaming_scenario.cpp @@ -9,8 +9,11 @@ #include #include #include +#include +#include #include "utils.h" +#include "program_options_utils.hpp" #ifndef _WINDOWS #include @@ -36,8 +39,8 @@ inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t o int npts_i32, dim_i32; reader.read((char *)&npts_i32, sizeof(int)); reader.read((char *)&dim_i32, sizeof(int)); - size_t npts = (unsigned)npts_i32; - size_t dim = (unsigned)dim_i32; + size_t npts = (uint32_t)npts_i32; + size_t dim = (uint32_t)dim_i32; size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t); if (actual_file_size != expected_actual_file_size) @@ -81,8 +84,8 @@ std::string get_save_filename(const std::string &save_path, size_t active_window return final_path; } -template -void insert_next_batch(diskann::Index &index, size_t start, size_t end, size_t insert_threads, T *data, +template +void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, size_t insert_threads, T *data, size_t aligned_dim) { try @@ -91,7 +94,7 @@ void insert_next_batch(diskann::Index &index, size_t start, siz std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl; size_t num_failed = 0; -#pragma omp parallel for num_threads(insert_threads) schedule(dynamic) reduction(+ : num_failed) +#pragma omp parallel for num_threads((int32_t)insert_threads) schedule(dynamic) reduction(+ : num_failed) for (int64_t j = start; j < (int64_t)end; j++) { if (index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast(j)) != 0) @@ -113,15 +116,15 @@ void insert_next_batch(diskann::Index &index, size_t start, siz } } -template -void delete_and_consolidate(diskann::Index &index, diskann::Parameters &delete_params, size_t start, +template +void delete_and_consolidate(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, size_t start, size_t end) { try { std::cout << std::endl << "Lazy deleting points " << start << " to " << end << "... "; for (size_t i = start; i < end; ++i) - index.lazy_delete(1 + i); + index.lazy_delete(static_cast(1 + i)); std::cout << "lazy delete done." << std::endl; auto report = index.consolidate_deletes(delete_params); @@ -149,14 +152,13 @@ void delete_and_consolidate(diskann::Index &index, diskann::Par } auto points_processed = report._active_points + report._slots_released; auto deletion_rate = points_processed / report._time; - std::cout << "#active points: " << report._active_points << std::endl + std::cout << "#active points: " << report._active_points << std::endl << "max points: " << report._max_points << std::endl << "empty slots: " << report._empty_slots << std::endl << "deletes processed: " << report._slots_released << std::endl << "latest delete size: " << report._delete_set_size << std::endl << "Deletion rate: " << deletion_rate << "/sec " - << "Deletion rate: " << deletion_rate / delete_params.Get("num_threads") << "/thread/sec " - << std::endl; + << "Deletion rate: " << deletion_rate / delete_params.num_threads << "/thread/sec " << std::endl; } catch (std::system_error &e) { @@ -166,62 +168,79 @@ void delete_and_consolidate(diskann::Index &index, diskann::Par } template -void build_incremental_index(const std::string &data_path, const unsigned L, const unsigned R, const float alpha, - const unsigned insert_threads, const unsigned consolidate_threads, +void build_incremental_index(const std::string &data_path, const uint32_t L, const uint32_t R, const float alpha, + const uint32_t insert_threads, const uint32_t consolidate_threads, size_t max_points_to_insert, size_t active_window, size_t consolidate_interval, - const float start_point_norm, unsigned num_start_pts, const std::string &save_path) + const float start_point_norm, uint32_t num_start_pts, const std::string &save_path) { - const unsigned C = 500; + const uint32_t C = 500; const bool saturate_graph = false; + using TagT = uint32_t; + using LabelT = uint32_t; - diskann::Parameters params; - params.Set("L", L); - params.Set("R", R); - params.Set("C", C); - params.Set("alpha", alpha); - params.Set("saturate_graph", saturate_graph); - params.Set("num_rnds", 1); - params.Set("num_threads", insert_threads); - params.Set("Lf", 0); - params.Set("num_frozen_pts", num_start_pts); - diskann::Parameters delete_params; - delete_params.Set("L", L); - delete_params.Set("R", R); - delete_params.Set("C", C); - delete_params.Set("alpha", alpha); - delete_params.Set("saturate_graph", saturate_graph); - delete_params.Set("num_rnds", 1); - delete_params.Set("num_threads", consolidate_threads); + diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R) + .with_max_occlusion_size(C) + .with_alpha(alpha) + .with_saturate_graph(saturate_graph) + .with_num_threads(insert_threads) + .with_num_frozen_points(num_start_pts) + .build(); + + diskann::IndexWriteParameters delete_params = diskann::IndexWriteParametersBuilder(L, R) + .with_max_occlusion_size(C) + .with_alpha(alpha) + .with_saturate_graph(saturate_graph) + .with_num_threads(consolidate_threads) + .build(); size_t dim, aligned_dim; size_t num_points; diskann::get_bin_metadata(data_path, num_points, dim); + diskann::cout << "metadata: file " << data_path << " has " << num_points << " points in " << dim << " dims" + << std::endl; aligned_dim = ROUND_UP(dim, 8); + auto index_config = diskann::IndexConfigBuilder() + .with_metric(diskann::L2) + .with_dimension(dim) + .with_max_points(active_window + 4 * consolidate_interval) + .is_dynamic_index(true) + .is_enable_tags(true) + .is_use_opq(false) + .with_num_pq_chunks(0) + .is_pq_dist_build(false) + .with_search_threads(insert_threads) + .with_initial_search_list_size(L) + .with_tag_type(diskann_type_to_name()) + .with_label_type(diskann_type_to_name()) + .with_data_type(diskann_type_to_name()) + .with_index_write_params(params) + .with_data_load_store_strategy(diskann::MEMORY) + .build(); + + diskann::IndexFactory index_factory = diskann::IndexFactory(index_config); + auto index = index_factory.create_instance(); + if (max_points_to_insert == 0) { max_points_to_insert = num_points; } if (num_points < max_points_to_insert) - throw diskann::ANNException("num_points < max_points_to_insert", -1, __FUNCSIG__, __FILE__, __LINE__); + throw diskann::ANNException(std::string("num_points(") + std::to_string(num_points) + + ") < max_points_to_insert(" + std::to_string(max_points_to_insert) + ")", + -1, __FUNCSIG__, __FILE__, __LINE__); if (max_points_to_insert < active_window + consolidate_interval) - throw diskann::ANNException("ERROR: max_points_to_insert < active_window + consolidate_interval", -1, - __FUNCSIG__, __FILE__, __LINE__); + throw diskann::ANNException("ERROR: max_points_to_insert < " + "active_window + consolidate_interval", + -1, __FUNCSIG__, __FILE__, __LINE__); if (consolidate_interval < max_points_to_insert / 1000) throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, __FUNCSIG__, __FILE__, __LINE__); - using TagT = uint32_t; - using LabelT = uint32_t; - const bool enable_tags = true; - - diskann::Index index(diskann::L2, dim, active_window + 4 * consolidate_interval, true, params, - params, enable_tags, true); - index.set_start_points_at_random(static_cast(start_point_norm)); - index.enable_delete(); + index->set_start_points_at_random(static_cast(start_point_norm)); T *data = nullptr; diskann::alloc_aligned((void **)&data, std::max(consolidate_interval, active_window) * aligned_dim * sizeof(T), @@ -236,7 +255,7 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con auto insert_task = std::async(std::launch::async, [&]() { load_aligned_bin_part(data_path, data, 0, active_window); - insert_next_batch(index, 0, active_window, insert_threads, data, aligned_dim); + insert_next_batch(*index, (size_t)0, active_window, params.num_threads, data, aligned_dim); }); insert_task.wait(); @@ -246,7 +265,7 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con auto end = std::min(start + consolidate_interval, max_points_to_insert); auto insert_task = std::async(std::launch::async, [&]() { load_aligned_bin_part(data_path, data, start, end - start); - insert_next_batch(index, start, end, insert_threads, data, aligned_dim); + insert_next_batch(*index, start, end, params.num_threads, data, aligned_dim); }); insert_task.wait(); @@ -257,10 +276,9 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con auto start_del = start - active_window - consolidate_interval; auto end_del = start - active_window; - params.Set("num_threads", consolidate_threads); - - delete_tasks.emplace_back(std::async( - std::launch::async, [&]() { delete_and_consolidate(index, delete_params, start_del, end_del); })); + delete_tasks.emplace_back(std::async(std::launch::async, [&]() { + delete_and_consolidate(*index, delete_params, (size_t)start_del, (size_t)end_del); + })); } } if (delete_tasks.size() > 0) @@ -269,7 +287,7 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n"; const auto save_path_inc = get_save_filename(save_path + ".after-streaming-", active_window, consolidate_interval, max_points_to_insert); - index.save(save_path_inc.c_str(), true); + index->save(save_path_inc.c_str(), true); diskann::aligned_free(data); } @@ -277,48 +295,65 @@ void build_incremental_index(const std::string &data_path, const unsigned L, con int main(int argc, char **argv) { std::string data_type, dist_fn, data_path, index_path_prefix; - unsigned insert_threads, consolidate_threads; - unsigned R, L, num_start_pts; + uint32_t insert_threads, consolidate_threads; + uint32_t R, L, num_start_pts; float alpha, start_point_norm; size_t max_points_to_insert, active_window, consolidate_interval; - po::options_description desc{"Arguments"}; + po::options_description desc{program_options_utils::make_program_description("test_streaming_scenario", + "Test insert deletes & consolidate")}; try { desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("dist_fn", po::value(&dist_fn)->required(), "distance function "); - desc.add_options()("data_path", po::value(&data_path)->required(), - "Input data file in bin format"); - desc.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - "Path prefix for saving index file components"); - desc.add_options()("max_degree,R", po::value(&R)->default_value(64), "Maximum graph degree"); - desc.add_options()("Lbuild,L", po::value(&L)->default_value(100), - "Build complexity, higher value results in better graphs"); - desc.add_options()("alpha", po::value(&alpha)->default_value(1.2f), - "alpha controls density and diameter of graph, set 1 for sparse graph, " - "1.2 or 1.4 for denser graphs with lower diameter"); - desc.add_options()("insert_threads", - po::value(&insert_threads)->default_value(omp_get_num_procs() / 2), - "Number of threads used for inserting into the index (defaults to " - "omp_get_num_procs()/2)"); - desc.add_options()("consolidate_threads", - po::value(&consolidate_threads)->default_value(omp_get_num_procs() / 2), - "Number of threads used for consolidating deletes to " - "the index (defaults to omp_get_num_procs()/2)"); - - desc.add_options()("max_points_to_insert", po::value(&max_points_to_insert)->default_value(0), - "The number of points from the file that the program streams over "); - desc.add_options()("active_window", po::value(&active_window)->required(), - "Program maintains an index over an active window of " - "this size that slides through the data"); - desc.add_options()("consolidate_interval", po::value(&consolidate_interval)->required(), - "The program simultaneously adds this number of points to the right of " - "the window while deleting the same number from the left"); - desc.add_options()("start_point_norm", po::value(&start_point_norm)->required(), - "Set the start point to a random point on a sphere of this radius"); - desc.add_options()("num_start_points", po::value(&num_start_pts)->default_value(0), - "Set the number of random start (frozen) points to use when inserting and searching"); + + // Required parameters + po::options_description required_configs("Required"); + required_configs.add_options()("data_type", po::value(&data_type)->required(), + program_options_utils::DATA_TYPE_DESCRIPTION); + required_configs.add_options()("dist_fn", po::value(&dist_fn)->required(), + program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); + required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), + program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); + required_configs.add_options()("data_path", po::value(&data_path)->required(), + program_options_utils::INPUT_DATA_PATH); + required_configs.add_options()("active_window", po::value(&active_window)->required(), + "Program maintains an index over an active window of " + "this size that slides through the data"); + required_configs.add_options()("consolidate_interval", po::value(&consolidate_interval)->required(), + "The program simultaneously adds this number of points to the " + "right of " + "the window while deleting the same number from the left"); + required_configs.add_options()("start_point_norm", po::value(&start_point_norm)->required(), + "Set the start point to a random point on a sphere of this radius"); + + // Optional parameters + po::options_description optional_configs("Optional"); + optional_configs.add_options()("max_degree,R", po::value(&R)->default_value(64), + program_options_utils::MAX_BUILD_DEGREE); + optional_configs.add_options()("Lbuild,L", po::value(&L)->default_value(100), + program_options_utils::GRAPH_BUILD_COMPLEXITY); + optional_configs.add_options()("alpha", po::value(&alpha)->default_value(1.2f), + program_options_utils::GRAPH_BUILD_ALPHA); + optional_configs.add_options()("insert_threads", + po::value(&insert_threads)->default_value(omp_get_num_procs() / 2), + "Number of threads used for inserting into the index (defaults to " + "omp_get_num_procs()/2)"); + optional_configs.add_options()( + "consolidate_threads", po::value(&consolidate_threads)->default_value(omp_get_num_procs() / 2), + "Number of threads used for consolidating deletes to " + "the index (defaults to omp_get_num_procs()/2)"); + optional_configs.add_options()("max_points_to_insert", + po::value(&max_points_to_insert)->default_value(0), + "The number of points from the file that the program streams " + "over "); + optional_configs.add_options()( + "num_start_points", + po::value(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC), + "Set the number of random start (frozen) points to use when " + "inserting and searching"); + + // Merge required and optional parameters + desc.add(required_configs).add(optional_configs); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); diff --git a/tests/utils/CMakeLists.txt b/apps/utils/CMakeLists.txt similarity index 69% rename from tests/utils/CMakeLists.txt rename to apps/utils/CMakeLists.txt index 40aa050c8..3b8cf223c 100644 --- a/tests/utils/CMakeLists.txt +++ b/apps/utils/CMakeLists.txt @@ -1,7 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_COMPILE_WARNING_AS_ERROR ON) + add_executable(fvecs_to_bin fvecs_to_bin.cpp) @@ -44,11 +46,15 @@ add_executable(simulate_aggregate_recall simulate_aggregate_recall.cpp) add_executable(calculate_recall calculate_recall.cpp) target_link_libraries(calculate_recall ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS}) -# This is the only thing outside of DiskANN main source that depends on MKL. +# Compute ground truth thing outside of DiskANN main source that depends on MKL. add_executable(compute_groundtruth compute_groundtruth.cpp) target_include_directories(compute_groundtruth PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) target_link_libraries(compute_groundtruth ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options) +add_executable(compute_groundtruth_for_filters compute_groundtruth_for_filters.cpp) +target_include_directories(compute_groundtruth_for_filters PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) +target_link_libraries(compute_groundtruth_for_filters ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options) + add_executable(generate_pq generate_pq.cpp) target_link_libraries(generate_pq ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS}) @@ -71,3 +77,34 @@ target_link_libraries(generate_synthetic_labels ${PROJECT_NAME} Boost::program_o add_executable(stats_label_data stats_label_data.cpp) target_link_libraries(stats_label_data ${PROJECT_NAME} Boost::program_options) + +if (NOT MSVC) + include(GNUInstallDirs) + install(TARGETS fvecs_to_bin + fvecs_to_bvecs + rand_data_gen + float_bin_to_int8 + ivecs_to_bin + count_bfs_levels + tsv_to_bin + bin_to_tsv + int8_to_float + int8_to_float_scale + uint8_to_float + uint32_to_uint8 + vector_analysis + gen_random_slice + simulate_aggregate_recall + calculate_recall + compute_groundtruth + compute_groundtruth_for_filters + generate_pq + partition_data + partition_with_ram_budget + merge_shards + create_disk_layout + generate_synthetic_labels + stats_label_data + RUNTIME + ) +endif() \ No newline at end of file diff --git a/tests/utils/bin_to_fvecs.cpp b/apps/utils/bin_to_fvecs.cpp similarity index 70% rename from tests/utils/bin_to_fvecs.cpp rename to apps/utils/bin_to_fvecs.cpp index 13803b74c..e9a6a8ecc 100644 --- a/tests/utils/bin_to_fvecs.cpp +++ b/apps/utils/bin_to_fvecs.cpp @@ -4,11 +4,12 @@ #include #include "util.h" -void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, _u64 npts, _u64 ndims) +void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, uint64_t npts, + uint64_t ndims) { writr.write((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(unsigned))); #pragma omp parallel for - for (_u64 i = 0; i < npts; i++) + for (uint64_t i = 0; i < npts; i++) { memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float)); } @@ -25,31 +26,31 @@ int main(int argc, char **argv) std::ifstream readr(argv[1], std::ios::binary); int npts_s32; int ndims_s32; - readr.read((char *)&npts_s32, sizeof(_s32)); - readr.read((char *)&ndims_s32, sizeof(_s32)); + readr.read((char *)&npts_s32, sizeof(int32_t)); + readr.read((char *)&ndims_s32, sizeof(int32_t)); size_t npts = npts_s32; size_t ndims = ndims_s32; - _u32 ndims_u32 = (_u32)ndims_s32; - // _u64 fsize = writr.tellg(); + uint32_t ndims_u32 = (uint32_t)ndims_s32; + // uint64_t fsize = writr.tellg(); readr.seekg(0, std::ios::beg); unsigned ndims_u32; writr.write((char *)&ndims_u32, sizeof(unsigned)); writr.seekg(0, std::ios::beg); - _u64 ndims = (_u64)ndims_u32; - _u64 npts = fsize / ((ndims + 1) * sizeof(float)); + uint64_t ndims = (uint64_t)ndims_u32; + uint64_t npts = fsize / ((ndims + 1) * sizeof(float)); std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - _u64 blk_size = 131072; - _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + uint64_t blk_size = 131072; + uint64_t nblks = ROUND_UP(npts, blk_size) / blk_size; std::cout << "# blks: " << nblks << std::endl; std::ofstream writr(argv[2], std::ios::binary); float *read_buf = new float[npts * (ndims + 1)]; float *write_buf = new float[npts * ndims]; - for (_u64 i = 0; i < nblks; i++) + for (uint64_t i = 0; i < nblks; i++) { - _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + uint64_t cblk_size = std::min(npts - i * blk_size, blk_size); block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims); std::cout << "Block #" << i << " written" << std::endl; } diff --git a/tests/utils/bin_to_tsv.cpp b/apps/utils/bin_to_tsv.cpp similarity index 75% rename from tests/utils/bin_to_tsv.cpp rename to apps/utils/bin_to_tsv.cpp index 813fa6e9f..7851bef6d 100644 --- a/tests/utils/bin_to_tsv.cpp +++ b/apps/utils/bin_to_tsv.cpp @@ -4,13 +4,14 @@ #include #include "utils.h" -template void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, _u64 npts, _u64 ndims) +template +void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, size_t npts, size_t ndims) { reader.read((char *)read_buf, npts * ndims * sizeof(float)); - for (_u64 i = 0; i < npts; i++) + for (size_t i = 0; i < npts; i++) { - for (_u64 d = 0; d < ndims; d++) + for (size_t d = 0; d < ndims; d++) { writer << read_buf[d + i * ndims]; if (d < ndims - 1) @@ -36,22 +37,22 @@ int main(int argc, char **argv) } std::ifstream reader(argv[2], std::ios::binary); - _u32 npts_u32; - _u32 ndims_u32; - reader.read((char *)&npts_u32, sizeof(_s32)); - reader.read((char *)&ndims_u32, sizeof(_s32)); + uint32_t npts_u32; + uint32_t ndims_u32; + reader.read((char *)&npts_u32, sizeof(uint32_t)); + reader.read((char *)&ndims_u32, sizeof(uint32_t)); size_t npts = npts_u32; size_t ndims = ndims_u32; std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - _u64 blk_size = 131072; - _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; std::ofstream writer(argv[3]); char *read_buf = new char[blk_size * ndims * 4]; - for (_u64 i = 0; i < nblks; i++) + for (size_t i = 0; i < nblks; i++) { - _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + size_t cblk_size = std::min(npts - i * blk_size, blk_size); if (type_string == std::string("float")) block_convert(writer, reader, (float *)read_buf, cblk_size, ndims); else if (type_string == std::string("int8")) diff --git a/tests/utils/calculate_recall.cpp b/apps/utils/calculate_recall.cpp similarity index 79% rename from tests/utils/calculate_recall.cpp rename to apps/utils/calculate_recall.cpp index 3307104ce..dc76252cc 100644 --- a/tests/utils/calculate_recall.cpp +++ b/apps/utils/calculate_recall.cpp @@ -19,9 +19,9 @@ int main(int argc, char **argv) std::cout << argv[0] << " " << std::endl; return -1; } - unsigned *gold_std = NULL; + uint32_t *gold_std = NULL; float *gs_dist = nullptr; - unsigned *our_results = NULL; + uint32_t *our_results = NULL; float *or_dist = nullptr; size_t points_num, points_num_gs, points_num_or; size_t dim_gs; @@ -31,7 +31,9 @@ int main(int argc, char **argv) if (points_num_gs != points_num_or) { - std::cout << "Error. Number of queries mismatch in ground truth and our results" << std::endl; + std::cout << "Error. Number of queries mismatch in ground truth and " + "our results" + << std::endl; return -1; } points_num = points_num_gs; @@ -45,7 +47,8 @@ int main(int argc, char **argv) return -1; } std::cout << "Calculating recall@" << recall_at << std::endl; - float recall_val = diskann::calculate_recall(points_num, gold_std, gs_dist, dim_gs, our_results, dim_or, recall_at); + double recall_val = diskann::calculate_recall((uint32_t)points_num, gold_std, gs_dist, (uint32_t)dim_gs, + our_results, (uint32_t)dim_or, (uint32_t)recall_at); // double avg_recall = (recall*1.0)/(points_num*1.0); std::cout << "Avg. recall@" << recall_at << " is " << recall_val << "\n"; diff --git a/tests/utils/compute_groundtruth.cpp b/apps/utils/compute_groundtruth.cpp similarity index 63% rename from tests/utils/compute_groundtruth.cpp rename to apps/utils/compute_groundtruth.cpp index 231579d77..f33a26b84 100644 --- a/tests/utils/compute_groundtruth.cpp +++ b/apps/utils/compute_groundtruth.cpp @@ -18,13 +18,15 @@ #include #include #include +#include +#include #ifdef _WINDOWS #include #else #include #endif - +#include "filter_utils.h" #include "utils.h" // WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED) @@ -32,6 +34,10 @@ #define PARTSIZE 10000000 #define ALIGNMENT 512 +// custom types (for readability) +typedef tsl::robin_set label_set; +typedef std::string path; + namespace po = boost::program_options; template T div_round_up(const T numerator, const T denominator) @@ -39,7 +45,7 @@ template T div_round_up(const T numerator, const T denominator) return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator); } -using pairIF = std::pair; +using pairIF = std::pair; struct cmpmaxstruct { bool operator()(const pairIF &l, const pairIF &r) @@ -64,13 +70,13 @@ inline bool custom_dist(const std::pair &a, const std::pair::epsilon(); } - for (_u32 j = 0; j < dim; j++) + for (uint32_t j = 0; j < dim; j++) { points[i * dim + j] = points_in[i * dim + j] / norm; } } #pragma omp parallel for schedule(static, 4096) - for (_s64 i = 0; i < (_s64)nqueries; i++) + for (int64_t i = 0; i < (int64_t)nqueries; i++) { float norm = std::sqrt(queries_l2sq[i]); if (norm == 0) { norm = std::numeric_limits::epsilon(); } - for (_u32 j = 0; j < dim; j++) + for (uint32_t j = 0; j < dim; j++) { queries[i * dim + j] = queries_in[i * dim + j] / norm; } @@ -186,7 +192,7 @@ void exact_knn(const size_t dim, const size_t k, size_t q_batch_size = (1 << 9); float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; - for (_u64 b = 0; b < div_round_up(nqueries, q_batch_size); ++b) + for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) { int64_t q_b = b * q_batch_size; int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; @@ -207,9 +213,9 @@ void exact_knn(const size_t dim, const size_t k, for (long long q = q_b; q < q_e; q++) { maxPQIFCS point_dist; - for (_u64 p = 0; p < k; p++) + for (size_t p = 0; p < k; p++) point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); - for (_u64 p = k; p < npoints; p++) + for (size_t p = k; p < npoints; p++) { if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); @@ -251,13 +257,14 @@ template inline int get_num_parts(const char *filename) reader.read((char *)&ndims_i32, sizeof(int)); std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; reader.close(); - int num_parts = (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : std::floor(npts_i32 / PARTSIZE) + 1; + uint32_t num_parts = + (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; std::cout << "Number of parts: " << num_parts << std::endl; return num_parts; } template -inline void load_bin_as_float(const char *filename, float *&data, size_t &npts_u64, size_t &ndims_u64, int part_num) +inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, int part_num) { std::ifstream reader; reader.exceptions(std::ios::failbit | std::ios::badbit); @@ -268,87 +275,28 @@ inline void load_bin_as_float(const char *filename, float *&data, size_t &npts_u reader.read((char *)&ndims_i32, sizeof(int)); uint64_t start_id = part_num * PARTSIZE; uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); - npts_u64 = end_id - start_id; - ndims_u64 = (uint64_t)ndims_i32; - std::cout << "#pts in part = " << npts_u64 << ", #dims = " << ndims_u64 - << ", size = " << npts_u64 * ndims_u64 * sizeof(T) << "B" << std::endl; - - reader.seekg(start_id * ndims_u64 * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); - T *data_T = new T[npts_u64 * ndims_u64]; - reader.read((char *)data_T, sizeof(T) * npts_u64 * ndims_u64); - std::cout << "Finished reading part of the bin file." << std::endl; - reader.close(); - data = aligned_malloc(npts_u64 * ndims_u64, ALIGNMENT); -#pragma omp parallel for schedule(dynamic, 32768) - for (int64_t i = 0; i < (int64_t)npts_u64; i++) - { - for (int64_t j = 0; j < (int64_t)ndims_u64; j++) - { - float cur_val_float = (float)data_T[i * ndims_u64 + j]; - std::memcpy((char *)(data + i * ndims_u64 + j), (char *)&cur_val_float, sizeof(float)); - } - } - delete[] data_T; - std::cout << "Finished converting part data to float." << std::endl; -} - -template -inline std::vector load_filtered_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, - int part_num, const char *label_file, - const std::string &filter_label, - const std::string &universal_label, size_t &npoints_filt, - std::vector> &pts_to_labels) -{ - std::ifstream reader(filename, std::ios::binary); - if (reader.fail()) - { - throw diskann::ANNException(std::string("Failed to open file ") + filename, -1); - } - - std::cout << "Reading bin file " << filename << " ...\n"; - int npts_i32, ndims_i32; - std::vector rev_map; - reader.read((char *)&npts_i32, sizeof(int)); - reader.read((char *)&ndims_i32, sizeof(int)); - uint64_t start_id = part_num * PARTSIZE; - uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); npts = end_id - start_id; - ndims = (unsigned)ndims_i32; - uint64_t nptsuint64_t = (uint64_t)npts; - uint64_t ndimsuint64_t = (uint64_t)ndims; - npoints_filt = 0; - std::cout << "#pts in part = " << npts << ", #dims = " << ndims - << ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" << std::endl; - std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl; - reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); + ndims = (uint64_t)ndims_i32; + std::cout << "#pts in part = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) << "B" + << std::endl; - T *data_T = new T[nptsuint64_t * ndimsuint64_t]; - reader.read((char *)data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t); + reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); + T *data_T = new T[npts * ndims]; + reader.read((char *)data_T, sizeof(T) * npts * ndims); std::cout << "Finished reading part of the bin file." << std::endl; reader.close(); - - data = aligned_malloc(nptsuint64_t * ndimsuint64_t, ALIGNMENT); - - for (int64_t i = 0; i < (int64_t)nptsuint64_t; i++) + data = aligned_malloc(npts * ndims, ALIGNMENT); +#pragma omp parallel for schedule(dynamic, 32768) + for (int64_t i = 0; i < (int64_t)npts; i++) { - if (std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), filter_label) != - pts_to_labels[start_id + i].end() || - std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), universal_label) != - pts_to_labels[start_id + i].end()) + for (int64_t j = 0; j < (int64_t)ndims; j++) { - rev_map.push_back(start_id + i); - for (int64_t j = 0; j < (int64_t)ndimsuint64_t; j++) - { - float cur_val_float = (float)data_T[i * ndimsuint64_t + j]; - std::memcpy((char *)(data + npoints_filt * ndimsuint64_t + j), (char *)&cur_val_float, sizeof(float)); - } - npoints_filt++; + float cur_val_float = (float)data_T[i * ndims + j]; + std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, sizeof(float)); } } delete[] data_T; - std::cout << "Finished converting part data to float.. identified " << npoints_filt - << " points matching the filter." << std::endl; - return rev_map; + std::cout << "Finished converting part data to float." << std::endl; } template inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims) @@ -377,7 +325,7 @@ inline void save_groundtruth_as_one_file(const std::string filename, int32_t *da writer.write((char *)&ndims_i32, sizeof(int)); std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " "npts*dim dist-matrix) with npts = " - << npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(unsigned) + 2 * sizeof(int) + << npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) << "B" << std::endl; writer.write((char *)data, npts * ndims * sizeof(uint32_t)); @@ -386,153 +334,73 @@ inline void save_groundtruth_as_one_file(const std::string filename, int32_t *da std::cout << "Finished writing truthset" << std::endl; } -inline void parse_label_file_into_vec(size_t &line_cnt, const std::string &map_file, - std::vector> &pts_to_labels) +template +std::vector>> processUnfilteredParts(const std::string &base_file, + size_t &nqueries, size_t &npoints, + size_t &dim, size_t &k, float *query_data, + const diskann::Metric &metric, + std::vector &location_to_tag) { - std::ifstream infile(map_file); - std::string line, token; - std::set labels; - infile.clear(); - infile.seekg(0, std::ios::beg); - while (std::getline(infile, line)) + float *base_data = nullptr; + int num_parts = get_num_parts(base_file.c_str()); + std::vector>> res(nqueries); + for (int p = 0; p < num_parts; p++) { - std::istringstream iss(line); - std::vector lbls(0); + size_t start_id = p * PARTSIZE; + load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) - { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - lbls.push_back(token); - labels.insert(token); - } - if (lbls.size() <= 0) + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints ? k : npoints; + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data, + metric); + + for (size_t i = 0; i < nqueries; i++) { - std::cout << "No label found"; - exit(-1); + for (size_t j = 0; j < part_k; j++) + { + if (!location_to_tag.empty()) + if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + continue; + + res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), + dist_closest_points_part[i * part_k + j])); + } } - std::sort(lbls.begin(), lbls.end()); - pts_to_labels.push_back(lbls); + + delete[] closest_points_part; + delete[] dist_closest_points_part; + + diskann::aligned_free(base_data); } - std::cout << "Identified " << labels.size() << " distinct label(s), and populated labels for " - << pts_to_labels.size() << " points" << std::endl; -} + return res; +}; template -int aux_main(const std::string &base_file, const std::string &label_file, const std::string &query_file, - const std::string >_file, size_t k, const std::string &filter_label, const std::string &universal_label, +int aux_main(const std::string &base_file, const std::string &query_file, const std::string >_file, size_t k, const diskann::Metric &metric, const std::string &tags_file = std::string("")) { - size_t npoints, nqueries, dim, npoints_filt; + size_t npoints, nqueries, dim; - float *base_data; float *query_data; - const bool tags_enabled = tags_file.empty() ? false : true; - - int num_parts = get_num_parts(base_file.c_str()); load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); if (nqueries > PARTSIZE) std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE << ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl; // load tags - std::vector location_to_tag; - if (tags_enabled) - { - size_t tag_file_ndims, tag_file_npts; - std::uint32_t *tag_data; - diskann::load_bin(tags_file, tag_data, tag_file_npts, tag_file_ndims); - if (tag_file_ndims != 1) - { - diskann::cerr << "tags file error" << std::endl; - throw diskann::ANNException("tag file error", -1, __FUNCSIG__, __FILE__, __LINE__); - } - - // check if the point count match - size_t base_file_npts, base_file_ndims; - diskann::get_bin_metadata(base_file, base_file_npts, base_file_ndims); - if (base_file_npts != tag_file_npts) - { - diskann::cerr << "point num in tags file mismatch" << std::endl; - throw diskann::ANNException("point num in tags file mismatch", -1, __FUNCSIG__, __FILE__, __LINE__); - } - - location_to_tag.assign(tag_data, tag_data + tag_file_npts); - delete[] tag_data; - } - - std::vector>> results(nqueries); + const bool tags_enabled = tags_file.empty() ? false : true; + std::vector location_to_tag = diskann::loadTags(tags_file, base_file); int *closest_points = new int[nqueries * k]; float *dist_closest_points = new float[nqueries * k]; - std::vector> pts_to_labels; - if (filter_label != "") - parse_label_file_into_vec(npoints, label_file, pts_to_labels); - std::vector rev_map; + std::vector>> results = + processUnfilteredParts(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag); - for (int p = 0; p < num_parts; p++) - { - size_t start_id = p * PARTSIZE; - if (filter_label == "") - { - load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); - } - else - { - rev_map = load_filtered_bin_as_float(base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(), - filter_label, universal_label, npoints_filt, pts_to_labels); - } - int *closest_points_part = new int[nqueries * k]; - float *dist_closest_points_part = new float[nqueries * k]; - - _u32 part_k; - if (filter_label == "") - { - part_k = k < npoints ? k : npoints; - exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, - query_data, metric); - } - else - { - part_k = k < npoints_filt ? k : npoints_filt; - if (npoints_filt > 0) - { - exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints_filt, base_data, nqueries, - query_data, metric); - } - } - - for (_u64 i = 0; i < nqueries; i++) - { - for (_u64 j = 0; j < part_k; j++) - { - if (tags_enabled) - if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) - continue; - if (filter_label == "") - { - results[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), - dist_closest_points_part[i * part_k + j])); - } - else - { - results[i].push_back(std::make_pair((uint32_t)(rev_map[closest_points_part[i * part_k + j]]), - dist_closest_points_part[i * part_k + j])); - } - } - } - - delete[] closest_points_part; - delete[] dist_closest_points_part; - - diskann::aligned_free(base_data); - } - - for (_u64 i = 0; i < nqueries; i++) + for (size_t i = 0; i < nqueries; i++) { std::vector> &cur_res = results[i]; std::sort(cur_res.begin(), cur_res.end(), custom_dist); @@ -563,16 +431,65 @@ int aux_main(const std::string &base_file, const std::string &label_file, const } save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k); - diskann::aligned_free(query_data); delete[] closest_points; delete[] dist_closest_points; + diskann::aligned_free(query_data); + return 0; } +void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim) +{ + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (uint32_t)npts_i32; + dim = (uint32_t)dim_i32; + + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; + + int truthset_type = -1; // 1 means truthset has ids and distances, 2 means + // only ids, -1 is error + size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_with_dists) + truthset_type = 1; + + size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_just_ids) + truthset_type = 2; + + if (truthset_type == -1) + { + std::stringstream stream; + stream << "Error. File size mismatch. File should have bin format, with " + "npts followed by ngt followed by npts*ngt ids and optionally " + "followed by npts*ngt distance values; actual size: " + << actual_file_size << ", expected: " << expected_file_size_with_dists << " or " + << expected_file_size_just_ids; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + ids = new uint32_t[npts * dim]; + reader.read((char *)ids, npts * dim * sizeof(uint32_t)); + + if (truthset_type == 1) + { + dists = new float[npts * dim]; + reader.read((char *)dists, npts * dim * sizeof(float)); + } +} + int main(int argc, char **argv) { - std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, label_file, filter_label, - universal_label; + std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file; uint64_t K; try @@ -587,15 +504,12 @@ int main(int argc, char **argv) "File containing the base vectors in binary format"); desc.add_options()("query_file", po::value(&query_file)->required(), "File containing the query vectors in binary format"); - desc.add_options()("label_file", po::value(&label_file)->default_value(""), - "Input labels file in txt format if present"); - desc.add_options()("filter_label", po::value(&filter_label)->default_value(""), - "Input filter label if doing filtered groundtruth"); - desc.add_options()("universal_label", po::value(&universal_label)->default_value(""), - "Universal label, if using it, only in conjunction with label_file"); - desc.add_options()("gt_file", po::value(>_file)->required(), - "File name for the writing ground truth in binary format"); + "File name for the writing ground truth in binary " + "format, please don' append .bin at end if " + "no filter_label or filter_label_file is provided it " + "will save the file with '.bin' at end." + "else it will save the file as filename_label.bin"); desc.add_options()("K", po::value(&K)->required(), "Number of ground truth nearest neighbors to compute"); desc.add_options()("tags_file", po::value(&tags_file)->default_value(std::string()), @@ -644,14 +558,11 @@ int main(int argc, char **argv) try { if (data_type == std::string("float")) - aux_main(base_file, label_file, query_file, gt_file, K, filter_label, universal_label, metric, - tags_file); + aux_main(base_file, query_file, gt_file, K, metric, tags_file); if (data_type == std::string("int8")) - aux_main(base_file, label_file, query_file, gt_file, K, filter_label, universal_label, metric, - tags_file); + aux_main(base_file, query_file, gt_file, K, metric, tags_file); if (data_type == std::string("uint8")) - aux_main(base_file, label_file, query_file, gt_file, K, filter_label, universal_label, metric, - tags_file); + aux_main(base_file, query_file, gt_file, K, metric, tags_file); } catch (const std::exception &e) { diff --git a/apps/utils/compute_groundtruth_for_filters.cpp b/apps/utils/compute_groundtruth_for_filters.cpp new file mode 100644 index 000000000..5be7135e1 --- /dev/null +++ b/apps/utils/compute_groundtruth_for_filters.cpp @@ -0,0 +1,924 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WINDOWS +#include +#else +#include +#endif + +#include "filter_utils.h" +#include "utils.h" + +// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED) + +#define PARTSIZE 10000000 +#define ALIGNMENT 512 + +// custom types (for readability) +typedef tsl::robin_set label_set; +typedef std::string path; + +namespace po = boost::program_options; + +template T div_round_up(const T numerator, const T denominator) +{ + return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator); +} + +using pairIF = std::pair; +struct cmpmaxstruct +{ + bool operator()(const pairIF &l, const pairIF &r) + { + return l.second < r.second; + }; +}; + +using maxPQIFCS = std::priority_queue, cmpmaxstruct>; + +template T *aligned_malloc(const size_t n, const size_t alignment) +{ +#ifdef _WINDOWS + return (T *)_aligned_malloc(sizeof(T) * n, alignment); +#else + return static_cast(aligned_alloc(alignment, sizeof(T) * n)); +#endif +} + +inline bool custom_dist(const std::pair &a, const std::pair &b) +{ + return a.second < b.second; +} + +void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim) +{ + assert(points_l2sq != NULL); +#pragma omp parallel for schedule(static, 65536) + for (int64_t d = 0; d < num_points; ++d) + points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, + matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); +} + +void distsq_to_points(const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, + const float *const points_l2sq, // points in Col major + size_t nqueries, const float *const queries, + const float *const queries_l2sq, // queries in Col major + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +{ + bool ones_vec_alloc = false; + if (ones_vec == NULL) + { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim, + (float)0.0, dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints, + ones_vec, nqueries, (float)1.0, dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints, + queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints); + if (ones_vec_alloc) + delete[] ones_vec; +} + +void inner_prod_to_points(const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, size_t nqueries, const float *const queries, + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +{ + bool ones_vec_alloc = false; + if (ones_vec == NULL) + { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim, + (float)0.0, dist_matrix, npoints); + + if (ones_vec_alloc) + delete[] ones_vec; +} + +void exact_knn(const size_t dim, const size_t k, + size_t *const closest_points, // k * num_queries preallocated, col + // major, queries columns + float *const dist_closest_points, // k * num_queries + // preallocated, Dist to + // corresponding closes_points + size_t npoints, + float *points_in, // points in Col major + size_t nqueries, float *queries_in, + diskann::Metric metric = diskann::Metric::L2) // queries in Col major +{ + float *points_l2sq = new float[npoints]; + float *queries_l2sq = new float[nqueries]; + compute_l2sq(points_l2sq, points_in, npoints, dim); + compute_l2sq(queries_l2sq, queries_in, nqueries, dim); + + float *points = points_in; + float *queries = queries_in; + + if (metric == diskann::Metric::COSINE) + { // we convert cosine distance as + // normalized L2 distnace + points = new float[npoints * dim]; + queries = new float[nqueries * dim]; +#pragma omp parallel for schedule(static, 4096) + for (int64_t i = 0; i < (int64_t)npoints; i++) + { + float norm = std::sqrt(points_l2sq[i]); + if (norm == 0) + { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) + { + points[i * dim + j] = points_in[i * dim + j] / norm; + } + } + +#pragma omp parallel for schedule(static, 4096) + for (int64_t i = 0; i < (int64_t)nqueries; i++) + { + float norm = std::sqrt(queries_l2sq[i]); + if (norm == 0) + { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) + { + queries[i * dim + j] = queries_in[i * dim + j] / norm; + } + } + // recalculate norms after normalizing, they should all be one. + compute_l2sq(points_l2sq, points, npoints, dim); + compute_l2sq(queries_l2sq, queries, nqueries, dim); + } + + std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in " + << dim << " dimensions using"; + if (metric == diskann::Metric::INNER_PRODUCT) + std::cout << " MIPS "; + else if (metric == diskann::Metric::COSINE) + std::cout << " Cosine "; + else + std::cout << " L2 "; + std::cout << "distance fn. " << std::endl; + + size_t q_batch_size = (1 << 9); + float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; + + for (uint64_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) + { + int64_t q_b = b * q_batch_size; + int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; + + if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) + { + distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b); + } + else + { + inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); + } + std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl; + +#pragma omp parallel for schedule(dynamic, 16) + for (long long q = q_b; q < q_e; q++) + { + maxPQIFCS point_dist; + for (size_t p = 0; p < k; p++) + point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + for (size_t p = k; p < npoints; p++) + { + if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]) + point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + if (point_dist.size() > k) + point_dist.pop(); + } + for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l) + { + closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first; + dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second; + point_dist.pop(); + } + assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k, + dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k)); + } + std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl; + } + + delete[] dist_matrix; + + delete[] points_l2sq; + delete[] queries_l2sq; + + if (metric == diskann::Metric::COSINE) + { + delete[] points; + delete[] queries; + } +} + +template inline int get_num_parts(const char *filename) +{ + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; + reader.close(); + int num_parts = (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; + std::cout << "Number of parts: " << num_parts << std::endl; + return num_parts; +} + +template +inline void load_bin_as_float(const char *filename, float *&data, size_t &npts_u64, size_t &ndims_u64, int part_num) +{ + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + uint64_t start_id = part_num * PARTSIZE; + uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); + npts_u64 = end_id - start_id; + ndims_u64 = (uint64_t)ndims_i32; + std::cout << "#pts in part = " << npts_u64 << ", #dims = " << ndims_u64 + << ", size = " << npts_u64 * ndims_u64 * sizeof(T) << "B" << std::endl; + + reader.seekg(start_id * ndims_u64 * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); + T *data_T = new T[npts_u64 * ndims_u64]; + reader.read((char *)data_T, sizeof(T) * npts_u64 * ndims_u64); + std::cout << "Finished reading part of the bin file." << std::endl; + reader.close(); + data = aligned_malloc(npts_u64 * ndims_u64, ALIGNMENT); +#pragma omp parallel for schedule(dynamic, 32768) + for (int64_t i = 0; i < (int64_t)npts_u64; i++) + { + for (int64_t j = 0; j < (int64_t)ndims_u64; j++) + { + float cur_val_float = (float)data_T[i * ndims_u64 + j]; + std::memcpy((char *)(data + i * ndims_u64 + j), (char *)&cur_val_float, sizeof(float)); + } + } + delete[] data_T; + std::cout << "Finished converting part data to float." << std::endl; +} + +template +inline std::vector load_filtered_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, + int part_num, const char *label_file, + const std::string &filter_label, + const std::string &universal_label, size_t &npoints_filt, + std::vector> &pts_to_labels) +{ + std::ifstream reader(filename, std::ios::binary); + if (reader.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + filename, -1); + } + + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + std::vector rev_map; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + uint64_t start_id = part_num * PARTSIZE; + uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); + npts = end_id - start_id; + ndims = (uint32_t)ndims_i32; + uint64_t nptsuint64_t = (uint64_t)npts; + uint64_t ndimsuint64_t = (uint64_t)ndims; + npoints_filt = 0; + std::cout << "#pts in part = " << npts << ", #dims = " << ndims + << ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" << std::endl; + std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl; + reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); + + T *data_T = new T[nptsuint64_t * ndimsuint64_t]; + reader.read((char *)data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t); + std::cout << "Finished reading part of the bin file." << std::endl; + reader.close(); + + data = aligned_malloc(nptsuint64_t * ndimsuint64_t, ALIGNMENT); + + for (int64_t i = 0; i < (int64_t)nptsuint64_t; i++) + { + if (std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), filter_label) != + pts_to_labels[start_id + i].end() || + std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), universal_label) != + pts_to_labels[start_id + i].end()) + { + rev_map.push_back(start_id + i); + for (int64_t j = 0; j < (int64_t)ndimsuint64_t; j++) + { + float cur_val_float = (float)data_T[i * ndimsuint64_t + j]; + std::memcpy((char *)(data + npoints_filt * ndimsuint64_t + j), (char *)&cur_val_float, sizeof(float)); + } + npoints_filt++; + } + } + delete[] data_T; + std::cout << "Finished converting part data to float.. identified " << npoints_filt + << " points matching the filter." << std::endl; + return rev_map; +} + +template inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims) +{ + std::ofstream writer; + writer.exceptions(std::ios::failbit | std::ios::badbit); + writer.open(filename, std::ios::binary | std::ios::out); + std::cout << "Writing bin: " << filename << "\n"; + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "bin: #pts = " << npts << ", #dims = " << ndims + << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(T)); + writer.close(); + std::cout << "Finished writing bin" << std::endl; +} + +inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts, + size_t ndims) +{ + std::ofstream writer(filename, std::ios::binary | std::ios::out); + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " + "npts*dim dist-matrix) with npts = " + << npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) + << "B" << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(uint32_t)); + writer.write((char *)distances, npts * ndims * sizeof(float)); + writer.close(); + std::cout << "Finished writing truthset" << std::endl; +} + +inline void parse_label_file_into_vec(size_t &line_cnt, const std::string &map_file, + std::vector> &pts_to_labels) +{ + std::ifstream infile(map_file); + std::string line, token; + std::set labels; + infile.clear(); + infile.seekg(0, std::ios::beg); + while (std::getline(infile, line)) + { + std::istringstream iss(line); + std::vector lbls(0); + + getline(iss, token, '\t'); + std::istringstream new_iss(token); + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + lbls.push_back(token); + labels.insert(token); + } + if (lbls.size() <= 0) + { + std::cout << "No label found"; + exit(-1); + } + std::sort(lbls.begin(), lbls.end()); + pts_to_labels.push_back(lbls); + } + std::cout << "Identified " << labels.size() << " distinct label(s), and populated labels for " + << pts_to_labels.size() << " points" << std::endl; +} + +template +std::vector>> processUnfilteredParts(const std::string &base_file, + size_t &nqueries, size_t &npoints, + size_t &dim, size_t &k, float *query_data, + const diskann::Metric &metric, + std::vector &location_to_tag) +{ + float *base_data = nullptr; + int num_parts = get_num_parts(base_file.c_str()); + std::vector>> res(nqueries); + for (int p = 0; p < num_parts; p++) + { + size_t start_id = p * PARTSIZE; + load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); + + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints ? k : npoints; + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data, + metric); + + for (size_t i = 0; i < nqueries; i++) + { + for (uint64_t j = 0; j < part_k; j++) + { + if (!location_to_tag.empty()) + if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + continue; + + res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id), + dist_closest_points_part[i * part_k + j])); + } + } + + delete[] closest_points_part; + delete[] dist_closest_points_part; + + diskann::aligned_free(base_data); + } + return res; +}; + +template +std::vector>> processFilteredParts( + const std::string &base_file, const std::string &label_file, const std::string &filter_label, + const std::string &universal_label, size_t &nqueries, size_t &npoints, size_t &dim, size_t &k, float *query_data, + const diskann::Metric &metric, std::vector &location_to_tag) +{ + size_t npoints_filt = 0; + float *base_data = nullptr; + std::vector>> res(nqueries); + int num_parts = get_num_parts(base_file.c_str()); + + std::vector> pts_to_labels; + if (filter_label != "") + parse_label_file_into_vec(npoints, label_file, pts_to_labels); + + for (int p = 0; p < num_parts; p++) + { + size_t start_id = p * PARTSIZE; + std::vector rev_map; + if (filter_label != "") + rev_map = load_filtered_bin_as_float(base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(), + filter_label, universal_label, npoints_filt, pts_to_labels); + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints_filt ? k : npoints_filt; + if (npoints_filt > 0) + { + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints_filt, base_data, nqueries, + query_data, metric); + } + + for (size_t i = 0; i < nqueries; i++) + { + for (uint64_t j = 0; j < part_k; j++) + { + if (!location_to_tag.empty()) + if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) + continue; + + res[i].push_back(std::make_pair((uint32_t)(rev_map[closest_points_part[i * part_k + j]]), + dist_closest_points_part[i * part_k + j])); + } + } + + delete[] closest_points_part; + delete[] dist_closest_points_part; + + diskann::aligned_free(base_data); + } + return res; +}; + +template +int aux_main(const std::string &base_file, const std::string &label_file, const std::string &query_file, + const std::string >_file, size_t k, const std::string &universal_label, const diskann::Metric &metric, + const std::string &filter_label, const std::string &tags_file = std::string("")) +{ + size_t npoints, nqueries, dim; + + float *query_data = nullptr; + + load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); + if (nqueries > PARTSIZE) + std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE + << ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl; + + // load tags + const bool tags_enabled = tags_file.empty() ? false : true; + std::vector location_to_tag = diskann::loadTags(tags_file, base_file); + + int *closest_points = new int[nqueries * k]; + float *dist_closest_points = new float[nqueries * k]; + + std::vector>> results; + if (filter_label == "") + { + results = processUnfilteredParts(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag); + } + else + { + results = processFilteredParts(base_file, label_file, filter_label, universal_label, nqueries, npoints, dim, + k, query_data, metric, location_to_tag); + } + + for (size_t i = 0; i < nqueries; i++) + { + std::vector> &cur_res = results[i]; + std::sort(cur_res.begin(), cur_res.end(), custom_dist); + size_t j = 0; + for (auto iter : cur_res) + { + if (j == k) + break; + if (tags_enabled) + { + std::uint32_t index_with_tag = location_to_tag[iter.first]; + closest_points[i * k + j] = (int32_t)index_with_tag; + } + else + { + closest_points[i * k + j] = (int32_t)iter.first; + } + + if (metric == diskann::Metric::INNER_PRODUCT) + dist_closest_points[i * k + j] = -iter.second; + else + dist_closest_points[i * k + j] = iter.second; + + ++j; + } + if (j < k) + std::cout << "WARNING: found less than k GT entries for query " << i << std::endl; + } + + save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k); + delete[] closest_points; + delete[] dist_closest_points; + diskann::aligned_free(query_data); + + return 0; +} + +void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim) +{ + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (uint32_t)npts_i32; + dim = (uint32_t)dim_i32; + + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; + + int truthset_type = -1; // 1 means truthset has ids and distances, 2 means + // only ids, -1 is error + size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_with_dists) + truthset_type = 1; + + size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_just_ids) + truthset_type = 2; + + if (truthset_type == -1) + { + std::stringstream stream; + stream << "Error. File size mismatch. File should have bin format, with " + "npts followed by ngt followed by npts*ngt ids and optionally " + "followed by npts*ngt distance values; actual size: " + << actual_file_size << ", expected: " << expected_file_size_with_dists << " or " + << expected_file_size_just_ids; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + ids = new uint32_t[npts * dim]; + reader.read((char *)ids, npts * dim * sizeof(uint32_t)); + + if (truthset_type == 1) + { + dists = new float[npts * dim]; + reader.read((char *)dists, npts * dim * sizeof(float)); + } +} + +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, label_file, filter_label, + universal_label, filter_label_file; + uint64_t K; + + try + { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + + desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); + desc.add_options()("dist_fn", po::value(&dist_fn)->required(), "distance function "); + desc.add_options()("base_file", po::value(&base_file)->required(), + "File containing the base vectors in binary format"); + desc.add_options()("query_file", po::value(&query_file)->required(), + "File containing the query vectors in binary format"); + desc.add_options()("label_file", po::value(&label_file)->default_value(""), + "Input labels file in txt format if present"); + desc.add_options()("filter_label", po::value(&filter_label)->default_value(""), + "Input filter label if doing filtered groundtruth"); + desc.add_options()("universal_label", po::value(&universal_label)->default_value(""), + "Universal label, if using it, only in conjunction with label_file"); + desc.add_options()("gt_file", po::value(>_file)->required(), + "File name for the writing ground truth in binary " + "format, please don' append .bin at end if " + "no filter_label or filter_label_file is provided it " + "will save the file with '.bin' at end." + "else it will save the file as filename_label.bin"); + desc.add_options()("K", po::value(&K)->required(), + "Number of ground truth nearest neighbors to compute"); + desc.add_options()("tags_file", po::value(&tags_file)->default_value(std::string()), + "File containing the tags in binary format"); + desc.add_options()("filter_label_file", + po::value(&filter_label_file)->default_value(std::string("")), + "Filter file for Queries for Filtered Search "); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; + } + + if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) + { + std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; + return -1; + } + + if (filter_label != "" && filter_label_file != "") + { + std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("l2")) + { + metric = diskann::Metric::L2; + } + else if (dist_fn == std::string("mips")) + { + metric = diskann::Metric::INNER_PRODUCT; + } + else if (dist_fn == std::string("cosine")) + { + metric = diskann::Metric::COSINE; + } + else + { + std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl; + return -1; + } + + std::vector filter_labels; + if (filter_label != "") + { + filter_labels.push_back(filter_label); + } + else if (filter_label_file != "") + { + filter_labels = read_file_to_vector_of_strings(filter_label_file, false); + } + + // only when there is no filter label or 1 filter label for all queries + if (filter_labels.size() == 1) + { + try + { + if (data_type == std::string("float")) + aux_main(base_file, label_file, query_file, gt_file, K, universal_label, metric, + filter_labels[0], tags_file); + if (data_type == std::string("int8")) + aux_main(base_file, label_file, query_file, gt_file, K, universal_label, metric, + filter_labels[0], tags_file); + if (data_type == std::string("uint8")) + aux_main(base_file, label_file, query_file, gt_file, K, universal_label, metric, + filter_labels[0], tags_file); + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Compute GT failed." << std::endl; + return -1; + } + } + else + { // Each query has its own filter label + // Split up data and query bins into label specific ones + tsl::robin_map labels_to_number_of_points; + tsl::robin_map labels_to_number_of_queries; + + label_set all_labels; + for (size_t i = 0; i < filter_labels.size(); i++) + { + std::string label = filter_labels[i]; + all_labels.insert(label); + + if (labels_to_number_of_queries.find(label) == labels_to_number_of_queries.end()) + { + labels_to_number_of_queries[label] = 0; + } + labels_to_number_of_queries[label] += 1; + } + + size_t npoints; + std::vector> point_to_labels; + parse_label_file_into_vec(npoints, label_file, point_to_labels); + std::vector point_ids_to_labels(point_to_labels.size()); + std::vector query_ids_to_labels(filter_labels.size()); + + for (size_t i = 0; i < point_to_labels.size(); i++) + { + for (size_t j = 0; j < point_to_labels[i].size(); j++) + { + std::string label = point_to_labels[i][j]; + if (all_labels.find(label) != all_labels.end()) + { + point_ids_to_labels[i].insert(point_to_labels[i][j]); + if (labels_to_number_of_points.find(label) == labels_to_number_of_points.end()) + { + labels_to_number_of_points[label] = 0; + } + labels_to_number_of_points[label] += 1; + } + } + } + + for (size_t i = 0; i < filter_labels.size(); i++) + { + query_ids_to_labels[i].insert(filter_labels[i]); + } + + tsl::robin_map> label_id_to_orig_id; + tsl::robin_map> label_query_id_to_orig_id; + + if (data_type == std::string("float")) + { + label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + base_file, labels_to_number_of_points, point_ids_to_labels, all_labels); + + label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + query_file, labels_to_number_of_queries, query_ids_to_labels, + all_labels); // query_filters acts like query_ids_to_labels + } + else if (data_type == std::string("int8")) + { + label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + base_file, labels_to_number_of_points, point_ids_to_labels, all_labels); + + label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + query_file, labels_to_number_of_queries, query_ids_to_labels, + all_labels); // query_filters acts like query_ids_to_labels + } + else if (data_type == std::string("uint8")) + { + label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + base_file, labels_to_number_of_points, point_ids_to_labels, all_labels); + + label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat( + query_file, labels_to_number_of_queries, query_ids_to_labels, + all_labels); // query_filters acts like query_ids_to_labels + } + else + { + diskann::cerr << "Invalid data type" << std::endl; + return -1; + } + + // Generate label specific ground truths + + try + { + for (const auto &label : all_labels) + { + std::string filtered_base_file = base_file + "_" + label; + std::string filtered_query_file = query_file + "_" + label; + std::string filtered_gt_file = gt_file + "_" + label; + if (data_type == std::string("float")) + aux_main(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, ""); + if (data_type == std::string("int8")) + aux_main(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, ""); + if (data_type == std::string("uint8")) + aux_main(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, ""); + } + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Compute GT failed." << std::endl; + return -1; + } + + // Combine the label specific ground truths to produce a single GT file + + uint32_t *gt_ids = nullptr; + float *gt_dists = nullptr; + size_t gt_num, gt_dim; + + std::vector> final_gt_ids; + std::vector> final_gt_dists; + + uint32_t query_num = 0; + for (const auto &lbl : all_labels) + { + query_num += labels_to_number_of_queries[lbl]; + } + + for (uint32_t i = 0; i < query_num; i++) + { + final_gt_ids.push_back(std::vector(K)); + final_gt_dists.push_back(std::vector(K)); + } + + for (const auto &lbl : all_labels) + { + std::string filtered_gt_file = gt_file + "_" + lbl; + load_truthset(filtered_gt_file, gt_ids, gt_dists, gt_num, gt_dim); + + for (uint32_t i = 0; i < labels_to_number_of_queries[lbl]; i++) + { + uint32_t orig_query_id = label_query_id_to_orig_id[lbl][i]; + for (uint64_t j = 0; j < K; j++) + { + final_gt_ids[orig_query_id][j] = label_id_to_orig_id[lbl][gt_ids[i * K + j]]; + final_gt_dists[orig_query_id][j] = gt_dists[i * K + j]; + } + } + } + + int32_t *closest_points = new int32_t[query_num * K]; + float *dist_closest_points = new float[query_num * K]; + + for (uint32_t i = 0; i < query_num; i++) + { + for (uint32_t j = 0; j < K; j++) + { + closest_points[i * K + j] = final_gt_ids[i][j]; + dist_closest_points[i * K + j] = final_gt_dists[i][j]; + } + } + + save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, query_num, K); + + // cleanup artifacts + std::cout << "Cleaning up artifacts..." << std::endl; + tsl::robin_set paths_to_clean{gt_file, base_file, query_file}; + clean_up_artifacts(paths_to_clean, all_labels); + } +} diff --git a/tests/utils/count_bfs_levels.cpp b/apps/utils/count_bfs_levels.cpp similarity index 95% rename from tests/utils/count_bfs_levels.cpp rename to apps/utils/count_bfs_levels.cpp index fa38bef34..ddc4eaf0b 100644 --- a/tests/utils/count_bfs_levels.cpp +++ b/apps/utils/count_bfs_levels.cpp @@ -23,7 +23,7 @@ namespace po = boost::program_options; -template void bfs_count(const std::string &index_path, unsigned data_dims) +template void bfs_count(const std::string &index_path, uint32_t data_dims) { using TagT = uint32_t; using LabelT = uint32_t; @@ -37,7 +37,7 @@ template void bfs_count(const std::string &index_path, unsigned dat int main(int argc, char **argv) { std::string data_type, index_path_prefix; - unsigned data_dims; + uint32_t data_dims; po::options_description desc{"Arguments"}; try @@ -46,7 +46,7 @@ int main(int argc, char **argv) desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); desc.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), "Path prefix to the index"); - desc.add_options()("data_dims", po::value(&data_dims)->required(), "Dimensionality of the data"); + desc.add_options()("data_dims", po::value(&data_dims)->required(), "Dimensionality of the data"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); diff --git a/tests/utils/create_disk_layout.cpp b/apps/utils/create_disk_layout.cpp similarity index 100% rename from tests/utils/create_disk_layout.cpp rename to apps/utils/create_disk_layout.cpp diff --git a/tests/utils/float_bin_to_int8.cpp b/apps/utils/float_bin_to_int8.cpp similarity index 63% rename from tests/utils/float_bin_to_int8.cpp rename to apps/utils/float_bin_to_int8.cpp index fd63cb353..1982005af 100644 --- a/tests/utils/float_bin_to_int8.cpp +++ b/apps/utils/float_bin_to_int8.cpp @@ -4,14 +4,14 @@ #include #include "utils.h" -void block_convert(std::ofstream &writer, int8_t *write_buf, std::ifstream &reader, float *read_buf, _u64 npts, - _u64 ndims, float bias, float scale) +void block_convert(std::ofstream &writer, int8_t *write_buf, std::ifstream &reader, float *read_buf, size_t npts, + size_t ndims, float bias, float scale) { reader.read((char *)read_buf, npts * ndims * sizeof(float)); - for (_u64 i = 0; i < npts; i++) + for (size_t i = 0; i < npts; i++) { - for (_u64 d = 0; d < ndims; d++) + for (size_t d = 0; d < ndims; d++) { write_buf[d + i * ndims] = (int8_t)((read_buf[d + i * ndims] - bias) * (254.0 / scale)); } @@ -28,29 +28,29 @@ int main(int argc, char **argv) } std::ifstream reader(argv[1], std::ios::binary); - _u32 npts_u32; - _u32 ndims_u32; - reader.read((char *)&npts_u32, sizeof(_s32)); - reader.read((char *)&ndims_u32, sizeof(_s32)); + uint32_t npts_u32; + uint32_t ndims_u32; + reader.read((char *)&npts_u32, sizeof(uint32_t)); + reader.read((char *)&ndims_u32, sizeof(uint32_t)); size_t npts = npts_u32; size_t ndims = ndims_u32; std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - _u64 blk_size = 131072; - _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; std::ofstream writer(argv[2], std::ios::binary); auto read_buf = new float[blk_size * ndims]; auto write_buf = new int8_t[blk_size * ndims]; - float bias = atof(argv[3]); - float scale = atof(argv[4]); + float bias = (float)atof(argv[3]); + float scale = (float)atof(argv[4]); - writer.write((char *)(&npts_u32), sizeof(_u32)); - writer.write((char *)(&ndims_u32), sizeof(_u32)); + writer.write((char *)(&npts_u32), sizeof(uint32_t)); + writer.write((char *)(&ndims_u32), sizeof(uint32_t)); - for (_u64 i = 0; i < nblks; i++) + for (size_t i = 0; i < nblks; i++) { - _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + size_t cblk_size = std::min(npts - i * blk_size, blk_size); block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale); std::cout << "Block #" << i << " written" << std::endl; } diff --git a/tests/utils/fvecs_to_bin.cpp b/apps/utils/fvecs_to_bin.cpp similarity index 57% rename from tests/utils/fvecs_to_bin.cpp rename to apps/utils/fvecs_to_bin.cpp index 10e28076a..873ad3b0c 100644 --- a/tests/utils/fvecs_to_bin.cpp +++ b/apps/utils/fvecs_to_bin.cpp @@ -5,11 +5,11 @@ #include "utils.h" // Convert float types -void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *read_buf, float *write_buf, _u64 npts, - _u64 ndims) +void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *read_buf, float *write_buf, size_t npts, + size_t ndims) { - reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(unsigned))); - for (_u64 i = 0; i < npts; i++) + reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) { memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float)); } @@ -17,16 +17,16 @@ void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *re } // Convert byte types -void block_convert_byte(std::ifstream &reader, std::ofstream &writer, _u8 *read_buf, _u8 *write_buf, _u64 npts, - _u64 ndims) +void block_convert_byte(std::ifstream &reader, std::ofstream &writer, uint8_t *read_buf, uint8_t *write_buf, + size_t npts, size_t ndims) { - reader.read((char *)read_buf, npts * (ndims * sizeof(_u8) + sizeof(unsigned))); - for (_u64 i = 0; i < npts; i++) + reader.read((char *)read_buf, npts * (ndims * sizeof(uint8_t) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) { - memcpy(write_buf + i * ndims, (read_buf + i * (ndims + sizeof(unsigned))) + sizeof(unsigned), - ndims * sizeof(_u8)); + memcpy(write_buf + i * ndims, (read_buf + i * (ndims + sizeof(uint32_t))) + sizeof(uint32_t), + ndims * sizeof(uint8_t)); } - writer.write((char *)write_buf, npts * ndims * sizeof(_u8)); + writer.write((char *)write_buf, npts * ndims * sizeof(uint8_t)); } int main(int argc, char **argv) @@ -41,7 +41,7 @@ int main(int argc, char **argv) if (strcmp(argv[1], "uint8") == 0 || strcmp(argv[1], "int8") == 0) { - datasize = sizeof(_u8); + datasize = sizeof(uint8_t); } else if (strcmp(argv[1], "float") != 0) { @@ -50,32 +50,32 @@ int main(int argc, char **argv) } std::ifstream reader(argv[2], std::ios::binary | std::ios::ate); - _u64 fsize = reader.tellg(); + size_t fsize = reader.tellg(); reader.seekg(0, std::ios::beg); - unsigned ndims_u32; - reader.read((char *)&ndims_u32, sizeof(unsigned)); + uint32_t ndims_u32; + reader.read((char *)&ndims_u32, sizeof(uint32_t)); reader.seekg(0, std::ios::beg); - _u64 ndims = (_u64)ndims_u32; - _u64 npts = fsize / ((ndims * datasize) + sizeof(unsigned)); + size_t ndims = (size_t)ndims_u32; + size_t npts = fsize / ((ndims * datasize) + sizeof(uint32_t)); std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - _u64 blk_size = 131072; - _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; std::cout << "# blks: " << nblks << std::endl; std::ofstream writer(argv[3], std::ios::binary); - _s32 npts_s32 = (_s32)npts; - _s32 ndims_s32 = (_s32)ndims; - writer.write((char *)&npts_s32, sizeof(_s32)); - writer.write((char *)&ndims_s32, sizeof(_s32)); + int32_t npts_s32 = (int32_t)npts; + int32_t ndims_s32 = (int32_t)ndims; + writer.write((char *)&npts_s32, sizeof(int32_t)); + writer.write((char *)&ndims_s32, sizeof(int32_t)); - _u64 chunknpts = std::min(npts, blk_size); - _u8 *read_buf = new _u8[chunknpts * ((ndims * datasize) + sizeof(unsigned))]; - _u8 *write_buf = new _u8[chunknpts * ndims * datasize]; + size_t chunknpts = std::min(npts, blk_size); + uint8_t *read_buf = new uint8_t[chunknpts * ((ndims * datasize) + sizeof(uint32_t))]; + uint8_t *write_buf = new uint8_t[chunknpts * ndims * datasize]; - for (_u64 i = 0; i < nblks; i++) + for (size_t i = 0; i < nblks; i++) { - _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + size_t cblk_size = std::min(npts - i * blk_size, blk_size); if (datasize == sizeof(float)) { block_convert_float(reader, writer, (float *)read_buf, (float *)write_buf, cblk_size, ndims); diff --git a/tests/utils/fvecs_to_bvecs.cpp b/apps/utils/fvecs_to_bvecs.cpp similarity index 68% rename from tests/utils/fvecs_to_bvecs.cpp rename to apps/utils/fvecs_to_bvecs.cpp index 8324e9df7..f9c2aa71b 100644 --- a/tests/utils/fvecs_to_bvecs.cpp +++ b/apps/utils/fvecs_to_bvecs.cpp @@ -4,14 +4,14 @@ #include #include "utils.h" -void block_convert(std::ifstream &reader, std::ofstream &writer, float *read_buf, uint8_t *write_buf, _u64 npts, - _u64 ndims) +void block_convert(std::ifstream &reader, std::ofstream &writer, float *read_buf, uint8_t *write_buf, size_t npts, + size_t ndims) { - reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(unsigned))); - for (_u64 i = 0; i < npts; i++) + reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) { - memcpy(write_buf + i * (ndims + 4), read_buf + i * (ndims + 1), sizeof(unsigned)); - for (_u64 d = 0; d < ndims; d++) + memcpy(write_buf + i * (ndims + 4), read_buf + i * (ndims + 1), sizeof(uint32_t)); + for (size_t d = 0; d < ndims; d++) write_buf[i * (ndims + 4) + 4 + d] = (uint8_t)read_buf[i * (ndims + 1) + 1 + d]; } writer.write((char *)write_buf, npts * (ndims * 1 + 4)); @@ -25,25 +25,25 @@ int main(int argc, char **argv) exit(-1); } std::ifstream reader(argv[1], std::ios::binary | std::ios::ate); - _u64 fsize = reader.tellg(); + size_t fsize = reader.tellg(); reader.seekg(0, std::ios::beg); - unsigned ndims_u32; - reader.read((char *)&ndims_u32, sizeof(unsigned)); + uint32_t ndims_u32; + reader.read((char *)&ndims_u32, sizeof(uint32_t)); reader.seekg(0, std::ios::beg); - _u64 ndims = (_u64)ndims_u32; - _u64 npts = fsize / ((ndims + 1) * sizeof(float)); + size_t ndims = (size_t)ndims_u32; + size_t npts = fsize / ((ndims + 1) * sizeof(float)); std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - _u64 blk_size = 131072; - _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; std::cout << "# blks: " << nblks << std::endl; std::ofstream writer(argv[2], std::ios::binary); auto read_buf = new float[npts * (ndims + 1)]; auto write_buf = new uint8_t[npts * (ndims + 4)]; - for (_u64 i = 0; i < nblks; i++) + for (size_t i = 0; i < nblks; i++) { - _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + size_t cblk_size = std::min(npts - i * blk_size, blk_size); block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims); std::cout << "Block #" << i << " written" << std::endl; } diff --git a/tests/utils/gen_random_slice.cpp b/apps/utils/gen_random_slice.cpp similarity index 100% rename from tests/utils/gen_random_slice.cpp rename to apps/utils/gen_random_slice.cpp diff --git a/tests/utils/generate_pq.cpp b/apps/utils/generate_pq.cpp similarity index 77% rename from tests/utils/generate_pq.cpp rename to apps/utils/generate_pq.cpp index 761983129..a881b1104 100644 --- a/tests/utils/generate_pq.cpp +++ b/apps/utils/generate_pq.cpp @@ -22,16 +22,16 @@ bool generate_pq(const std::string &data_path, const std::string &index_prefix_p if (opq) { - diskann::generate_opq_pivots(train_data, train_size, train_dim, num_pq_centers, num_pq_chunks, pq_pivots_path, - true); + diskann::generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers, + (uint32_t)num_pq_chunks, pq_pivots_path, true); } else { - diskann::generate_pq_pivots(train_data, train_size, train_dim, num_pq_centers, num_pq_chunks, - KMEANS_ITERS_FOR_PQ, pq_pivots_path); + diskann::generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers, + (uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path); } - diskann::generate_pq_data_from_pivots(data_path, num_pq_centers, num_pq_chunks, pq_pivots_path, - pq_compressed_vectors_path, true); + diskann::generate_pq_data_from_pivots(data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks, + pq_pivots_path, pq_compressed_vectors_path, true); delete[] train_data; @@ -55,7 +55,7 @@ int main(int argc, char **argv) const std::string index_prefix_path(argv[3]); const size_t num_pq_centers = 256; const size_t num_pq_chunks = (size_t)atoi(argv[4]); - const float sampling_rate = atof(argv[5]); + const float sampling_rate = (float)atof(argv[5]); const bool opq = atoi(argv[6]) == 0 ? false : true; if (std::string(argv[1]) == std::string("float")) diff --git a/tests/utils/generate_synthetic_labels.cpp b/apps/utils/generate_synthetic_labels.cpp similarity index 73% rename from tests/utils/generate_synthetic_labels.cpp rename to apps/utils/generate_synthetic_labels.cpp index c96d3361d..6741760cb 100644 --- a/tests/utils/generate_synthetic_labels.cpp +++ b/apps/utils/generate_synthetic_labels.cpp @@ -12,19 +12,19 @@ namespace po = boost::program_options; class ZipfDistribution { public: - ZipfDistribution(int num_points, int num_labels) - : uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0)), num_points(num_points), - num_labels(num_labels) + ZipfDistribution(uint64_t num_points, uint32_t num_labels) + : num_labels(num_labels), num_points(num_points), + uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0)) { } - std::unordered_map createDistributionMap() + std::unordered_map createDistributionMap() { - std::unordered_map map; - int primary_label_freq = ceil(num_points * distribution_factor); - for (int i{1}; i < num_labels + 1; i++) + std::unordered_map map; + uint32_t primary_label_freq = (uint32_t)ceil(num_points * distribution_factor); + for (uint32_t i{1}; i < num_labels + 1; i++) { - map[i] = ceil(primary_label_freq / i); + map[i] = (uint32_t)ceil(primary_label_freq / i); } return map; } @@ -32,15 +32,13 @@ class ZipfDistribution int writeDistribution(std::ofstream &outfile) { auto distribution_map = createDistributionMap(); - auto primary_label_frequency = num_points * distribution_factor; - for (int i{0}; i < num_points; i++) + for (uint32_t i{0}; i < num_points; i++) { bool label_written = false; - for (auto it = distribution_map.cbegin(), next_it = it; it != distribution_map.cend(); it = next_it) + for (auto it = distribution_map.cbegin(); it != distribution_map.cend(); it++) { - next_it++; auto label_selection_probability = std::bernoulli_distribution(distribution_factor / (double)it->first); - if (label_selection_probability(rand_engine)) + if (label_selection_probability(rand_engine) && distribution_map[it->first] > 0) { if (label_written) { @@ -50,10 +48,6 @@ class ZipfDistribution label_written = true; // remove label from map if we have used all labels distribution_map[it->first] -= 1; - if (distribution_map[it->first] == 0) - { - distribution_map.erase(it); - } } } if (!label_written) @@ -81,8 +75,8 @@ class ZipfDistribution } private: - int num_labels; - const int num_points; + const uint32_t num_labels; + const uint64_t num_points; const double distribution_factor = 0.7; std::knuth_b rand_engine; const std::uniform_real_distribution uniform_zero_to_one; @@ -91,7 +85,8 @@ class ZipfDistribution int main(int argc, char **argv) { std::string output_file, distribution_type; - _u64 num_labels, num_points; + uint32_t num_labels; + uint64_t num_points; try { @@ -101,10 +96,11 @@ int main(int argc, char **argv) desc.add_options()("output_file,O", po::value(&output_file)->required(), "Filename for saving the label file"); desc.add_options()("num_points,N", po::value(&num_points)->required(), "Number of points in dataset"); - desc.add_options()("num_labels,L", po::value(&num_labels)->required(), + desc.add_options()("num_labels,L", po::value(&num_labels)->required(), "Number of unique labels, up to 5000"); desc.add_options()("distribution_type,DT", po::value(&distribution_type)->default_value("random"), - "Distribution function for labels defaults to random"); + "Distribution function for labels defaults " + "to random"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -152,10 +148,10 @@ int main(int argc, char **argv) } else if (distribution_type == "random") { - for (int i = 0; i < num_points; i++) + for (size_t i = 0; i < num_points; i++) { bool label_written = false; - for (int j = 1; j <= num_labels; j++) + for (size_t j = 1; j <= num_labels; j++) { // 50% chance to assign each label if (rand() > (RAND_MAX / 2)) @@ -178,6 +174,19 @@ int main(int argc, char **argv) } } } + else if (distribution_type == "one_per_point") + { + std::random_device rd; // obtain a random number from hardware + std::mt19937 gen(rd()); // seed the generator + std::uniform_int_distribution<> distr(0, num_labels); // define the range + + for (size_t i = 0; i < num_points; i++) + { + outfile << distr(gen); + if (i != num_points - 1) + outfile << '\n'; + } + } if (outfile.is_open()) { outfile.close(); diff --git a/tests/utils/int8_to_float.cpp b/apps/utils/int8_to_float.cpp similarity index 100% rename from tests/utils/int8_to_float.cpp rename to apps/utils/int8_to_float.cpp diff --git a/tests/utils/int8_to_float_scale.cpp b/apps/utils/int8_to_float_scale.cpp similarity index 63% rename from tests/utils/int8_to_float_scale.cpp rename to apps/utils/int8_to_float_scale.cpp index 42aa06d0e..19fbc6c43 100644 --- a/tests/utils/int8_to_float_scale.cpp +++ b/apps/utils/int8_to_float_scale.cpp @@ -4,14 +4,14 @@ #include #include "utils.h" -void block_convert(std::ofstream &writer, float *write_buf, std::ifstream &reader, int8_t *read_buf, _u64 npts, - _u64 ndims, float bias, float scale) +void block_convert(std::ofstream &writer, float *write_buf, std::ifstream &reader, int8_t *read_buf, size_t npts, + size_t ndims, float bias, float scale) { reader.read((char *)read_buf, npts * ndims * sizeof(int8_t)); - for (_u64 i = 0; i < npts; i++) + for (size_t i = 0; i < npts; i++) { - for (_u64 d = 0; d < ndims; d++) + for (size_t d = 0; d < ndims; d++) { write_buf[d + i * ndims] = (((float)read_buf[d + i * ndims] - bias) * scale); } @@ -28,29 +28,29 @@ int main(int argc, char **argv) } std::ifstream reader(argv[1], std::ios::binary); - _u32 npts_u32; - _u32 ndims_u32; - reader.read((char *)&npts_u32, sizeof(_s32)); - reader.read((char *)&ndims_u32, sizeof(_s32)); + uint32_t npts_u32; + uint32_t ndims_u32; + reader.read((char *)&npts_u32, sizeof(uint32_t)); + reader.read((char *)&ndims_u32, sizeof(uint32_t)); size_t npts = npts_u32; size_t ndims = ndims_u32; std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - _u64 blk_size = 131072; - _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; std::ofstream writer(argv[2], std::ios::binary); auto read_buf = new int8_t[blk_size * ndims]; auto write_buf = new float[blk_size * ndims]; - float bias = atof(argv[3]); - float scale = atof(argv[4]); + float bias = (float)atof(argv[3]); + float scale = (float)atof(argv[4]); - writer.write((char *)(&npts_u32), sizeof(_u32)); - writer.write((char *)(&ndims_u32), sizeof(_u32)); + writer.write((char *)(&npts_u32), sizeof(uint32_t)); + writer.write((char *)(&ndims_u32), sizeof(uint32_t)); - for (_u64 i = 0; i < nblks; i++) + for (size_t i = 0; i < nblks; i++) { - _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + size_t cblk_size = std::min(npts - i * blk_size, blk_size); block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale); std::cout << "Block #" << i << " written" << std::endl; } diff --git a/tests/utils/ivecs_to_bin.cpp b/apps/utils/ivecs_to_bin.cpp similarity index 50% rename from tests/utils/ivecs_to_bin.cpp rename to apps/utils/ivecs_to_bin.cpp index b7ee2304a..ea8a4a3d2 100644 --- a/tests/utils/ivecs_to_bin.cpp +++ b/apps/utils/ivecs_to_bin.cpp @@ -4,14 +4,15 @@ #include #include "utils.h" -void block_convert(std::ifstream &reader, std::ofstream &writer, _u32 *read_buf, _u32 *write_buf, _u64 npts, _u64 ndims) +void block_convert(std::ifstream &reader, std::ofstream &writer, uint32_t *read_buf, uint32_t *write_buf, size_t npts, + size_t ndims) { - reader.read((char *)read_buf, npts * (ndims * sizeof(_u32) + sizeof(unsigned))); - for (_u64 i = 0; i < npts; i++) + reader.read((char *)read_buf, npts * (ndims * sizeof(uint32_t) + sizeof(uint32_t))); + for (size_t i = 0; i < npts; i++) { - memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(_u32)); + memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(uint32_t)); } - writer.write((char *)write_buf, npts * ndims * sizeof(_u32)); + writer.write((char *)write_buf, npts * ndims * sizeof(uint32_t)); } int main(int argc, char **argv) @@ -22,29 +23,29 @@ int main(int argc, char **argv) exit(-1); } std::ifstream reader(argv[1], std::ios::binary | std::ios::ate); - _u64 fsize = reader.tellg(); + size_t fsize = reader.tellg(); reader.seekg(0, std::ios::beg); - unsigned ndims_u32; - reader.read((char *)&ndims_u32, sizeof(unsigned)); + uint32_t ndims_u32; + reader.read((char *)&ndims_u32, sizeof(uint32_t)); reader.seekg(0, std::ios::beg); - _u64 ndims = (_u64)ndims_u32; - _u64 npts = fsize / ((ndims + 1) * sizeof(_u32)); + size_t ndims = (size_t)ndims_u32; + size_t npts = fsize / ((ndims + 1) * sizeof(uint32_t)); std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - _u64 blk_size = 131072; - _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; std::cout << "# blks: " << nblks << std::endl; std::ofstream writer(argv[2], std::ios::binary); - int npts_s32 = (_s32)npts; - int ndims_s32 = (_s32)ndims; - writer.write((char *)&npts_s32, sizeof(_s32)); - writer.write((char *)&ndims_s32, sizeof(_s32)); - _u32 *read_buf = new _u32[npts * (ndims + 1)]; - _u32 *write_buf = new _u32[npts * ndims]; - for (_u64 i = 0; i < nblks; i++) + int npts_s32 = (int)npts; + int ndims_s32 = (int)ndims; + writer.write((char *)&npts_s32, sizeof(int)); + writer.write((char *)&ndims_s32, sizeof(int)); + uint32_t *read_buf = new uint32_t[npts * (ndims + 1)]; + uint32_t *write_buf = new uint32_t[npts * ndims]; + for (size_t i = 0; i < nblks; i++) { - _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + size_t cblk_size = std::min(npts - i * blk_size, blk_size); block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims); std::cout << "Block #" << i << " written" << std::endl; } diff --git a/tests/utils/merge_shards.cpp b/apps/utils/merge_shards.cpp similarity index 83% rename from tests/utils/merge_shards.cpp rename to apps/utils/merge_shards.cpp index 5e14fb14d..106c15eef 100644 --- a/tests/utils/merge_shards.cpp +++ b/apps/utils/merge_shards.cpp @@ -19,8 +19,10 @@ int main(int argc, char **argv) if (argc != 9) { std::cout << argv[0] - << " vamana_index_prefix[1] vamana_index_suffix[2] idmaps_prefix[3] " - "idmaps_suffix[4] n_shards[5] max_degree[6] output_vamana_path[7] " + << " vamana_index_prefix[1] vamana_index_suffix[2] " + "idmaps_prefix[3] " + "idmaps_suffix[4] n_shards[5] max_degree[6] " + "output_vamana_path[7] " "output_medoids_path[8]" << std::endl; exit(-1); @@ -30,8 +32,8 @@ int main(int argc, char **argv) std::string vamana_suffix(argv[2]); std::string idmaps_prefix(argv[3]); std::string idmaps_suffix(argv[4]); - _u64 nshards = (_u64)std::atoi(argv[5]); - _u32 max_degree = (_u64)std::atoi(argv[6]); + uint64_t nshards = (uint64_t)std::atoi(argv[5]); + uint32_t max_degree = (uint64_t)std::atoi(argv[6]); std::string output_index(argv[7]); std::string output_medoids(argv[8]); diff --git a/tests/utils/partition_data.cpp b/apps/utils/partition_data.cpp similarity index 96% rename from tests/utils/partition_data.cpp rename to apps/utils/partition_data.cpp index 2c505315c..2520f3f4a 100644 --- a/tests/utils/partition_data.cpp +++ b/apps/utils/partition_data.cpp @@ -23,7 +23,7 @@ int main(int argc, char **argv) const std::string data_path(argv[2]); const std::string prefix_path(argv[3]); - const float sampling_rate = atof(argv[4]); + const float sampling_rate = (float)atof(argv[4]); const size_t num_partitions = (size_t)std::atoi(argv[5]); const size_t max_reps = 15; const size_t k_index = (size_t)std::atoi(argv[6]); diff --git a/tests/utils/partition_with_ram_budget.cpp b/apps/utils/partition_with_ram_budget.cpp similarity index 96% rename from tests/utils/partition_with_ram_budget.cpp rename to apps/utils/partition_with_ram_budget.cpp index 3c546801a..937b68d2c 100644 --- a/tests/utils/partition_with_ram_budget.cpp +++ b/apps/utils/partition_with_ram_budget.cpp @@ -23,7 +23,7 @@ int main(int argc, char **argv) const std::string data_path(argv[2]); const std::string prefix_path(argv[3]); - const float sampling_rate = atof(argv[4]); + const float sampling_rate = (float)atof(argv[4]); const double ram_budget = (double)std::atof(argv[5]); const size_t graph_degree = (size_t)std::atoi(argv[6]); const size_t k_index = (size_t)std::atoi(argv[7]); diff --git a/tests/utils/rand_data_gen.cpp b/apps/utils/rand_data_gen.cpp similarity index 74% rename from tests/utils/rand_data_gen.cpp rename to apps/utils/rand_data_gen.cpp index c4461137d..a6f9305c8 100644 --- a/tests/utils/rand_data_gen.cpp +++ b/apps/utils/rand_data_gen.cpp @@ -11,7 +11,7 @@ namespace po = boost::program_options; -int block_write_float(std::ofstream &writer, _u64 ndims, _u64 npts, float norm) +int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, float norm) { auto vec = new float[ndims]; @@ -19,14 +19,14 @@ int block_write_float(std::ofstream &writer, _u64 ndims, _u64 npts, float norm) std::mt19937 gen{rd()}; std::normal_distribution<> normal_rand{0, 1}; - for (_u64 i = 0; i < npts; i++) + for (size_t i = 0; i < npts; i++) { float sum = 0; - for (_u64 d = 0; d < ndims; ++d) - vec[d] = normal_rand(gen); - for (_u64 d = 0; d < ndims; ++d) + for (size_t d = 0; d < ndims; ++d) + vec[d] = (float)normal_rand(gen); + for (size_t d = 0; d < ndims; ++d) sum += vec[d] * vec[d]; - for (_u64 d = 0; d < ndims; ++d) + for (size_t d = 0; d < ndims; ++d) vec[d] = vec[d] * norm / std::sqrt(sum); writer.write((char *)vec, ndims * sizeof(float)); @@ -36,7 +36,7 @@ int block_write_float(std::ofstream &writer, _u64 ndims, _u64 npts, float norm) return 0; } -int block_write_int8(std::ofstream &writer, _u64 ndims, _u64 npts, float norm) +int block_write_int8(std::ofstream &writer, size_t ndims, size_t npts, float norm) { auto vec = new float[ndims]; auto vec_T = new int8_t[ndims]; @@ -45,19 +45,19 @@ int block_write_int8(std::ofstream &writer, _u64 ndims, _u64 npts, float norm) std::mt19937 gen{rd()}; std::normal_distribution<> normal_rand{0, 1}; - for (_u64 i = 0; i < npts; i++) + for (size_t i = 0; i < npts; i++) { float sum = 0; - for (_u64 d = 0; d < ndims; ++d) - vec[d] = normal_rand(gen); - for (_u64 d = 0; d < ndims; ++d) + for (size_t d = 0; d < ndims; ++d) + vec[d] = (float)normal_rand(gen); + for (size_t d = 0; d < ndims; ++d) sum += vec[d] * vec[d]; - for (_u64 d = 0; d < ndims; ++d) + for (size_t d = 0; d < ndims; ++d) vec[d] = vec[d] * norm / std::sqrt(sum); - for (_u64 d = 0; d < ndims; ++d) + for (size_t d = 0; d < ndims; ++d) { - vec_T[d] = std::round(vec[d]); + vec_T[d] = (int8_t)std::round(vec[d]); } writer.write((char *)vec_T, ndims * sizeof(int8_t)); @@ -68,7 +68,7 @@ int block_write_int8(std::ofstream &writer, _u64 ndims, _u64 npts, float norm) return 0; } -int block_write_uint8(std::ofstream &writer, _u64 ndims, _u64 npts, float norm) +int block_write_uint8(std::ofstream &writer, size_t ndims, size_t npts, float norm) { auto vec = new float[ndims]; auto vec_T = new int8_t[ndims]; @@ -77,19 +77,19 @@ int block_write_uint8(std::ofstream &writer, _u64 ndims, _u64 npts, float norm) std::mt19937 gen{rd()}; std::normal_distribution<> normal_rand{0, 1}; - for (_u64 i = 0; i < npts; i++) + for (size_t i = 0; i < npts; i++) { float sum = 0; - for (_u64 d = 0; d < ndims; ++d) - vec[d] = normal_rand(gen); - for (_u64 d = 0; d < ndims; ++d) + for (size_t d = 0; d < ndims; ++d) + vec[d] = (float)normal_rand(gen); + for (size_t d = 0; d < ndims; ++d) sum += vec[d] * vec[d]; - for (_u64 d = 0; d < ndims; ++d) + for (size_t d = 0; d < ndims; ++d) vec[d] = vec[d] * norm / std::sqrt(sum); - for (_u64 d = 0; d < ndims; ++d) + for (size_t d = 0; d < ndims; ++d) { - vec_T[d] = 128 + std::round(vec[d]); + vec_T[d] = 128 + (int8_t)std::round(vec[d]); } writer.write((char *)vec_T, ndims * sizeof(uint8_t)); @@ -103,7 +103,7 @@ int block_write_uint8(std::ofstream &writer, _u64 ndims, _u64 npts, float norm) int main(int argc, char **argv) { std::string data_type, output_file; - _u64 ndims, npts; + size_t ndims, npts; float norm; try @@ -149,7 +149,8 @@ int main(int argc, char **argv) { if (norm > 127) { - std::cerr << "Error: for int8/uint8 datatypes, L2 norm can not be greater " + std::cerr << "Error: for int8/uint8 datatypes, L2 norm can not be " + "greater " "than 127" << std::endl; return -1; @@ -161,19 +162,19 @@ int main(int argc, char **argv) std::ofstream writer; writer.exceptions(std::ofstream::failbit | std::ofstream::badbit); writer.open(output_file, std::ios::binary); - auto npts_s32 = (_u32)npts; - auto ndims_s32 = (_u32)ndims; - writer.write((char *)&npts_s32, sizeof(_u32)); - writer.write((char *)&ndims_s32, sizeof(_u32)); + auto npts_u32 = (uint32_t)npts; + auto ndims_u32 = (uint32_t)ndims; + writer.write((char *)&npts_u32, sizeof(uint32_t)); + writer.write((char *)&ndims_u32, sizeof(uint32_t)); - _u64 blk_size = 131072; - _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; std::cout << "# blks: " << nblks << std::endl; int ret = 0; - for (_u64 i = 0; i < nblks; i++) + for (size_t i = 0; i < nblks; i++) { - _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + size_t cblk_size = std::min(npts - i * blk_size, blk_size); if (data_type == std::string("float")) { ret = block_write_float(writer, ndims, cblk_size, norm); diff --git a/tests/utils/simulate_aggregate_recall.cpp b/apps/utils/simulate_aggregate_recall.cpp similarity index 72% rename from tests/utils/simulate_aggregate_recall.cpp rename to apps/utils/simulate_aggregate_recall.cpp index bb096cf20..73c4ea0f7 100644 --- a/tests/utils/simulate_aggregate_recall.cpp +++ b/apps/utils/simulate_aggregate_recall.cpp @@ -6,11 +6,11 @@ #include #include -inline float aggregate_recall(const unsigned k_aggr, const unsigned k, const unsigned npart, unsigned *count, +inline float aggregate_recall(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, uint32_t *count, const std::vector &recalls) { float found = 0; - for (unsigned i = 0; i < npart; ++i) + for (uint32_t i = 0; i < npart; ++i) { size_t max_found = std::min(count[i], k); found += recalls[max_found - 1] * max_found; @@ -18,23 +18,23 @@ inline float aggregate_recall(const unsigned k_aggr, const unsigned k, const uns return found / (float)k_aggr; } -void simulate(const unsigned k_aggr, const unsigned k, const unsigned npart, const unsigned nsim, +void simulate(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, const uint32_t nsim, const std::vector &recalls) { std::random_device r; std::default_random_engine randeng(r()); std::uniform_int_distribution uniform_dist(0, npart - 1); - unsigned *count = new unsigned[npart]; + uint32_t *count = new uint32_t[npart]; double aggr_recall = 0; - for (unsigned i = 0; i < nsim; ++i) + for (uint32_t i = 0; i < nsim; ++i) { - for (unsigned p = 0; p < npart; ++p) + for (uint32_t p = 0; p < npart; ++p) { count[p] = 0; } - for (unsigned t = 0; t < k_aggr; ++t) + for (uint32_t t = 0; t < k_aggr; ++t) { count[uniform_dist(randeng)]++; } @@ -53,15 +53,15 @@ int main(int argc, char **argv) exit(-1); } - const unsigned k_aggr = atoi(argv[1]); - const unsigned k = atoi(argv[2]); - const unsigned npart = atoi(argv[3]); - const unsigned nsim = atoi(argv[4]); + const uint32_t k_aggr = atoi(argv[1]); + const uint32_t k = atoi(argv[2]); + const uint32_t npart = atoi(argv[3]); + const uint32_t nsim = atoi(argv[4]); std::vector recalls; for (int ctr = 5; ctr < argc; ctr++) { - recalls.push_back(atof(argv[ctr])); + recalls.push_back((float)atof(argv[ctr])); } if (recalls.size() != k) diff --git a/tests/utils/stats_label_data.cpp b/apps/utils/stats_label_data.cpp similarity index 84% rename from tests/utils/stats_label_data.cpp rename to apps/utils/stats_label_data.cpp index c5aabd5ff..3342672ff 100644 --- a/tests/utils/stats_label_data.cpp +++ b/apps/utils/stats_label_data.cpp @@ -28,26 +28,26 @@ #endif namespace po = boost::program_options; -void stats_analysis(const std::string labels_file, std::string univeral_label, _u32 density = 10) +void stats_analysis(const std::string labels_file, std::string univeral_label, uint32_t density = 10) { std::string token, line; std::ifstream labels_stream(labels_file); - std::unordered_map label_counts; + std::unordered_map label_counts; std::string label_with_max_points; - _u32 max_points = 0; + uint32_t max_points = 0; long long sum = 0; long long point_cnt = 0; - float avg_labels_per_pt, avg_labels_per_pt_incl_0, mean_label_size, mean_label_size_incl_0; + float avg_labels_per_pt, mean_label_size; - std::vector<_u32> labels_per_point; - _u32 dense_pts = 0; + std::vector labels_per_point; + uint32_t dense_pts = 0; if (labels_stream.is_open()) { while (getline(labels_stream, line)) { point_cnt++; std::stringstream iss(line); - _u32 lbl_cnt = 0; + uint32_t lbl_cnt = 0; while (getline(iss, token, ',')) { lbl_cnt++; @@ -69,7 +69,7 @@ void stats_analysis(const std::string labels_file, std::string univeral_label, _ << " labels = " << (float)dense_pts / (float)labels_per_point.size() << std::endl; std::sort(labels_per_point.begin(), labels_per_point.end()); - std::vector> label_count_vec; + std::vector> label_count_vec; for (auto it = label_counts.begin(); it != label_counts.end(); it++) { @@ -84,14 +84,14 @@ void stats_analysis(const std::string labels_file, std::string univeral_label, _ } sort(label_count_vec.begin(), label_count_vec.end(), - [](const std::pair &lhs, const std::pair &rhs) { + [](const std::pair &lhs, const std::pair &rhs) { return lhs.second < rhs.second; }); for (float p = 0; p < 1; p += 0.05) { - std::cout << "Percentile " << (100 * p) << "\t" << label_count_vec[(_u32)(p * label_count_vec.size())].first - << " with count=" << label_count_vec[(_u32)(p * label_count_vec.size())].second << std::endl; + std::cout << "Percentile " << (100 * p) << "\t" << label_count_vec[(size_t)(p * label_count_vec.size())].first + << " with count=" << label_count_vec[(size_t)(p * label_count_vec.size())].second << std::endl; } std::cout << "Most common label " @@ -105,8 +105,8 @@ void stats_analysis(const std::string labels_file, std::string univeral_label, _ std::cout << "Third common label " << "\t" << label_count_vec[label_count_vec.size() - 3].first << " with count=" << label_count_vec[label_count_vec.size() - 3].second << std::endl; - avg_labels_per_pt = (sum) / (float)point_cnt; - mean_label_size = (sum) / label_counts.size(); + avg_labels_per_pt = sum / (float)point_cnt; + mean_label_size = sum / (float)label_counts.size(); std::cout << "Total number of points = " << point_cnt << ", number of labels = " << label_counts.size() << std::endl; std::cout << "Average number of labels per point = " << avg_labels_per_pt << std::endl; @@ -117,7 +117,7 @@ void stats_analysis(const std::string labels_file, std::string univeral_label, _ int main(int argc, char **argv) { std::string labels_file, universal_label; - _u32 density; + uint32_t density; po::options_description desc{"Arguments"}; try @@ -127,7 +127,7 @@ int main(int argc, char **argv) "path to labels data file."); desc.add_options()("universal_label", po::value(&universal_label)->required(), "Universal label used in labels file."); - desc.add_options()("density", po::value<_u32>(&density)->default_value(1), + desc.add_options()("density", po::value(&density)->default_value(1), "Number of labels each point in labels file, defaults to 1"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); diff --git a/tests/utils/tsv_to_bin.cpp b/apps/utils/tsv_to_bin.cpp similarity index 75% rename from tests/utils/tsv_to_bin.cpp rename to apps/utils/tsv_to_bin.cpp index 5aaa9a03c..c590a8f73 100644 --- a/tests/utils/tsv_to_bin.cpp +++ b/apps/utils/tsv_to_bin.cpp @@ -4,16 +4,16 @@ #include #include "utils.h" -void block_convert_float(std::ifstream &reader, std::ofstream &writer, _u64 npts, _u64 ndims) +void block_convert_float(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims) { auto read_buf = new float[npts * (ndims + 1)]; auto cursor = read_buf; float val; - for (_u64 i = 0; i < npts; i++) + for (size_t i = 0; i < npts; i++) { - for (_u64 d = 0; d < ndims; ++d) + for (size_t d = 0; d < ndims; ++d) { reader >> val; *cursor = val; @@ -24,16 +24,16 @@ void block_convert_float(std::ifstream &reader, std::ofstream &writer, _u64 npts delete[] read_buf; } -void block_convert_int8(std::ifstream &reader, std::ofstream &writer, _u64 npts, _u64 ndims) +void block_convert_int8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims) { auto read_buf = new int8_t[npts * (ndims + 1)]; auto cursor = read_buf; int val; - for (_u64 i = 0; i < npts; i++) + for (size_t i = 0; i < npts; i++) { - for (_u64 d = 0; d < ndims; ++d) + for (size_t d = 0; d < ndims; ++d) { reader >> val; *cursor = (int8_t)val; @@ -44,16 +44,16 @@ void block_convert_int8(std::ifstream &reader, std::ofstream &writer, _u64 npts, delete[] read_buf; } -void block_convert_uint8(std::ifstream &reader, std::ofstream &writer, _u64 npts, _u64 ndims) +void block_convert_uint8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims) { auto read_buf = new uint8_t[npts * (ndims + 1)]; auto cursor = read_buf; int val; - for (_u64 i = 0; i < npts; i++) + for (size_t i = 0; i < npts; i++) { - for (_u64 d = 0; d < ndims; ++d) + for (size_t d = 0; d < ndims; ++d) { reader >> val; *cursor = (uint8_t)val; @@ -81,26 +81,26 @@ int main(int argc, char **argv) std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; } - _u64 ndims = atoi(argv[4]); - _u64 npts = atoi(argv[5]); + size_t ndims = atoi(argv[4]); + size_t npts = atoi(argv[5]); std::ifstream reader(argv[2], std::ios::binary | std::ios::ate); - // _u64 fsize = reader.tellg(); + // size_t fsize = reader.tellg(); reader.seekg(0, std::ios::beg); reader.seekg(0, std::ios::beg); - _u64 blk_size = 131072; - _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; std::cout << "# blks: " << nblks << std::endl; std::ofstream writer(argv[3], std::ios::binary); - auto npts_s32 = (_u32)npts; - auto ndims_s32 = (_u32)ndims; - writer.write((char *)&npts_s32, sizeof(_u32)); - writer.write((char *)&ndims_s32, sizeof(_u32)); + auto npts_u32 = (uint32_t)npts; + auto ndims_u32 = (uint32_t)ndims; + writer.write((char *)&npts_u32, sizeof(uint32_t)); + writer.write((char *)&ndims_u32, sizeof(uint32_t)); - for (_u64 i = 0; i < nblks; i++) + for (size_t i = 0; i < nblks; i++) { - _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + size_t cblk_size = std::min(npts - i * blk_size, blk_size); if (std::string(argv[1]) == std::string("float")) { block_convert_float(reader, writer, cblk_size, ndims); diff --git a/tests/utils/uint32_to_uint8.cpp b/apps/utils/uint32_to_uint8.cpp similarity index 100% rename from tests/utils/uint32_to_uint8.cpp rename to apps/utils/uint32_to_uint8.cpp diff --git a/tests/utils/uint8_to_float.cpp b/apps/utils/uint8_to_float.cpp similarity index 100% rename from tests/utils/uint8_to_float.cpp rename to apps/utils/uint8_to_float.cpp diff --git a/tests/utils/vector_analysis.cpp b/apps/utils/vector_analysis.cpp similarity index 78% rename from tests/utils/vector_analysis.cpp rename to apps/utils/vector_analysis.cpp index 99e5627b1..009df6d05 100644 --- a/tests/utils/vector_analysis.cpp +++ b/apps/utils/vector_analysis.cpp @@ -24,19 +24,19 @@ template int analyze_norm(std::string base_file) { std::cout << "Analyzing data norms" << std::endl; T *data; - _u64 npts, ndims; + size_t npts, ndims; diskann::load_bin(base_file, data, npts, ndims); std::vector norms(npts, 0); #pragma omp parallel for schedule(dynamic) - for (_s64 i = 0; i < (_s64)npts; i++) + for (int64_t i = 0; i < (int64_t)npts; i++) { - for (_u32 d = 0; d < ndims; d++) + for (size_t d = 0; d < ndims; d++) norms[i] += data[i * ndims + d] * data[i * ndims + d]; norms[i] = std::sqrt(norms[i]); } std::sort(norms.begin(), norms.end()); - for (_u32 p = 0; p < 100; p += 5) - std::cout << "percentile " << p << ": " << norms[std::floor((p / 100.0) * npts)] << std::endl; + for (int p = 0; p < 100; p += 5) + std::cout << "percentile " << p << ": " << norms[(uint64_t)(std::floor((p / 100.0) * npts))] << std::endl; std::cout << "percentile 100" << ": " << norms[npts - 1] << std::endl; delete[] data; @@ -47,18 +47,18 @@ template int normalize_base(std::string base_file, std::string out_ { std::cout << "Normalizing base" << std::endl; T *data; - _u64 npts, ndims; + size_t npts, ndims; diskann::load_bin(base_file, data, npts, ndims); // std::vector norms(npts, 0); #pragma omp parallel for schedule(dynamic) - for (_s64 i = 0; i < (_s64)npts; i++) + for (int64_t i = 0; i < (int64_t)npts; i++) { float pt_norm = 0; - for (_u32 d = 0; d < ndims; d++) + for (size_t d = 0; d < ndims; d++) pt_norm += data[i * ndims + d] * data[i * ndims + d]; pt_norm = std::sqrt(pt_norm); - for (_u32 d = 0; d < ndims; d++) - data[i * ndims + d] = data[i * ndims + d] / pt_norm; + for (size_t d = 0; d < ndims; d++) + data[i * ndims + d] = static_cast(data[i * ndims + d] / pt_norm); } diskann::save_bin(out_file, data, npts, ndims); delete[] data; @@ -69,14 +69,14 @@ template int augment_base(std::string base_file, std::string out_fi { std::cout << "Analyzing data norms" << std::endl; T *data; - _u64 npts, ndims; + size_t npts, ndims; diskann::load_bin(base_file, data, npts, ndims); std::vector norms(npts, 0); float max_norm = 0; #pragma omp parallel for schedule(dynamic) - for (_s64 i = 0; i < (_s64)npts; i++) + for (int64_t i = 0; i < (int64_t)npts; i++) { - for (_u32 d = 0; d < ndims; d++) + for (size_t d = 0; d < ndims; d++) norms[i] += data[i * ndims + d] * data[i * ndims + d]; max_norm = norms[i] > max_norm ? norms[i] : max_norm; } @@ -84,19 +84,19 @@ template int augment_base(std::string base_file, std::string out_fi max_norm = std::sqrt(max_norm); std::cout << "Max norm: " << max_norm << std::endl; T *new_data; - _u64 newdims = ndims + 1; + size_t newdims = ndims + 1; new_data = new T[npts * newdims]; - for (_u64 i = 0; i < npts; i++) + for (size_t i = 0; i < npts; i++) { if (prep_base) { - for (_u64 j = 0; j < ndims; j++) + for (size_t j = 0; j < ndims; j++) { - new_data[i * newdims + j] = data[i * ndims + j] / max_norm; + new_data[i * newdims + j] = static_cast(data[i * ndims + j] / max_norm); } float diff = 1 - (norms[i] / (max_norm * max_norm)); diff = diff <= 0 ? 0 : std::sqrt(diff); - new_data[i * newdims + ndims] = diff; + new_data[i * newdims + ndims] = static_cast(diff); if (diff <= 0) { std::cout << i << " has large max norm, investigate if needed. diff = " << diff << std::endl; @@ -104,9 +104,9 @@ template int augment_base(std::string base_file, std::string out_fi } else { - for (_u64 j = 0; j < ndims; j++) + for (size_t j = 0; j < ndims; j++) { - new_data[i * newdims + j] = data[i * ndims + j] / std::sqrt(norms[i]); + new_data[i * newdims + j] = static_cast(data[i * ndims + j] / std::sqrt(norms[i])); } new_data[i * newdims + ndims] = 0; } @@ -120,7 +120,7 @@ template int augment_base(std::string base_file, std::string out_fi template int aux_main(char **argv) { std::string base_file(argv[2]); - _u32 option = atoi(argv[3]); + uint32_t option = atoi(argv[3]); if (option == 1) analyze_norm(base_file); else if (option == 2) diff --git a/clang-format.cmake b/clang-format.cmake index cbe6694b4..19bb3a850 100644 --- a/clang-format.cmake +++ b/clang-format.cmake @@ -2,7 +2,7 @@ if (NOT MSVC) message(STATUS "Setting up `make format` and `make checkformat`") # additional target to perform clang-format run, requires clang-format # get all project files - file(GLOB_RECURSE ALL_SOURCE_FILES include/*.h python/src/*.cpp src/*.cpp tests/*.cpp) + file(GLOB_RECURSE ALL_SOURCE_FILES include/*.h include/*.hpp python/src/*.cpp src/*.cpp src/*.hpp apps/*.cpp apps/*.hpp) message(status ${ALL_SOURCE_FILES}) diff --git a/include/abstract_data_store.h b/include/abstract_data_store.h new file mode 100644 index 000000000..976174378 --- /dev/null +++ b/include/abstract_data_store.h @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +#include "types.h" +#include "windows_customizations.h" +#include "distance.h" + +namespace diskann +{ + +template class AbstractDataStore +{ + public: + AbstractDataStore(const location_t capacity, const size_t dim); + + // virtual ~AbstractDataStore() = default; + + // Return number of points returned + virtual location_t load(const std::string &filename) = 0; + + // Why does store take num_pts? Since store only has capacity, but we allow + // resizing we can end up in a situation where the store has spare capacity. + // To optimize disk utilization, we pass the number of points that are "true" + // points, so that the store can discard the empty locations before saving. + virtual size_t save(const std::string &filename, const location_t num_pts) = 0; + + DISKANN_DLLEXPORT virtual location_t capacity() const; + + DISKANN_DLLEXPORT virtual size_t get_dims() const; + + // Implementers can choose to return _dim if they are not + // concerned about memory alignment. + // Some distance metrics (like l2) need data vectors to be aligned, so we + // align the dimension by padding zeros. + virtual size_t get_aligned_dim() const = 0; + + // populate the store with vectors (either from a pointer or bin file), + // potentially after pre-processing the vectors if the metric deems so + // e.g., normalizing vectors for cosine distance over floating-point vectors + // useful for bulk or static index building. + virtual void populate_data(const data_t *vectors, const location_t num_pts) = 0; + virtual void populate_data(const std::string &filename, const size_t offset) = 0; + + // save the first num_pts many vectors back to bin file + // note: cannot undo the pre-processing done in populate data + virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) = 0; + + // Returns the updated capacity of the datastore. Clients should check + // if resize actually changed the capacity to new_num_points before + // proceeding with operations. See the code below: + // auto new_capcity = data_store->resize(new_num_points); + // if ( new_capacity >= new_num_points) { + // //PROCEED + // else + // //ERROR. + virtual location_t resize(const location_t new_num_points); + + // operations on vectors + // like populate_data function, but over one vector at a time useful for + // streaming setting + virtual void get_vector(const location_t i, data_t *dest) const = 0; + virtual void set_vector(const location_t i, const data_t *const vector) = 0; + virtual void prefetch_vector(const location_t loc) = 0; + + // internal shuffle operations to move around vectors + // will bulk-move all the vectors in [old_start_loc, old_start_loc + + // num_points) to [new_start_loc, new_start_loc + num_points) and set the old + // positions to zero vectors. + virtual void move_vectors(const location_t old_start_loc, const location_t new_start_loc, + const location_t num_points) = 0; + + // same as above, without resetting the vectors in [from_loc, from_loc + + // num_points) to zero + virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) = 0; + + // metric specific operations + + virtual float get_distance(const data_t *query, const location_t loc) const = 0; + virtual void get_distance(const data_t *query, const location_t *locations, const uint32_t location_count, + float *distances) const = 0; + virtual float get_distance(const location_t loc1, const location_t loc2) const = 0; + + // stats of the data stored in store + // Returns the point in the dataset that is closest to the mean of all points + // in the dataset + virtual location_t calculate_medoid() const = 0; + + virtual Distance *get_dist_fn() = 0; + + // search helpers + // if the base data is aligned per the request of the metric, this will tell + // how to align the query vector in a consistent manner + virtual size_t get_alignment_factor() const = 0; + + protected: + // Expand the datastore to new_num_points. Returns the new capacity created, + // which should be == new_num_points in the normal case. Implementers can also + // return _capacity to indicate that there are not implementing this method. + virtual location_t expand(const location_t new_num_points) = 0; + + // Shrink the datastore to new_num_points. It is NOT an error if shrink + // doesn't reduce the capacity so callers need to check this correctly. See + // also for "default" implementation + virtual location_t shrink(const location_t new_num_points) = 0; + + location_t _capacity; + size_t _dim; +}; + +} // namespace diskann diff --git a/include/abstract_graph_store.h b/include/abstract_graph_store.h new file mode 100644 index 000000000..f7735b79a --- /dev/null +++ b/include/abstract_graph_store.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +#include "types.h" + +namespace diskann +{ + +class AbstractGraphStore +{ + public: + AbstractGraphStore(const size_t max_pts) : _capacity(max_pts) + { + } + + virtual int load(const std::string &index_path_prefix) = 0; + virtual int store(const std::string &index_path_prefix) = 0; + + virtual void get_adj_list(const location_t i, std::vector &neighbors) = 0; + virtual void set_adj_list(const location_t i, std::vector &neighbors) = 0; + + private: + size_t _capacity; +}; + +} // namespace diskann diff --git a/include/abstract_index.h b/include/abstract_index.h new file mode 100644 index 000000000..1a32bf8da --- /dev/null +++ b/include/abstract_index.h @@ -0,0 +1,118 @@ +#pragma once +#include "distance.h" +#include "parameters.h" +#include "utils.h" +#include "types.h" +#include "index_config.h" +#include "index_build_params.h" +#include + +namespace diskann +{ +struct consolidation_report +{ + enum status_code + { + SUCCESS = 0, + FAIL = 1, + LOCK_FAIL = 2, + INCONSISTENT_COUNT_ERROR = 3 + }; + status_code _status; + size_t _active_points, _max_points, _empty_slots, _slots_released, _delete_set_size, _num_calls_to_process_delete; + double _time; + + consolidation_report(status_code status, size_t active_points, size_t max_points, size_t empty_slots, + size_t slots_released, size_t delete_set_size, size_t num_calls_to_process_delete, + double time_secs) + : _status(status), _active_points(active_points), _max_points(max_points), _empty_slots(empty_slots), + _slots_released(slots_released), _delete_set_size(delete_set_size), + _num_calls_to_process_delete(num_calls_to_process_delete), _time(time_secs) + { + } +}; + +/* A templated independent class for intercation with Index. Uses Type Erasure to add virtual implemetation of methods +that can take any type(using std::any) and Provides a clean API that can be inherited by different type of Index. +*/ +class AbstractIndex +{ + public: + AbstractIndex() = default; + virtual ~AbstractIndex() = default; + + virtual void build(const std::string &data_file, const size_t num_points_to_load, + IndexBuildParams &build_params) = 0; + + template + void build(const data_type *data, const size_t num_points_to_load, const IndexWriteParameters ¶meters, + const std::vector &tags); + + virtual void save(const char *filename, bool compact_before_save = false) = 0; + +#ifdef EXEC_ENV_OLS + virtual void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) = 0; +#else + virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l) = 0; +#endif + + // For FastL2 search on optimized layout + template + void search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices); + + // Initialize space for res_vectors before calling. + template + size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, + float *distances, std::vector &res_vectors); + + // Added search overload that takes L as parameter, so that we + // can customize L on a per-query basis without tampering with "Parameters" + // IDtype is either uint32_t or uint64_t + template + std::pair search(const data_type *query, const size_t K, const uint32_t L, IDType *indices, + float *distances = nullptr); + + // Filter support search + // IndexType is either uint32_t or uint64_t + template + std::pair search_with_filters(const DataType &query, const std::string &raw_label, + const size_t K, const uint32_t L, IndexType *indices, + float *distances); + + template int insert_point(const data_type *point, const tag_type tag); + + template int lazy_delete(const tag_type &tag); + + template + void lazy_delete(const std::vector &tags, std::vector &failed_tags); + + template void get_active_tags(tsl::robin_set &active_tags); + + template void set_start_points_at_random(data_type radius, uint32_t random_seed = 0); + + virtual consolidation_report consolidate_deletes(const IndexWriteParameters ¶meters) = 0; + + virtual void optimize_index_layout() = 0; + + // memory should be allocated for vec before calling this function + template int get_vector_by_tag(tag_type &tag, data_type *vec); + + private: + virtual void _build(const DataType &data, const size_t num_points_to_load, const IndexWriteParameters ¶meters, + TagVector &tags) = 0; + virtual std::pair _search(const DataType &query, const size_t K, const uint32_t L, + std::any &indices, float *distances = nullptr) = 0; + virtual std::pair _search_with_filters(const DataType &query, const std::string &filter_label, + const size_t K, const uint32_t L, std::any &indices, + float *distances) = 0; + virtual int _insert_point(const DataType &data_point, const TagType tag) = 0; + virtual int _lazy_delete(const TagType &tag) = 0; + virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) = 0; + virtual void _get_active_tags(TagRobinSet &active_tags) = 0; + virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0; + virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0; + virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, + float *distances, DataVector &res_vectors) = 0; + virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0; +}; +} // namespace diskann diff --git a/include/ann_exception.h b/include/ann_exception.h index ff3bb33d9..6b81373c1 100644 --- a/include/ann_exception.h +++ b/include/ann_exception.h @@ -19,7 +19,7 @@ class ANNException : public std::runtime_error public: DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode); DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode, const std::string &funcSig, - const std::string &fileName, unsigned int lineNum); + const std::string &fileName, uint32_t lineNum); private: int _errorCode; @@ -29,6 +29,6 @@ class FileException : public ANNException { public: DISKANN_DLLEXPORT FileException(const std::string &filename, std::system_error &e, const std::string &funcSig, - const std::string &fileName, unsigned int lineNum); + const std::string &fileName, uint32_t lineNum); }; } // namespace diskann diff --git a/include/any_wrappers.h b/include/any_wrappers.h new file mode 100644 index 000000000..da9005cfb --- /dev/null +++ b/include/any_wrappers.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include +#include "tsl/robin_set.h" + +namespace AnyWrapper +{ + +/* + * Base Struct to hold refrence to the data. + * Note: No memory mamagement, caller need to keep object alive. + */ +struct AnyReference +{ + template AnyReference(Ty &reference) : _data(&reference) + { + } + + template Ty &get() + { + auto ptr = std::any_cast(_data); + return *ptr; + } + + private: + std::any _data; +}; +struct AnyRobinSet : public AnyReference +{ + template AnyRobinSet(const tsl::robin_set &robin_set) : AnyReference(robin_set) + { + } + template AnyRobinSet(tsl::robin_set &robin_set) : AnyReference(robin_set) + { + } +}; + +struct AnyVector : public AnyReference +{ + template AnyVector(const std::vector &vector) : AnyReference(vector) + { + } + template AnyVector(std::vector &vector) : AnyReference(vector) + { + } +}; +} // namespace AnyWrapper diff --git a/include/cached_io.h b/include/cached_io.h index a41c03431..daef2f2f7 100644 --- a/include/cached_io.h +++ b/include/cached_io.h @@ -95,8 +95,8 @@ class cached_ifstream reader.read(cache_buf, cache_size); cur_off = 0; } - // note that if size_left < cache_size, then cur_off = cache_size, so - // subsequent reads will all be directly from file + // note that if size_left < cache_size, then cur_off = cache_size, + // so subsequent reads will all be directly from file } } diff --git a/include/common_includes.h b/include/common_includes.h index 96de5a46c..e1a51bdec 100644 --- a/include/common_includes.h +++ b/include/common_includes.h @@ -23,4 +23,5 @@ #include #include #include +#include #include diff --git a/include/concurrent_queue.h b/include/concurrent_queue.h index 405b90617..1e57bbf0f 100644 --- a/include/concurrent_queue.h +++ b/include/concurrent_queue.h @@ -88,8 +88,8 @@ template class ConcurrentQueue { T ret = this->q.front(); this->q.pop(); - // diskann::cout << "thread_id: " << std::this_thread::get_id() << ", - // ctx: " + // diskann::cout << "thread_id: " << std::this_thread::get_id() << + // ", ctx: " // << ret.ctx << "\n"; lk.unlock(); return ret; diff --git a/include/defaults.h b/include/defaults.h new file mode 100644 index 000000000..2f157cb25 --- /dev/null +++ b/include/defaults.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once +#include + +namespace diskann +{ +namespace defaults +{ +const float ALPHA = 1.2f; +const uint32_t NUM_THREADS = 0; +const uint32_t MAX_OCCLUSION_SIZE = 750; +const uint32_t FILTER_LIST_SIZE = 0; +const uint32_t NUM_FROZEN_POINTS_STATIC = 0; +const uint32_t NUM_FROZEN_POINTS_DYNAMIC = 1; +// following constants should always be specified, but are useful as a +// sensible default at cli / python boundaries +const uint32_t MAX_DEGREE = 64; +const uint32_t BUILD_LIST_SIZE = 100; +const uint32_t SATURATE_GRAPH = false; +const uint32_t SEARCH_LIST_SIZE = 100; +} // namespace defaults +} // namespace diskann diff --git a/include/disk_utils.h b/include/disk_utils.h index ff0619e74..08f046dcd 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -49,7 +49,7 @@ DISKANN_DLLEXPORT void add_new_file_to_single_index(std::string index_file, std: DISKANN_DLLEXPORT size_t calculate_num_pq_chunks(double final_index_ram_limit, size_t points_num, uint32_t dim); -DISKANN_DLLEXPORT void read_idmap(const std::string &fname, std::vector &ivecs); +DISKANN_DLLEXPORT void read_idmap(const std::string &fname, std::vector &ivecs); #ifdef EXEC_ENV_OLS template @@ -63,7 +63,7 @@ DISKANN_DLLEXPORT T *load_warmup(const std::string &cache_warmup_file, uint64_t DISKANN_DLLEXPORT int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suffix, const std::string &idmaps_prefix, const std::string &idmaps_suffix, - const _u64 nshards, unsigned max_degree, const std::string &output_vamana, + const uint64_t nshards, uint32_t max_degree, const std::string &output_vamana, const std::string &medoids_file, bool use_filters = false, const std::string &labels_to_medoids_file = std::string("")); @@ -75,27 +75,30 @@ DISKANN_DLLEXPORT std::string preprocess_base_file(const std::string &infile, co diskann::Metric &distMetric); template -DISKANN_DLLEXPORT int build_merged_vamana_index(std::string base_file, diskann::Metric _compareMetric, unsigned L, - unsigned R, double sampling_rate, double ram_budget, +DISKANN_DLLEXPORT int build_merged_vamana_index(std::string base_file, diskann::Metric _compareMetric, uint32_t L, + uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq, - bool use_filters = false, + uint32_t num_threads, bool use_filters = false, const std::string &label_file = std::string(""), const std::string &labels_to_medoids_file = std::string(""), - const std::string &universal_label = "", const _u32 Lf = 0); + const std::string &universal_label = "", const uint32_t Lf = 0); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth(std::unique_ptr> &_pFlashIndex, - T *tuning_sample, _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, - uint32_t L, uint32_t nthreads, uint32_t start_bw = 2); + T *tuning_sample, uint64_t tuning_sample_num, + uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, + uint32_t start_bw = 2); template DISKANN_DLLEXPORT int build_disk_index( const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, - diskann::Metric _compareMetric, bool use_opq = false, bool use_filters = false, + diskann::Metric _compareMetric, bool use_opq = false, + const std::string &codebook_prefix = "", // default is empty for no codebook pass in + bool use_filters = false, const std::string &label_file = std::string(""), // default is empty string for no label_file - const std::string &universal_label = "", const _u32 filter_threshold = 0, - const _u32 Lf = 0); // default is empty string for no universal label + const std::string &universal_label = "", const uint32_t filter_threshold = 0, + const uint32_t Lf = 0); // default is empty string for no universal label template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file, diff --git a/include/distance.h b/include/distance.h index e04be7ee2..8b20e586b 100644 --- a/include/distance.h +++ b/include/distance.h @@ -1,5 +1,6 @@ #pragma once #include "windows_customizations.h" +#include namespace diskann { @@ -14,21 +15,77 @@ enum Metric template class Distance { public: - virtual float compare(const T *a, const T *b, uint32_t length) const = 0; - virtual ~Distance() + DISKANN_DLLEXPORT Distance(diskann::Metric dist_metric) : _distance_metric(dist_metric) { } + + // distance comparison function + DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const = 0; + + // Needed only for COSINE-BYTE and INNER_PRODUCT-BYTE + DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, const float normA, const float normB, + uint32_t length) const; + + // For MIPS, normalization adds an extra dimension to the vectors. + // This function lets callers know if the normalization process + // changes the dimension. + DISKANN_DLLEXPORT virtual uint32_t post_normalization_dimension(uint32_t orig_dimension) const; + + DISKANN_DLLEXPORT virtual diskann::Metric get_metric() const; + + // This is for efficiency. If no normalization is required, the callers + // can simply ignore the normalize_data_for_build() function. + DISKANN_DLLEXPORT virtual bool preprocessing_required() const; + + // Check the preprocessing_required() function before calling this. + // Clients can call the function like this: + // + // if (metric->preprocessing_required()){ + // T* normalized_data_batch; + // Split data into batches of batch_size and for each, call: + // metric->preprocess_base_points(data_batch, batch_size); + // + // TODO: This does not take into account the case for SSD inner product + // where the dimensions change after normalization. + DISKANN_DLLEXPORT virtual void preprocess_base_points(T *original_data, const size_t orig_dim, + const size_t num_points); + + // Invokes normalization for a single vector during search. The scratch space + // has to be created by the caller keeping track of the fact that + // normalization might change the dimension of the query vector. + DISKANN_DLLEXPORT virtual void preprocess_query(const T *query_vec, const size_t query_dim, T *scratch_query); + + // If an algorithm has a requirement that some data be aligned to a certain + // boundary it can use this function to indicate that requirement. Currently, + // we are setting it to 8 because that works well for AVX2. If we have AVX512 + // implementations of distance algos, they might have to set this to 16 + // (depending on how they are implemented) + DISKANN_DLLEXPORT virtual size_t get_required_alignment() const; + + // Providing a default implementation for the virtual destructor because we + // don't expect most metric implementations to need it. + DISKANN_DLLEXPORT virtual ~Distance(); + + protected: + diskann::Metric _distance_metric; + size_t _alignment_factor = 8; }; class DistanceCosineInt8 : public Distance { public: + DistanceCosineInt8() : Distance(diskann::Metric::COSINE) + { + } DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const; }; class DistanceL2Int8 : public Distance { public: + DistanceL2Int8() : Distance(diskann::Metric::L2) + { + } DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t size) const; }; @@ -36,18 +93,28 @@ class DistanceL2Int8 : public Distance class AVXDistanceL2Int8 : public Distance { public: + AVXDistanceL2Int8() : Distance(diskann::Metric::L2) + { + } DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const; }; class DistanceCosineFloat : public Distance { public: + DistanceCosineFloat() : Distance(diskann::Metric::COSINE) + { + } DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const; }; class DistanceL2Float : public Distance { public: + DistanceL2Float() : Distance(diskann::Metric::L2) + { + } + #ifdef _WINDOWS DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const; #else @@ -58,46 +125,49 @@ class DistanceL2Float : public Distance class AVXDistanceL2Float : public Distance { public: + AVXDistanceL2Float() : Distance(diskann::Metric::L2) + { + } DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const; }; -class SlowDistanceL2Float : public Distance +template class SlowDistanceL2 : public Distance { public: - DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const; + SlowDistanceL2() : Distance(diskann::Metric::L2) + { + } + DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const; }; class SlowDistanceCosineUInt8 : public Distance { public: + SlowDistanceCosineUInt8() : Distance(diskann::Metric::COSINE) + { + } DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t length) const; }; class DistanceL2UInt8 : public Distance { public: + DistanceL2UInt8() : Distance(diskann::Metric::L2) + { + } DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t size) const; }; -// Simple implementations for non-AVX machines. Compiler can optimize. -template class SlowDistanceL2Int : public Distance +template class DistanceInnerProduct : public Distance { public: - // Implementing here because this is a template function - DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const + DistanceInnerProduct() : Distance(diskann::Metric::INNER_PRODUCT) { - uint32_t result = 0; - for (uint32_t i = 0; i < length; i++) - { - result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * ((int32_t)((int16_t)a[i] - (int16_t)b[i])); - } - return (float)result; } -}; -template class DistanceInnerProduct : public Distance -{ - public: + DistanceInnerProduct(diskann::Metric metric) : Distance(metric) + { + } inline float inner_product(const T *a, const T *b, unsigned size) const; inline float compare(const T *a, const T *b, unsigned size) const @@ -115,6 +185,9 @@ template class DistanceFastL2 : public DistanceInnerProduct // currently defined only for float. // templated for future use. public: + DistanceFastL2() : DistanceInnerProduct(diskann::Metric::FAST_L2) + { + } float norm(const T *a, unsigned size) const; float compare(const T *a, const T *b, float norm, unsigned size) const; }; @@ -122,6 +195,9 @@ template class DistanceFastL2 : public DistanceInnerProduct class AVXDistanceInnerProductFloat : public Distance { public: + AVXDistanceInnerProductFloat() : Distance(diskann::Metric::INNER_PRODUCT) + { + } DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const; }; @@ -130,13 +206,28 @@ class AVXNormalizedCosineDistanceFloat : public Distance private: AVXDistanceInnerProductFloat _innerProduct; + protected: + void normalize_and_copy(const float *a, uint32_t length, float *a_norm) const; + public: + AVXNormalizedCosineDistanceFloat() : Distance(diskann::Metric::COSINE) + { + } DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const { // Inner product returns negative values to indicate distance. // This will ensure that cosine is between -1 and 1. return 1.0f + _innerProduct.compare(a, b, length); } + DISKANN_DLLEXPORT virtual uint32_t post_normalization_dimension(uint32_t orig_dimension) const override; + + DISKANN_DLLEXPORT virtual bool preprocessing_required() const; + + DISKANN_DLLEXPORT virtual void preprocess_base_points(float *original_data, const size_t orig_dim, + const size_t num_points) override; + + DISKANN_DLLEXPORT virtual void preprocess_query(const float *query_vec, const size_t query_dim, + float *scratch_query_vector) override; }; template Distance *get_distance_function(Metric m); diff --git a/include/filter_utils.h b/include/filter_utils.h new file mode 100644 index 000000000..df1970be4 --- /dev/null +++ b/include/filter_utils.h @@ -0,0 +1,217 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef __APPLE__ +#else +#include +#endif + +#ifdef _WINDOWS +#include +typedef HANDLE FileHandle; +#else +#include +typedef int FileHandle; +#endif + +#ifndef _WINDOWS +#include +#endif + +#include "cached_io.h" +#include "common_includes.h" +#include "memory_mapper.h" +#include "utils.h" +#include "windows_customizations.h" + +// custom types (for readability) +typedef tsl::robin_set label_set; +typedef std::string path; + +// structs for returning multiple items from a function +typedef std::tuple, tsl::robin_map, tsl::robin_set> + parse_label_file_return_values; +typedef std::tuple>, uint64_t> load_label_index_return_values; + +namespace diskann +{ +template +DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, label_set all_labels, + unsigned R, unsigned L, float alpha, unsigned num_threads); + +DISKANN_DLLEXPORT load_label_index_return_values load_label_index(path label_index_path, + uint32_t label_number_of_points); + +DISKANN_DLLEXPORT parse_label_file_return_values parse_label_file(path label_data_path, std::string universal_label); + +template +DISKANN_DLLEXPORT tsl::robin_map> generate_label_specific_vector_files_compat( + path input_data_path, tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels); + +/* + * For each label, generates a file containing all vectors that have said label. + * Also copies data from original bin file to new dimension-aligned file. + * + * Utilizes POSIX functions mmap and writev in order to minimize memory + * overhead, so we include an STL version as well. + * + * Each data file is saved under the following format: + * input_data_path + "_" + label + */ +#ifndef _WINDOWS +template +inline tsl::robin_map> generate_label_specific_vector_files( + path input_data_path, tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels) +{ +#ifndef _WINDOWS + auto file_writing_timer = std::chrono::high_resolution_clock::now(); + diskann::MemoryMapper input_data(input_data_path); + char *input_start = input_data.getBuf(); + + uint32_t number_of_points, dimension; + std::memcpy(&number_of_points, input_start, sizeof(uint32_t)); + std::memcpy(&dimension, input_start + sizeof(uint32_t), sizeof(uint32_t)); + const uint32_t VECTOR_SIZE = dimension * sizeof(T); + const size_t METADATA = 2 * sizeof(uint32_t); + if (number_of_points != point_ids_to_labels.size()) + { + std::cerr << "Error: number of points in labels file and data file differ." << std::endl; + throw; + } + + tsl::robin_map label_to_iovec_map; + tsl::robin_map label_to_curr_iovec; + tsl::robin_map> label_id_to_orig_id; + + // setup iovec list for each label + for (const auto &lbl : all_labels) + { + iovec *label_iovecs = (iovec *)malloc(labels_to_number_of_points[lbl] * sizeof(iovec)); + if (label_iovecs == nullptr) + { + throw; + } + label_to_iovec_map[lbl] = label_iovecs; + label_to_curr_iovec[lbl] = 0; + label_id_to_orig_id[lbl].reserve(labels_to_number_of_points[lbl]); + } + + // each point added to corresponding per-label iovec list + for (uint32_t point_id = 0; point_id < number_of_points; point_id++) + { + char *curr_point = input_start + METADATA + (VECTOR_SIZE * point_id); + iovec curr_iovec; + + curr_iovec.iov_base = curr_point; + curr_iovec.iov_len = VECTOR_SIZE; + for (const auto &lbl : point_ids_to_labels[point_id]) + { + *(label_to_iovec_map[lbl] + label_to_curr_iovec[lbl]) = curr_iovec; + label_to_curr_iovec[lbl]++; + label_id_to_orig_id[lbl].push_back(point_id); + } + } + + // write each label iovec to resp. file + for (const auto &lbl : all_labels) + { + int label_input_data_fd; + path curr_label_input_data_path(input_data_path + "_" + lbl); + uint32_t curr_num_pts = labels_to_number_of_points[lbl]; + + label_input_data_fd = + open(curr_label_input_data_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC | O_APPEND, (mode_t)0644); + if (label_input_data_fd == -1) + throw; + + // write metadata + uint32_t metadata[2] = {curr_num_pts, dimension}; + int return_value = write(label_input_data_fd, metadata, sizeof(uint32_t) * 2); + if (return_value == -1) + { + throw; + } + + // limits on number of iovec structs per writev means we need to perform + // multiple writevs + size_t i = 0; + while (curr_num_pts > IOV_MAX) + { + return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), IOV_MAX); + if (return_value == -1) + { + close(label_input_data_fd); + throw; + } + curr_num_pts -= IOV_MAX; + i += 1; + } + return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), curr_num_pts); + if (return_value == -1) + { + close(label_input_data_fd); + throw; + } + + free(label_to_iovec_map[lbl]); + close(label_input_data_fd); + } + + std::chrono::duration file_writing_time = std::chrono::high_resolution_clock::now() - file_writing_timer; + std::cout << "generated " << all_labels.size() << " label-specific vector files for index building in time " + << file_writing_time.count() << "\n" + << std::endl; + + return label_id_to_orig_id; +#endif +} +#endif + +inline std::vector loadTags(const std::string &tags_file, const std::string &base_file) +{ + const bool tags_enabled = tags_file.empty() ? false : true; + std::vector location_to_tag; + if (tags_enabled) + { + size_t tag_file_ndims, tag_file_npts; + std::uint32_t *tag_data; + diskann::load_bin(tags_file, tag_data, tag_file_npts, tag_file_ndims); + if (tag_file_ndims != 1) + { + diskann::cerr << "tags file error" << std::endl; + throw diskann::ANNException("tag file error", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + // check if the point count match + size_t base_file_npts, base_file_ndims; + diskann::get_bin_metadata(base_file, base_file_npts, base_file_ndims); + if (base_file_npts != tag_file_npts) + { + diskann::cerr << "point num in tags file mismatch" << std::endl; + throw diskann::ANNException("point num in tags file mismatch", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + location_to_tag.assign(tag_data, tag_data + tag_file_npts); + delete[] tag_data; + } + return location_to_tag; +} + +} // namespace diskann diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h new file mode 100644 index 000000000..0509b3b82 --- /dev/null +++ b/include/in_mem_data_store.h @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include + +#include "tsl/robin_map.h" +#include "tsl/robin_set.h" +#include "tsl/sparse_map.h" +// #include "boost/dynamic_bitset.hpp" + +#include "abstract_data_store.h" + +#include "distance.h" +#include "natural_number_map.h" +#include "natural_number_set.h" +#include "aligned_file_reader.h" + +namespace diskann +{ +template class InMemDataStore : public AbstractDataStore +{ + public: + InMemDataStore(const location_t capacity, const size_t dim, std::shared_ptr> distance_fn); + virtual ~InMemDataStore(); + + virtual location_t load(const std::string &filename) override; + virtual size_t save(const std::string &filename, const location_t num_points) override; + + virtual size_t get_aligned_dim() const override; + + // Populate internal data from unaligned data while doing alignment and any + // normalization that is required. + virtual void populate_data(const data_t *vectors, const location_t num_pts) override; + virtual void populate_data(const std::string &filename, const size_t offset) override; + + virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) override; + + virtual void get_vector(const location_t i, data_t *target) const override; + virtual void set_vector(const location_t i, const data_t *const vector) override; + virtual void prefetch_vector(const location_t loc) override; + + virtual void move_vectors(const location_t old_location_start, const location_t new_location_start, + const location_t num_points) override; + virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) override; + + virtual float get_distance(const data_t *query, const location_t loc) const override; + virtual float get_distance(const location_t loc1, const location_t loc2) const override; + virtual void get_distance(const data_t *query, const location_t *locations, const uint32_t location_count, + float *distances) const override; + + virtual location_t calculate_medoid() const override; + + virtual Distance *get_dist_fn() override; + + virtual size_t get_alignment_factor() const override; + + protected: + virtual location_t expand(const location_t new_size) override; + virtual location_t shrink(const location_t new_size) override; + + virtual location_t load_impl(const std::string &filename); +#ifdef EXEC_ENV_OLS + virtual location_t load_impl(AlignedFileReader &reader); +#endif + + private: + data_t *_data = nullptr; + + size_t _aligned_dim; + + // It may seem weird to put distance metric along with the data store class, + // but this gives us perf benefits as the datastore can do distance + // computations during search and compute norms of vectors internally without + // have to copy data back and forth. + std::shared_ptr> _distance_fn; + + // in case we need to save vector norms for optimization + std::shared_ptr _pre_computed_norms; +}; + +} // namespace diskann \ No newline at end of file diff --git a/include/in_mem_graph_store.h b/include/in_mem_graph_store.h new file mode 100644 index 000000000..98a9e4dc5 --- /dev/null +++ b/include/in_mem_graph_store.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "abstract_graph_store.h" + +namespace diskann +{ + +class InMemGraphStore : public AbstractGraphStore +{ + public: + InMemGraphStore(const size_t max_pts); + + int load(const std::string &index_path_prefix); + int store(const std::string &index_path_prefix); + + void get_adj_list(const location_t i, std::vector &neighbors); + void set_adj_list(const location_t i, std::vector &neighbors); +}; + +} // namespace diskann diff --git a/include/index.h b/include/index.h index 10efe1323..7c38db00e 100644 --- a/include/index.h +++ b/include/index.h @@ -18,6 +18,8 @@ #include "utils.h" #include "windows_customizations.h" #include "scratch.h" +#include "in_mem_data_store.h" +#include "abstract_index.h" #include #define OVERHEAD_FACTOR 1.1 @@ -27,40 +29,16 @@ namespace diskann { -inline double estimate_ram_usage(_u64 size, _u32 dim, _u32 datasize, _u32 degree) +inline double estimate_ram_usage(size_t size, uint32_t dim, uint32_t datasize, uint32_t degree) { double size_of_data = ((double)size) * ROUND_UP(dim, 8) * datasize; - double size_of_graph = ((double)size) * degree * sizeof(unsigned) * GRAPH_SLACK_FACTOR; + double size_of_graph = ((double)size) * degree * sizeof(uint32_t) * GRAPH_SLACK_FACTOR; double size_of_locks = ((double)size) * sizeof(non_recursive_mutex); double size_of_outer_vector = ((double)size) * sizeof(ptrdiff_t); return OVERHEAD_FACTOR * (size_of_data + size_of_graph + size_of_locks + size_of_outer_vector); } - -struct consolidation_report -{ - enum status_code - { - SUCCESS = 0, - FAIL = 1, - LOCK_FAIL = 2, - INCONSISTENT_COUNT_ERROR = 3 - }; - status_code _status; - size_t _active_points, _max_points, _empty_slots, _slots_released, _delete_set_size, _num_calls_to_process_delete; - double _time; - - consolidation_report(status_code status, size_t active_points, size_t max_points, size_t empty_slots, - size_t slots_released, size_t delete_set_size, size_t num_calls_to_process_delete, - double time_secs) - : _status(status), _active_points(active_points), _max_points(max_points), _empty_slots(empty_slots), - _slots_released(slots_released), _delete_set_size(delete_set_size), - _num_calls_to_process_delete(num_calls_to_process_delete), _time(time_secs) - { - } -}; - struct simple_bitmask_val { size_t _index = 0; @@ -172,7 +150,7 @@ class simple_bitmask std::uint64_t _bitmask_size; }; -template class Index +template class Index : public AbstractIndex { /************************************************************************** * @@ -188,14 +166,18 @@ template clas DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points = 1, const bool dynamic_index = false, const bool enable_tags = false, const bool concurrent_consolidate = false, const bool pq_dist_build = false, const size_t num_pq_chunks = 0, - const bool use_opq = false, const size_t num_frozen_pts = 0); + const bool use_opq = false, const size_t num_frozen_pts = 0, + const bool init_data_store = true); // Constructor for incremental index DISKANN_DLLEXPORT Index(Metric m, const size_t dim, const size_t max_points, const bool dynamic_index, - const Parameters &indexParameters, const Parameters &searchParameters, - const bool enable_tags = false, const bool concurrent_consolidate = false, - const bool pq_dist_build = false, const size_t num_pq_chunks = 0, - const bool use_opq = false); + const IndexWriteParameters &indexParameters, const uint32_t initial_search_list_size, + const uint32_t search_threads, const bool enable_tags = false, + const bool concurrent_consolidate = false, const bool pq_dist_build = false, + const size_t num_pq_chunks = 0, const bool use_opq = false); + + DISKANN_DLLEXPORT Index(const IndexConfig &index_config, std::unique_ptr> data_store + /* std::unique_ptr graph_store*/); DISKANN_DLLEXPORT ~Index(); @@ -216,21 +198,28 @@ template clas DISKANN_DLLEXPORT size_t get_num_points(); DISKANN_DLLEXPORT size_t get_max_points(); + DISKANN_DLLEXPORT bool detect_common_filters(uint32_t point_id, bool search_invocation, + const std::vector &incoming_labels); + // Batch build from a file. Optionally pass tags vector. - DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load, Parameters ¶meters, + DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, const std::vector &tags = std::vector()); // Batch build from a file. Optionally pass tags file. - DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load, Parameters ¶meters, - const char *tag_filename); + DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, const char *tag_filename); // Batch build from a data array, which must pad vectors to aligned_dim - DISKANN_DLLEXPORT void build(const T *data, const size_t num_points_to_load, Parameters ¶meters, + DISKANN_DLLEXPORT void build(const T *data, const size_t num_points_to_load, const IndexWriteParameters ¶meters, const std::vector &tags); + DISKANN_DLLEXPORT void build(const std::string &data_file, const size_t num_points_to_load, + IndexBuildParams &build_params); + // Filtered Support DISKANN_DLLEXPORT void build_filtered_index(const char *filename, const std::string &label_file, - const size_t num_points_to_load, Parameters ¶meters, + const size_t num_points_to_load, IndexWriteParameters ¶meters, const std::vector &tags = std::vector()); DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); @@ -244,28 +233,28 @@ template clas // Set starting points to random points on a sphere of certain radius. // A fixed random seed can be specified for scenarios where it's important // to have higher consistency between index builds. - DISKANN_DLLEXPORT void set_start_points_at_random(T radius, unsigned int random_seed = 0); + DISKANN_DLLEXPORT void set_start_points_at_random(T radius, uint32_t random_seed = 0); // For FastL2 search on a static index, we interleave the data with graph DISKANN_DLLEXPORT void optimize_index_layout(); // For FastL2 search on optimized layout - DISKANN_DLLEXPORT void search_with_optimized_layout(const T *query, size_t K, size_t L, unsigned *indices); + DISKANN_DLLEXPORT void search_with_optimized_layout(const T *query, size_t K, size_t L, uint32_t *indices); // Added search overload that takes L as parameter, so that we // can customize L on a per-query basis without tampering with "Parameters" template - DISKANN_DLLEXPORT std::pair search(const T *query, const size_t K, const unsigned L, + DISKANN_DLLEXPORT std::pair search(const T *query, const size_t K, const uint32_t L, IDType *indices, float *distances = nullptr); // Initialize space for res_vectors before calling. - DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const unsigned L, TagT *tags, + DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, float *distances, std::vector &res_vectors); // Filter support search template DISKANN_DLLEXPORT std::pair search_with_filters(const T *query, const LabelT &filter_label, - const size_t K, const unsigned L, + const size_t K, const uint32_t L, IndexType *indices, float *distances); // Will fail if tag already in the index or if tag=0. @@ -286,17 +275,18 @@ template clas // Returns number of live points left after consolidation // If _conc_consolidates is set in the ctor, then this call can be invoked // alongside inserts and lazy deletes, else it acquires _update_lock - DISKANN_DLLEXPORT consolidation_report consolidate_deletes(const Parameters ¶meters); + DISKANN_DLLEXPORT consolidation_report consolidate_deletes(const IndexWriteParameters ¶meters); - DISKANN_DLLEXPORT void prune_all_nbrs(const Parameters ¶meters); + DISKANN_DLLEXPORT void prune_all_neighbors(const uint32_t max_degree, const uint32_t max_occlusion, + const float alpha); DISKANN_DLLEXPORT bool is_index_saved(); // repositions frozen points to the end of _data - if they have been moved // during deletion DISKANN_DLLEXPORT void reposition_frozen_point_to_end(); - DISKANN_DLLEXPORT void reposition_points(unsigned old_location_start, unsigned new_location_start, - unsigned num_locations); + DISKANN_DLLEXPORT void reposition_points(uint32_t old_location_start, uint32_t new_location_start, + uint32_t num_locations); // DISKANN_DLLEXPORT void save_index_as_one_file(bool flag); @@ -320,20 +310,48 @@ template clas // ******************************** protected: + // overload of abstract index virtual methods + virtual void _build(const DataType &data, const size_t num_points_to_load, const IndexWriteParameters ¶meters, + TagVector &tags) override; + + virtual std::pair _search(const DataType &query, const size_t K, const uint32_t L, + std::any &indices, float *distances = nullptr) override; + virtual std::pair _search_with_filters(const DataType &query, + const std::string &filter_label_raw, const size_t K, + const uint32_t L, std::any &indices, + float *distances) override; + + virtual int _insert_point(const DataType &data_point, const TagType tag) override; + + virtual int _lazy_delete(const TagType &tag) override; + + virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) override; + + virtual void _get_active_tags(TagRobinSet &active_tags) override; + + virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) override; + + virtual int _get_vector_by_tag(TagType &tag, DataType &vec) override; + + virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override; + + virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, + float *distances, DataVector &res_vectors) override; + // No copy/assign. Index(const Index &) = delete; Index &operator=(const Index &) = delete; // Use after _data and _nd have been populated // Acquire exclusive _update_lock before calling - void build_with_data_populated(Parameters ¶meters, const std::vector &tags); + void build_with_data_populated(const IndexWriteParameters ¶meters, const std::vector &tags); // generates 1 frozen point that will never be deleted from the graph // This is not visible to the user void generate_frozen_point(); // determines navigating node of the graph by calculating medoid of datafopt - unsigned calculate_entry_point(); + uint32_t calculate_entry_point(); void parse_label_file(const std::string &label_file, size_t &num_pts_labels); @@ -343,44 +361,47 @@ template clas std::unordered_map load_label_map(const std::string &map_file); - // Returns the locations of start point and frozen points suitable for use with iterate_to_fixed_point. - std::vector get_init_ids(); + // Returns the locations of start point and frozen points suitable for use + // with iterate_to_fixed_point. + std::vector get_init_ids(); - std::pair iterate_to_fixed_point(const T *node_coords, const unsigned Lindex, - const std::vector &init_ids, + std::pair iterate_to_fixed_point(const T *node_coords, const uint32_t Lindex, + const std::vector &init_ids, InMemQueryScratch *scratch, bool use_filter, const std::vector &filters, bool search_invocation); - void search_for_point_and_prune(int location, _u32 Lindex, std::vector &pruned_list, - InMemQueryScratch *scratch, bool use_filter = false, _u32 filteredLindex = 0); + void search_for_point_and_prune(int location, uint32_t Lindex, std::vector &pruned_list, + InMemQueryScratch *scratch, bool use_filter = false, + uint32_t filteredLindex = 0); - void prune_neighbors(const unsigned location, std::vector &pool, std::vector &pruned_list, + void prune_neighbors(const uint32_t location, std::vector &pool, std::vector &pruned_list, InMemQueryScratch *scratch); - void prune_neighbors(const unsigned location, std::vector &pool, const _u32 range, - const _u32 max_candidate_size, const float alpha, std::vector &pruned_list, + void prune_neighbors(const uint32_t location, std::vector &pool, const uint32_t range, + const uint32_t max_candidate_size, const float alpha, std::vector &pruned_list, InMemQueryScratch *scratch); // Prunes candidates in @pool to a shorter list @result // @pool must be sorted before calling - void occlude_list(const unsigned location, std::vector &pool, const float alpha, const unsigned degree, - const unsigned maxc, std::vector &result, InMemQueryScratch *scratch, - const tsl::robin_set *const delete_set_ptr = nullptr); + void occlude_list(const uint32_t location, std::vector &pool, const float alpha, const uint32_t degree, + const uint32_t maxc, std::vector &result, InMemQueryScratch *scratch, + const tsl::robin_set *const delete_set_ptr = nullptr); // add reverse links from all the visited nodes to node n. - void inter_insert(unsigned n, std::vector &pruned_list, const _u32 range, InMemQueryScratch *scratch); + void inter_insert(uint32_t n, std::vector &pruned_list, const uint32_t range, + InMemQueryScratch *scratch); - void inter_insert(unsigned n, std::vector &pruned_list, InMemQueryScratch *scratch); + void inter_insert(uint32_t n, std::vector &pruned_list, InMemQueryScratch *scratch); // Acquire exclusive _update_lock before calling - void link(Parameters ¶meters); + void link(const IndexWriteParameters ¶meters); // Acquire exclusive _tag_lock and _delete_lock before calling int reserve_location(); // Acquire exclusive _tag_lock before calling size_t release_location(int location); - size_t release_locations(const tsl::robin_set &locations); + size_t release_locations(const tsl::robin_set &locations); // Resize the index when no slots are left for insertion. // Acquire exclusive _update_lock and _tag_lock before calling. @@ -397,18 +418,18 @@ template clas // Remove deleted nodes from adjacency list of node loc // Replace removed neighbors with second order neighbors. // Also acquires _locks[i] for i = loc and out-neighbors of loc. - void process_delete(const tsl::robin_set &old_delete_set, size_t loc, const unsigned range, - const unsigned maxc, const float alpha, InMemQueryScratch *scratch); + void process_delete(const tsl::robin_set &old_delete_set, size_t loc, const uint32_t range, + const uint32_t maxc, const float alpha, InMemQueryScratch *scratch); void initialize_query_scratch(uint32_t num_threads, uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t bitmask_size = 0); // Do not call without acquiring appropriate locks // call public member functions save and load to invoke these. - DISKANN_DLLEXPORT _u64 save_graph(std::string filename); - DISKANN_DLLEXPORT _u64 save_data(std::string filename); - DISKANN_DLLEXPORT _u64 save_tags(std::string filename); - DISKANN_DLLEXPORT _u64 save_delete_list(const std::string &filename); + DISKANN_DLLEXPORT size_t save_graph(std::string filename); + DISKANN_DLLEXPORT size_t save_data(std::string filename); + DISKANN_DLLEXPORT size_t save_tags(std::string filename); + DISKANN_DLLEXPORT size_t save_delete_list(const std::string &filename); #ifdef EXEC_ENV_OLS DISKANN_DLLEXPORT size_t load_graph(AlignedFileReader &reader, size_t expected_num_points); DISKANN_DLLEXPORT size_t load_data(AlignedFileReader &reader); @@ -424,35 +445,37 @@ template clas private: // Distance functions Metric _dist_metric = diskann::L2; - Distance *_distance = nullptr; + std::shared_ptr> _distance; // Data - T *_data = nullptr; + std::unique_ptr> _data_store; char *_opt_graph = nullptr; // Graph related data structures - std::vector> _final_graph; + std::vector> _final_graph; + T *_data = nullptr; // coordinates of all base points // Dimensions size_t _dim = 0; - size_t _aligned_dim = 0; size_t _nd = 0; // number of active points i.e. existing in the graph size_t _max_points = 0; // total number of points in given data set - // Number of points which are used as initial candidates when iterating to - // closest point(s). These are not visible externally and won't be returned - // by search. DiskANN forces at least 1 frozen point for dynamic index. - // The frozen points have consecutive locations. See also _start below. + + // _num_frozen_pts is the number of points which are used as initial + // candidates when iterating to closest point(s). These are not visible + // externally and won't be returned by search. At least 1 frozen point is + // needed for a dynamic index. The frozen points have consecutive locations. + // See also _start below. size_t _num_frozen_pts = 0; size_t _max_range_of_loaded_graph = 0; size_t _node_size; size_t _data_len; size_t _neighbor_len; - unsigned _max_observed_degree = 0; + uint32_t _max_observed_degree = 0; // Start point of the search. When _num_frozen_pts is greater than zero, // this is the location of the first frozen point. Otherwise, this is a // location of one of the points in index. - unsigned _start = 0; + uint32_t _start = 0; bool _has_built = false; bool _saturate_graph = false; @@ -460,6 +483,7 @@ template clas bool _dynamic_index = false; bool _enable_tags = false; bool _normalize_vecs = false; // Using normalied L2 for cosine. + bool _deletes_enabled = false; // Filter Support @@ -467,8 +491,8 @@ template clas std::vector> _pts_to_labels; tsl::robin_set _labels; std::string _labels_file; - std::unordered_map _label_to_medoid_id; - std::unordered_map<_u32, _u32> _medoid_counts; + std::unordered_map _label_to_medoid_id; + std::unordered_map _medoid_counts; bool _use_universal_label = false; LabelT _universal_label = 0; uint32_t _filterIndexingQueueSize; @@ -487,7 +511,7 @@ template clas bool _pq_dist = false; bool _use_opq = false; size_t _num_pq_chunks = 0; - _u8 *_pq_data = nullptr; + uint8_t *_pq_data = nullptr; bool _pq_generated = false; FixedChunkPQTable _pq_table; @@ -497,18 +521,18 @@ template clas // lazy_delete removes entry from _location_to_tag and _tag_to_location. If // _location_to_tag does not resolve a location, infer that it was deleted. - tsl::sparse_map _tag_to_location; - natural_number_map _location_to_tag; + tsl::sparse_map _tag_to_location; + natural_number_map _location_to_tag; // _empty_slots has unallocated slots and those freed by consolidate_delete. // _delete_set has locations marked deleted by lazy_delete. Will not be // immediately available for insert. consolidate_delete will release these // slots to _empty_slots. - natural_number_set _empty_slots; - std::unique_ptr> _delete_set; + natural_number_set _empty_slots; + std::unique_ptr> _delete_set; bool _data_compacted = true; // true if data has been compacted - bool _is_saved = false; // Gopal. Checking if the index is already saved. + bool _is_saved = false; // Checking if the index is already saved. bool _conc_consolidate = false; // use _lock while searching // Acquire locks in the order below when acquiring multiple locks diff --git a/include/index_build_params.h b/include/index_build_params.h new file mode 100644 index 000000000..ff68c5001 --- /dev/null +++ b/include/index_build_params.h @@ -0,0 +1,72 @@ +#include "common_includes.h" +#include "parameters.h" + +namespace diskann +{ +struct IndexBuildParams +{ + public: + diskann::IndexWriteParameters index_write_params; + std::string save_path_prefix; + std::string label_file; + std::string universal_label; + uint32_t filter_threshold = 0; + + private: + IndexBuildParams(const IndexWriteParameters &index_write_params, const std::string &save_path_prefix, + const std::string &label_file, const std::string &universal_label, uint32_t filter_threshold) + : index_write_params(index_write_params), save_path_prefix(save_path_prefix), label_file(label_file), + universal_label(universal_label), filter_threshold(filter_threshold) + { + } + + friend class IndexBuildParamsBuilder; +}; +class IndexBuildParamsBuilder +{ + public: + IndexBuildParamsBuilder(const diskann::IndexWriteParameters ¶s) : _index_write_params(paras){}; + + IndexBuildParamsBuilder &with_save_path_prefix(const std::string &save_path_prefix) + { + if (save_path_prefix.empty() || save_path_prefix == "") + throw ANNException("Error: save_path_prefix can't be empty", -1); + this->_save_path_prefix = save_path_prefix; + return *this; + } + + IndexBuildParamsBuilder &with_label_file(const std::string &label_file) + { + this->_label_file = label_file; + return *this; + } + + IndexBuildParamsBuilder &with_universal_label(const std::string &univeral_label) + { + this->_universal_label = univeral_label; + return *this; + } + + IndexBuildParamsBuilder &with_filter_threshold(const std::uint32_t &filter_threshold) + { + this->_filter_threshold = filter_threshold; + return *this; + } + + IndexBuildParams build() + { + return IndexBuildParams(_index_write_params, _save_path_prefix, _label_file, _universal_label, + _filter_threshold); + } + + IndexBuildParamsBuilder(const IndexBuildParamsBuilder &) = delete; + IndexBuildParamsBuilder &operator=(const IndexBuildParamsBuilder &) = delete; + + private: + diskann::IndexWriteParameters _index_write_params; + std::string _save_path_prefix; + std::string _label_file; + std::string _universal_label; + uint32_t _filter_threshold = 0; +}; +} // namespace diskann diff --git a/include/index_config.h b/include/index_config.h new file mode 100644 index 000000000..b291c744d --- /dev/null +++ b/include/index_config.h @@ -0,0 +1,224 @@ +#include "common_includes.h" +#include "parameters.h" + +namespace diskann +{ +enum DataStoreStrategy +{ + MEMORY +}; + +enum GraphStoreStrategy +{ +}; +struct IndexConfig +{ + DataStoreStrategy data_strategy; + GraphStoreStrategy graph_strategy; + + Metric metric; + size_t dimension; + size_t max_points; + + bool dynamic_index; + bool enable_tags; + bool pq_dist_build; + bool concurrent_consolidate; + bool use_opq; + + size_t num_pq_chunks; + size_t num_frozen_pts; + + std::string label_type; + std::string tag_type; + std::string data_type; + + std::shared_ptr index_write_params; + + uint32_t search_threads; + uint32_t initial_search_list_size; + + private: + IndexConfig(DataStoreStrategy data_strategy, GraphStoreStrategy graph_strategy, Metric metric, size_t dimension, + size_t max_points, size_t num_pq_chunks, size_t num_frozen_points, bool dynamic_index, bool enable_tags, + bool pq_dist_build, bool concurrent_consolidate, bool use_opq, const std::string &data_type, + const std::string &tag_type, const std::string &label_type, + std::shared_ptr index_write_params, uint32_t search_threads, + uint32_t initial_search_list_size) + : data_strategy(data_strategy), graph_strategy(graph_strategy), metric(metric), dimension(dimension), + max_points(max_points), dynamic_index(dynamic_index), enable_tags(enable_tags), pq_dist_build(pq_dist_build), + concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), num_pq_chunks(num_pq_chunks), + num_frozen_pts(num_frozen_points), label_type(label_type), tag_type(tag_type), data_type(data_type), + index_write_params(index_write_params), search_threads(search_threads), + initial_search_list_size(initial_search_list_size) + { + } + + friend class IndexConfigBuilder; +}; + +class IndexConfigBuilder +{ + public: + IndexConfigBuilder() + { + } + + IndexConfigBuilder &with_metric(Metric m) + { + this->_metric = m; + return *this; + } + + IndexConfigBuilder &with_graph_load_store_strategy(GraphStoreStrategy graph_strategy) + { + this->_graph_strategy = graph_strategy; + return *this; + } + + IndexConfigBuilder &with_data_load_store_strategy(DataStoreStrategy data_strategy) + { + this->_data_strategy = data_strategy; + return *this; + } + + IndexConfigBuilder &with_dimension(size_t dimension) + { + this->_dimension = dimension; + return *this; + } + + IndexConfigBuilder &with_max_points(size_t max_points) + { + this->_max_points = max_points; + return *this; + } + + IndexConfigBuilder &is_dynamic_index(bool dynamic_index) + { + this->_dynamic_index = dynamic_index; + return *this; + } + + IndexConfigBuilder &is_enable_tags(bool enable_tags) + { + this->_enable_tags = enable_tags; + return *this; + } + + IndexConfigBuilder &is_pq_dist_build(bool pq_dist_build) + { + this->_pq_dist_build = pq_dist_build; + return *this; + } + + IndexConfigBuilder &is_concurrent_consolidate(bool concurrent_consolidate) + { + this->_concurrent_consolidate = concurrent_consolidate; + return *this; + } + + IndexConfigBuilder &is_use_opq(bool use_opq) + { + this->_use_opq = use_opq; + return *this; + } + + IndexConfigBuilder &with_num_pq_chunks(size_t num_pq_chunks) + { + this->_num_pq_chunks = num_pq_chunks; + return *this; + } + + IndexConfigBuilder &with_num_frozen_pts(size_t num_frozen_pts) + { + this->_num_frozen_pts = num_frozen_pts; + return *this; + } + + IndexConfigBuilder &with_label_type(const std::string &label_type) + { + this->_label_type = label_type; + return *this; + } + + IndexConfigBuilder &with_tag_type(const std::string &tag_type) + { + this->_tag_type = tag_type; + return *this; + } + + IndexConfigBuilder &with_data_type(const std::string &data_type) + { + this->_data_type = data_type; + return *this; + } + + IndexConfigBuilder &with_index_write_params(IndexWriteParameters &index_write_params) + { + this->_index_write_params = std::make_shared(index_write_params); + return *this; + } + + IndexConfigBuilder &with_search_threads(uint32_t search_threads) + { + this->_search_threads = search_threads; + return *this; + } + + IndexConfigBuilder &with_initial_search_list_size(uint32_t search_list_size) + { + this->_initial_search_list_size = search_list_size; + return *this; + } + + IndexConfig build() + { + if (_data_type == "" || _data_type.empty()) + throw ANNException("Error: data_type can not be empty", -1); + + if (_dynamic_index && _index_write_params != nullptr) + { + if (_search_threads == 0) + throw ANNException("Error: please pass search_threads for building dynamic index.", -1); + + if (_initial_search_list_size == 0) + throw ANNException("Error: please pass initial_search_list_size for building dynamic index.", -1); + } + + return IndexConfig(_data_strategy, _graph_strategy, _metric, _dimension, _max_points, _num_pq_chunks, + _num_frozen_pts, _dynamic_index, _enable_tags, _pq_dist_build, _concurrent_consolidate, + _use_opq, _data_type, _tag_type, _label_type, _index_write_params, _search_threads, + _initial_search_list_size); + } + + IndexConfigBuilder(const IndexConfigBuilder &) = delete; + IndexConfigBuilder &operator=(const IndexConfigBuilder &) = delete; + + private: + DataStoreStrategy _data_strategy; + GraphStoreStrategy _graph_strategy; + + Metric _metric; + size_t _dimension; + size_t _max_points; + + bool _dynamic_index = false; + bool _enable_tags = false; + bool _pq_dist_build = false; + bool _concurrent_consolidate = false; + bool _use_opq = false; + + size_t _num_pq_chunks = 0; + size_t _num_frozen_pts = 0; + + std::string _label_type = "uint32"; + std::string _tag_type = "uint32"; + std::string _data_type; + + std::shared_ptr _index_write_params; + + uint32_t _search_threads; + uint32_t _initial_search_list_size; +}; +} // namespace diskann diff --git a/include/index_factory.h b/include/index_factory.h new file mode 100644 index 000000000..3d1eb7992 --- /dev/null +++ b/include/index_factory.h @@ -0,0 +1,37 @@ +#include "index.h" +#include "abstract_graph_store.h" +#include "in_mem_graph_store.h" + +namespace diskann +{ +class IndexFactory +{ + public: + DISKANN_DLLEXPORT explicit IndexFactory(const IndexConfig &config); + DISKANN_DLLEXPORT std::unique_ptr create_instance(); + + private: + void check_config(); + + template + std::unique_ptr> construct_datastore(DataStoreStrategy stratagy, size_t num_points, + size_t dimension); + + std::unique_ptr construct_graphstore(GraphStoreStrategy stratagy, size_t size); + + template + std::unique_ptr create_instance(); + + std::unique_ptr create_instance(const std::string &data_type, const std::string &tag_type, + const std::string &label_type); + + template + std::unique_ptr create_instance(const std::string &tag_type, const std::string &label_type); + + template + std::unique_ptr create_instance(const std::string &label_type); + + std::unique_ptr _config; +}; + +} // namespace diskann diff --git a/include/logger.h b/include/logger.h index 28a9f619c..0b17807db 100644 --- a/include/logger.h +++ b/include/logger.h @@ -6,10 +6,21 @@ #include #include "windows_customizations.h" +#ifdef EXEC_ENV_OLS +#ifndef ENABLE_CUSTOM_LOGGER +#define ENABLE_CUSTOM_LOGGER +#endif // !ENABLE_CUSTOM_LOGGER +#endif // EXEC_ENV_OLS + namespace diskann { +#ifdef ENABLE_CUSTOM_LOGGER DISKANN_DLLEXPORT extern std::basic_ostream cout; DISKANN_DLLEXPORT extern std::basic_ostream cerr; +#else +using std::cerr; +using std::cout; +#endif enum class DISKANN_DLLEXPORT LogLevel { @@ -18,7 +29,7 @@ enum class DISKANN_DLLEXPORT LogLevel LL_Count }; -#ifdef EXEC_ENV_OLS +#ifdef ENABLE_CUSTOM_LOGGER DISKANN_DLLEXPORT void SetCustomLogger(std::function logger); #endif } // namespace diskann diff --git a/include/logger_impl.h b/include/logger_impl.h index af6108c06..03c65e0ce 100644 --- a/include/logger_impl.h +++ b/include/logger_impl.h @@ -11,6 +11,7 @@ namespace diskann { +#ifdef ENABLE_CUSTOM_LOGGER class ANNStreamBuf : public std::basic_streambuf { public: @@ -36,30 +37,25 @@ class ANNStreamBuf : public std::basic_streambuf int flush(); void logImpl(char *str, int numchars); -// Why the two buffer-sizes? If we are running normally, we are basically -// interacting with a character output system, so we short-circuit the -// output process by keeping an empty buffer and writing each character -// to stdout/stderr. But if we are running in OLS, we have to take all -// the text that is written to diskann::cout/diskann:cerr, consolidate it -// and push it out in one-shot, because the OLS infra does not give us -// character based output. Therefore, we use a larger buffer that is large -// enough to store the longest message, and continuously add characters -// to it. When the calling code outputs a std::endl or std::flush, sync() -// will be called and will output a log level, component name, and the text -// that has been collected. (sync() is also called if the buffer is full, so -// overflows/missing text are not a concern). -// This implies calling code _must_ either print std::endl or std::flush -// to ensure that the message is written immediately. -#ifdef EXEC_ENV_OLS + // Why the two buffer-sizes? If we are running normally, we are basically + // interacting with a character output system, so we short-circuit the + // output process by keeping an empty buffer and writing each character + // to stdout/stderr. But if we are running in OLS, we have to take all + // the text that is written to diskann::cout/diskann:cerr, consolidate it + // and push it out in one-shot, because the OLS infra does not give us + // character based output. Therefore, we use a larger buffer that is large + // enough to store the longest message, and continuously add characters + // to it. When the calling code outputs a std::endl or std::flush, sync() + // will be called and will output a log level, component name, and the text + // that has been collected. (sync() is also called if the buffer is full, so + // overflows/missing text are not a concern). + // This implies calling code _must_ either print std::endl or std::flush + // to ensure that the message is written immediately. + static const int BUFFER_SIZE = 1024; -#else - // Allocating an arbitrarily small buffer here because the overflow() and - // other function implementations push the BUFFER_SIZE chars into the - // buffer before flushing to fwrite. - static const int BUFFER_SIZE = 4; -#endif ANNStreamBuf(const ANNStreamBuf &); ANNStreamBuf &operator=(const ANNStreamBuf &); }; +#endif } // namespace diskann diff --git a/include/natural_number_map.h b/include/natural_number_map.h index dd4b09e41..820ac3fdf 100644 --- a/include/natural_number_map.h +++ b/include/natural_number_map.h @@ -7,7 +7,7 @@ #include #include -#include "boost_dynamic_bitset_fwd.h" +#include namespace diskann { @@ -35,9 +35,9 @@ template class natural_number_map struct position { size_t _key; - // The number of keys that were enumerated when iterating through the map - // so far. Used to early-terminate enumeration when ithere - // are no more entries in the map. + // The number of keys that were enumerated when iterating through the + // map so far. Used to early-terminate enumeration when ithere are no + // more entries in the map. size_t _keys_already_enumerated; // Returns whether it's valid to access the element at this position in diff --git a/include/parameters.h b/include/parameters.h index e24681e7b..81a336da7 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -6,91 +6,112 @@ #include #include +#include "omp.h" +#include "defaults.h" + namespace diskann { -class Parameters +class IndexWriteParameters + { public: - Parameters() + const uint32_t search_list_size; // L + const uint32_t max_degree; // R + const bool saturate_graph; + const uint32_t max_occlusion_size; // C + const float alpha; + const uint32_t num_threads; + const uint32_t filter_list_size; // Lf + const uint32_t num_frozen_points; + + private: + IndexWriteParameters(const uint32_t search_list_size, const uint32_t max_degree, const bool saturate_graph, + const uint32_t max_occlusion_size, const float alpha, const uint32_t num_threads, + const uint32_t filter_list_size, const uint32_t num_frozen_points) + : search_list_size(search_list_size), max_degree(max_degree), saturate_graph(saturate_graph), + max_occlusion_size(max_occlusion_size), alpha(alpha), num_threads(num_threads), + filter_list_size(filter_list_size), num_frozen_points(num_frozen_points) { - int *p = new int; - *p = 0; - params["num_threads"] = p; } - template inline void Set(const std::string &name, const ParamType &value) + friend class IndexWriteParametersBuilder; +}; + +class IndexWriteParametersBuilder +{ + /** + * Fluent builder pattern to keep track of the 7 non-default properties + * and their order. The basic ctor was getting unwieldy. + */ + public: + IndexWriteParametersBuilder(const uint32_t search_list_size, // L + const uint32_t max_degree // R + ) + : _search_list_size(search_list_size), _max_degree(max_degree) { - // ParamType *ptr = (ParamType *) malloc(sizeof(ParamType)); - if (params.find(name) != params.end()) - { - free(params[name]); - } - ParamType *ptr = new ParamType; - *ptr = value; - params[name] = (void *)ptr; } - template inline ParamType Get(const std::string &name) const + IndexWriteParametersBuilder &with_max_occlusion_size(const uint32_t max_occlusion_size) { - auto item = params.find(name); - if (item == params.end()) - { - throw std::invalid_argument("Invalid parameter name."); - } - else - { - // return ConvertStrToValue(item->second); - if (item->second == nullptr) - { - throw std::invalid_argument(std::string("Parameter ") + name + " has value null."); - } - else - { - return *(static_cast(item->second)); - } - } + _max_occlusion_size = max_occlusion_size; + return *this; } - template inline ParamType Get(const std::string &name, const ParamType &default_value) const + IndexWriteParametersBuilder &with_saturate_graph(const bool saturate_graph) { - try - { - return Get(name); - } - catch (std::invalid_argument e) - { - return default_value; - } + _saturate_graph = saturate_graph; + return *this; } - ~Parameters() + IndexWriteParametersBuilder &with_alpha(const float alpha) { - for (auto iter = params.begin(); iter != params.end(); iter++) - { - if (iter->second != nullptr) - free(iter->second); - // delete iter->second; - } + _alpha = alpha; + return *this; } - private: - std::unordered_map params; + IndexWriteParametersBuilder &with_num_threads(const uint32_t num_threads) + { + _num_threads = num_threads == 0 ? omp_get_num_threads() : num_threads; + return *this; + } - Parameters(const Parameters &); - Parameters &operator=(const Parameters &); + IndexWriteParametersBuilder &with_filter_list_size(const uint32_t filter_list_size) + { + _filter_list_size = filter_list_size == 0 ? _search_list_size : filter_list_size; + return *this; + } + + IndexWriteParametersBuilder &with_num_frozen_points(const uint32_t num_frozen_points) + { + _num_frozen_points = num_frozen_points; + return *this; + } - template inline ParamType ConvertStrToValue(const std::string &str) const + IndexWriteParameters build() const { - std::stringstream sstream(str); - ParamType value; - if (!(sstream >> value) || !sstream.eof()) - { - std::stringstream err; - err << "Failed to convert value '" << str << "' to type: " << typeid(value).name(); - throw std::runtime_error(err.str()); - } - return value; + return IndexWriteParameters(_search_list_size, _max_degree, _saturate_graph, _max_occlusion_size, _alpha, + _num_threads, _filter_list_size, _num_frozen_points); } + + IndexWriteParametersBuilder(const IndexWriteParameters &wp) + : _search_list_size(wp.search_list_size), _max_degree(wp.max_degree), + _max_occlusion_size(wp.max_occlusion_size), _saturate_graph(wp.saturate_graph), _alpha(wp.alpha), + _filter_list_size(wp.filter_list_size), _num_frozen_points(wp.num_frozen_points) + { + } + IndexWriteParametersBuilder(const IndexWriteParametersBuilder &) = delete; + IndexWriteParametersBuilder &operator=(const IndexWriteParametersBuilder &) = delete; + + private: + uint32_t _search_list_size{}; + uint32_t _max_degree{}; + uint32_t _max_occlusion_size{defaults::MAX_OCCLUSION_SIZE}; + bool _saturate_graph{defaults::SATURATE_GRAPH}; + float _alpha{defaults::ALPHA}; + uint32_t _num_threads{defaults::NUM_THREADS}; + uint32_t _filter_list_size{defaults::FILTER_LIST_SIZE}; + uint32_t _num_frozen_points{defaults::NUM_FROZEN_POINTS_STATIC}; }; + } // namespace diskann diff --git a/include/pq.h b/include/pq.h index 9f6e50ee4..acfa1b30a 100644 --- a/include/pq.h +++ b/include/pq.h @@ -10,17 +10,17 @@ #define MAX_OPQ_ITERS 20 #define NUM_KMEANS_REPS_PQ 12 #define MAX_PQ_TRAINING_SET_SIZE 256000 -#define MAX_PQ_CHUNKS 256 +#define MAX_PQ_CHUNKS 512 namespace diskann { class FixedChunkPQTable { float *tables = nullptr; // pq_tables = float array of size [256 * ndims] - _u64 ndims = 0; // ndims = true dimension of vectors - _u64 n_chunks = 0; + uint64_t ndims = 0; // ndims = true dimension of vectors + uint64_t n_chunks = 0; bool use_rotation = false; - _u32 *chunk_offsets = nullptr; + uint32_t *chunk_offsets = nullptr; float *centroid = nullptr; float *tables_tr = nullptr; // same as pq_tables, but col-major float *rotmat_tr = nullptr; @@ -36,19 +36,19 @@ class FixedChunkPQTable void load_pq_centroid_bin(const char *pq_table_file, size_t num_chunks); #endif - _u32 get_num_chunks(); + uint32_t get_num_chunks(); void preprocess_query(float *query_vec); // assumes pre-processed query void populate_chunk_distances(const float *query_vec, float *dist_vec); - float l2_distance(const float *query_vec, _u8 *base_vec); + float l2_distance(const float *query_vec, uint8_t *base_vec); - float inner_product(const float *query_vec, _u8 *base_vec); + float inner_product(const float *query_vec, uint8_t *base_vec); // assumes no rotation is involved - void inflate_vector(_u8 *base_vec, float *out_vec); + void inflate_vector(uint8_t *base_vec, float *out_vec); void populate_chunk_inner_products(const float *query_vec, float *dist_vec); }; @@ -57,16 +57,17 @@ template struct PQScratch { float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS] float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE - _u8 *aligned_pq_coord_scratch = nullptr; // MUST BE AT LEAST [N_CHUNKS * MAX_DEGREE] + uint8_t *aligned_pq_coord_scratch = nullptr; // MUST BE AT LEAST [N_CHUNKS * MAX_DEGREE] float *rotated_query = nullptr; float *aligned_query_float = nullptr; PQScratch(size_t graph_degree, size_t aligned_dim) { diskann::alloc_aligned((void **)&aligned_pq_coord_scratch, - (_u64)graph_degree * (_u64)MAX_PQ_CHUNKS * sizeof(_u8), 256); - diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (_u64)MAX_PQ_CHUNKS * sizeof(float), 256); - diskann::alloc_aligned((void **)&aligned_dist_scratch, (_u64)graph_degree * sizeof(float), 256); + (size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256); + diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float), + 256); + diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256); diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float)); diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float)); @@ -86,15 +87,16 @@ template struct PQScratch } }; -void aggregate_coords(const std::vector &ids, const _u8 *all_coords, const _u64 ndims, _u8 *out); +void aggregate_coords(const std::vector &ids, const uint8_t *all_coords, const uint64_t ndims, uint8_t *out); -void pq_dist_lookup(const _u8 *pq_ids, const _u64 n_pts, const _u64 pq_nchunks, const float *pq_dists, +void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists, std::vector &dists_out); // Need to replace calls to these with calls to vector& based functions above -void aggregate_coords(const unsigned *ids, const _u64 n_ids, const _u8 *all_coords, const _u64 ndims, _u8 *out); +void aggregate_coords(const unsigned *ids, const uint64_t n_ids, const uint8_t *all_coords, const uint64_t ndims, + uint8_t *out); -void pq_dist_lookup(const _u8 *pq_ids, const _u64 n_pts, const _u64 pq_nchunks, const float *pq_dists, +void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists, float *dists_out); DISKANN_DLLEXPORT int generate_pq_pivots(const float *const train_data, size_t num_train, unsigned dim, @@ -106,17 +108,18 @@ DISKANN_DLLEXPORT int generate_opq_pivots(const float *train_data, size_t num_tr bool make_zero_mean = false); template -int generate_pq_data_from_pivots(const std::string data_file, unsigned num_centers, unsigned num_pq_chunks, - std::string pq_pivots_path, std::string pq_compressed_vectors_path, +int generate_pq_data_from_pivots(const std::string &data_file, unsigned num_centers, unsigned num_pq_chunks, + const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path, bool use_opq = false); template -void generate_disk_quantized_data(const std::string data_file_to_use, const std::string disk_pq_pivots_path, - const std::string disk_pq_compressed_vectors_path, +void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, + const std::string &disk_pq_compressed_vectors_path, const diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims); template -void generate_quantized_data(const std::string data_file_to_use, const std::string pq_pivots_path, - const std::string pq_compressed_vectors_path, const diskann::Metric compareMetric, - const double p_val, const size_t num_pq_chunks, const bool use_opq); +void generate_quantized_data(const std::string &data_file_to_use, const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, const diskann::Metric compareMetric, + const double p_val, const uint64_t num_pq_chunks, const bool use_opq, + const std::string &codebook_prefix = ""); } // namespace diskann diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 807658837..ba76cd47e 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -38,55 +38,60 @@ template class PQFlashIndex #ifdef EXEC_ENV_OLS DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_filepath, const char *pivots_filepath, - const char *compressed_filepath); + const char *compressed_filepath, const char* labels_filepath, const char* labels_to_medoids_filepath, + const char* labels_map_filepath, const char* unv_label_filepath); #else DISKANN_DLLEXPORT int load_from_separate_paths(uint32_t num_threads, const char *index_filepath, - const char *pivots_filepath, const char *compressed_filepath); + const char *pivots_filepath, const char *compressed_filepath, + const char *labels_filepath, const char *labels_to_medoids_filepath, + const char *labels_map_filepath, const char* unv_label_filepath); #endif DISKANN_DLLEXPORT void load_cache_list(std::vector &node_list); #ifdef EXEC_ENV_OLS DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(MemoryMappedFiles &files, std::string sample_bin, - _u64 l_search, _u64 beamwidth, - _u64 num_nodes_to_cache, uint32_t nthreads, + uint64_t l_search, uint64_t beamwidth, + uint64_t num_nodes_to_cache, uint32_t nthreads, std::vector &node_list); #else - DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(std::string sample_bin, _u64 l_search, - _u64 beamwidth, _u64 num_nodes_to_cache, + DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(std::string sample_bin, uint64_t l_search, + uint64_t beamwidth, uint64_t num_nodes_to_cache, uint32_t num_threads, std::vector &node_list); #endif - DISKANN_DLLEXPORT void cache_bfs_levels(_u64 num_nodes_to_cache, std::vector &node_list, + DISKANN_DLLEXPORT void cache_bfs_levels(uint64_t num_nodes_to_cache, std::vector &node_list, const bool shuffle = false); - DISKANN_DLLEXPORT void cached_beam_search(const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids, - float *res_dists, const _u64 beam_width, + DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, + uint64_t *res_ids, float *res_dists, const uint64_t beam_width, const bool use_reorder_data = false, QueryStats *stats = nullptr); - DISKANN_DLLEXPORT void cached_beam_search(const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids, - float *res_dists, const _u64 beam_width, const bool use_filter, - const LabelT &filter_label, const bool use_reorder_data = false, - QueryStats *stats = nullptr); - - DISKANN_DLLEXPORT void cached_beam_search(const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids, - float *res_dists, const _u64 beam_width, const _u32 io_limit, + DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, + uint64_t *res_ids, float *res_dists, const uint64_t beam_width, + const bool use_filter, const LabelT &filter_label, const bool use_reorder_data = false, QueryStats *stats = nullptr); - DISKANN_DLLEXPORT void cached_beam_search(const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids, - float *res_dists, const _u64 beam_width, const bool use_filter, - const LabelT &filter_label, const _u32 io_limit, - const bool use_reorder_data = false, QueryStats *stats = nullptr); + DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, + uint64_t *res_ids, float *res_dists, const uint64_t beam_width, + const uint32_t io_limit, const bool use_reorder_data = false, + QueryStats *stats = nullptr); + + DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, + uint64_t *res_ids, float *res_dists, const uint64_t beam_width, + const bool use_filter, const LabelT &filter_label, + const uint32_t io_limit, const bool use_reorder_data = false, + QueryStats *stats = nullptr); DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label); - DISKANN_DLLEXPORT _u32 range_search(const T *query1, const double range, const _u64 min_l_search, - const _u64 max_l_search, std::vector<_u64> &indices, - std::vector &distances, const _u64 min_beam_width, - QueryStats *stats = nullptr); + DISKANN_DLLEXPORT uint32_t range_search(const T *query1, const double range, const uint64_t min_l_search, + const uint64_t max_l_search, std::vector &indices, + std::vector &distances, const uint64_t min_beam_width, + QueryStats *stats = nullptr); - DISKANN_DLLEXPORT _u64 get_data_dim(); + DISKANN_DLLEXPORT uint64_t get_data_dim(); std::shared_ptr &reader; @@ -94,16 +99,18 @@ template class PQFlashIndex protected: DISKANN_DLLEXPORT void use_medoids_data_as_centroids(); - DISKANN_DLLEXPORT void setup_thread_data(_u64 nthreads, _u64 visited_reserve = 4096); + DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096); DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); private: - DISKANN_DLLEXPORT inline bool point_has_label(_u32 point_id, _u32 label_id); + DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, uint32_t label_id); std::unordered_map load_label_map(const std::string &map_file); DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, size_t &num_pts_labels); - DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, _u32 &num_pts, _u32 &num_total_labels); + DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, uint32_t &num_pts, uint32_t &num_total_labels); DISKANN_DLLEXPORT inline int32_t get_filter_number(const LabelT &filter_label); + DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads); // index info // nhood of node `i` is in sector: [i / nnodes_per_sector] @@ -111,10 +118,10 @@ template class PQFlashIndex // nnbrs of node `i`: *(unsigned*) (buf) // nbrs of node `i`: ((unsigned*)buf) + 1 - _u64 max_node_len = 0, nnodes_per_sector = 0, max_degree = 0; + uint64_t max_node_len = 0, nnodes_per_sector = 0, max_degree = 0; // Data used for searching with re-order vectors - _u64 ndims_reorder_vecs = 0, reorder_data_start_sector = 0, nvecs_per_sector = 0; + uint64_t ndims_reorder_vecs = 0, reorder_data_start_sector = 0, nvecs_per_sector = 0; diskann::Metric metric = diskann::Metric::L2; @@ -123,25 +130,25 @@ template class PQFlashIndex float max_base_norm = 0.0f; // data info - _u64 num_points = 0; - _u64 num_frozen_points = 0; - _u64 frozen_location = 0; - _u64 data_dim = 0; - _u64 disk_data_dim = 0; // will be different from data_dim only if we use - // PQ for disk data (very large dimensionality) - _u64 aligned_dim = 0; - _u64 disk_bytes_per_point = 0; + uint64_t num_points = 0; + uint64_t num_frozen_points = 0; + uint64_t frozen_location = 0; + uint64_t data_dim = 0; + uint64_t disk_data_dim = 0; // will be different from data_dim only if we use + // PQ for disk data (very large dimensionality) + uint64_t aligned_dim = 0; + uint64_t disk_bytes_per_point = 0; std::string disk_index_file; - std::vector> node_visit_counter; + std::vector> node_visit_counter; // PQ data // n_chunks = # of chunks ndims is split into - // data: _u8 * n_chunks + // data: char * n_chunks // chunk_size = chunk size of each dimension chunk // pq_tables = float* [[2^8 * [chunk_size]] * n_chunks] - _u8 *data = nullptr; - _u64 n_chunks; + uint8_t *data = nullptr; + uint64_t n_chunks; FixedChunkPQTable pq_table; // distance comparator @@ -150,7 +157,7 @@ template class PQFlashIndex // for very large datasets: we use PQ even for the disk resident index bool use_disk_index_pq = false; - _u64 disk_pq_n_chunks = 0; + uint64_t disk_pq_n_chunks = 0; FixedChunkPQTable disk_pq_table; // medoid/start info @@ -167,32 +174,32 @@ template class PQFlashIndex // nhood_cache unsigned *nhood_cache_buf = nullptr; - tsl::robin_map<_u32, std::pair<_u32, _u32 *>> nhood_cache; + tsl::robin_map> nhood_cache; // coord_cache T *coord_cache_buf = nullptr; - tsl::robin_map<_u32, T *> coord_cache; + tsl::robin_map coord_cache; // thread-specific scratch ConcurrentQueue *> thread_data; - _u64 max_nthreads; + uint64_t max_nthreads; bool load_flag = false; bool count_visited_nodes = false; bool reorder_data_exists = false; - _u64 reoreder_data_offset = 0; + uint64_t reoreder_data_offset = 0; // filter support - _u32 *_pts_to_label_offsets = nullptr; - _u32 *_pts_to_labels = nullptr; + uint32_t *_pts_to_label_offsets = nullptr; + uint32_t *_pts_to_labels = nullptr; tsl::robin_set _labels; - std::unordered_map _filter_to_medoid_id; + std::unordered_map> _filter_to_medoid_ids; bool _use_universal_label = false; - _u32 _universal_filter_num = 0; + uint32_t _universal_filter_num; std::vector _filter_list; - tsl::robin_set<_u32> _dummy_pts; - tsl::robin_set<_u32> _has_dummy_pts; - tsl::robin_map<_u32, _u32> _dummy_to_real_map; - tsl::robin_map<_u32, std::vector<_u32>> _real_to_dummy_map; + tsl::robin_set _dummy_pts; + tsl::robin_set _has_dummy_pts; + tsl::robin_map _dummy_to_real_map; + tsl::robin_map> _real_to_dummy_map; std::unordered_map _label_map; #ifdef EXEC_ENV_OLS diff --git a/include/program_options_utils.hpp b/include/program_options_utils.hpp new file mode 100644 index 000000000..71077b7b2 --- /dev/null +++ b/include/program_options_utils.hpp @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include + +namespace program_options_utils +{ +const std::string make_program_description(const char *executable_name, const char *description) +{ + return std::string("\n") + .append(description) + .append("\n\n") + .append("Usage: ") + .append(executable_name) + .append(" [OPTIONS]"); +} + +// Required parameters +const char *DATA_TYPE_DESCRIPTION = "data type, one of {int8, uint8, float} - float is single precision (32 bit)"; +const char *DISTANCE_FUNCTION_DESCRIPTION = + "distance function {l2, mips, fast_l2, cosine}. 'fast l2' and 'mips' only support data_type float"; +const char *INDEX_PATH_PREFIX_DESCRIPTION = "Path prefix to the index, e.g. '/mnt/data/my_ann_index'"; +const char *RESULT_PATH_DESCRIPTION = + "Path prefix for saving results of the queries, e.g. '/mnt/data/query_file_X.bin'"; +const char *QUERY_FILE_DESCRIPTION = "Query file in binary format, e.g. '/mnt/data/query_file_X.bin'"; +const char *NUMBER_OF_RESULTS_DESCRIPTION = "Number of neighbors to be returned (K in the DiskANN white paper)"; +const char *SEARCH_LIST_DESCRIPTION = + "Size of search list to use. This value is the number of neighbor/distance pairs to keep in memory at the same " + "time while performing a query. This can also be described as the size of the working set at query time. This " + "must be greater than or equal to the number of results/neighbors to return (K in the white paper). Corresponds " + "to L in the DiskANN white paper."; +const char *INPUT_DATA_PATH = "Input data file in bin format. This is the file you want to build the index over. " + "File format: Shape of the vector followed by the vector of embeddings as binary data."; + +// Optional parameters +const char *FILTER_LABEL_DESCRIPTION = + "Filter to use when running a query. 'filter_label' and 'query_filters_file' are mutually exclusive."; +const char *FILTERS_FILE_DESCRIPTION = + "Filter file for Queries for Filtered Search. File format is text with one filter per line. File must " + "have exactly one filter OR the same number of filters as there are queries in the 'query_file'."; +const char *LABEL_TYPE_DESCRIPTION = + "Storage type of Labels {uint/uint32, ushort/uint16}, default value is uint which will consume memory 4 bytes per " + "filter. 'uint' is an alias for 'uint32' and 'ushort' is an alias for 'uint16'."; +const char *GROUND_TRUTH_FILE_DESCRIPTION = + "ground truth file for the queryset"; // what's the format, what's the requirements? does it need to include an + // entry for every item or just a small subset? I have so many questions about + // this file +const char *NUMBER_THREADS_DESCRIPTION = "Number of threads used for building index. Defaults to number of logical " + "processor cores on your this machine returned by omp_get_num_procs()"; +const char *FAIL_IF_RECALL_BELOW = + "Value between 0 (inclusive) and 100 (exclusive) indicating the recall tolerance percentage threshold before " + "program fails with a non-zero exit code. The default value of 0 means that the program will complete " + "successfully with any recall value. A non-zero value indicates the floor for acceptable recall values. If the " + "calculated recall value is below this threshold then the program will write out the results but return a non-zero " + "exit code as a signal that the recall was not acceptable."; // does it continue running or die immediately? Will I + // still get my results even if the return code is -1? + +const char *NUMBER_OF_NODES_TO_CACHE = "Number of BFS nodes around medoid(s) to cache. Default value: 0"; +const char *BEAMWIDTH = "Beamwidth for search. Set 0 to optimize internally. Default value: 2"; +const char *MAX_BUILD_DEGREE = "Maximum graph degree"; +const char *GRAPH_BUILD_COMPLEXITY = + "Size of the search working set during build time. This is the numer of neighbor/distance pairs to keep in memory " + "while building the index. Higher value results in a higher quality graph but it will take more time to build the " + "graph."; +const char *GRAPH_BUILD_ALPHA = "Alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for " + "denser graphs with lower diameter"; +const char *BUIlD_GRAPH_PQ_BYTES = "Number of PQ bytes to build the index; 0 for full precision build"; +const char *USE_OPQ = "Use Optimized Product Quantization (OPQ)."; +const char *LABEL_FILE = "Input label file in txt format for Filtered Index build. The file should contain comma " + "separated filters for each node with each line corresponding to a graph node"; +const char *UNIVERSAL_LABEL = + "Universal label, Use only in conjunction with label file for filtered index build. If a " + "graph node has all the labels against it, we can assign a special universal filter to the " + "point instead of comma separated filters for that point"; +const char *FILTERED_LBUILD = "Build complexity for filtered points, higher value results in better graphs"; + +} // namespace program_options_utils diff --git a/include/scratch.h b/include/scratch.h index 7a2cbb861..dd84c7f2f 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -6,7 +6,7 @@ #include #include "boost_dynamic_bitset_fwd.h" -//#include "boost/dynamic_bitset.hpp" +// #include "boost/dynamic_bitset.hpp" #include "tsl/robin_set.h" #include "tsl/robin_map.h" #include "tsl/sparse_map.h" @@ -21,12 +21,12 @@ // SSD Index related limits #define MAX_GRAPH_DEGREE 512 -#define MAX_N_CMPS 16384 -#define SECTOR_LEN (_u64)4096 +#define SECTOR_LEN (size_t)4096 #define MAX_N_SECTOR_READS 128 namespace diskann { + // // Scratch space for in-memory index based search // @@ -34,8 +34,9 @@ template class InMemQueryScratch { public: ~InMemQueryScratch(); - InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, - bool init_pq_scratch = false, size_t bitmask_size = 0); + // REFACTOR TODO: move all parameters to a new class. + InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim, + size_t alignment_factor, bool init_pq_scratch = false, size_t bitmask_size = 0); void resize_for_new_L(uint32_t new_search_l); void clear(); @@ -71,7 +72,7 @@ template class InMemQueryScratch { return _occlude_factor; } - inline tsl::robin_set &inserted_into_pool_rs() + inline tsl::robin_set &inserted_into_pool_rs() { return _inserted_into_pool_rs; } @@ -79,7 +80,7 @@ template class InMemQueryScratch { return *_inserted_into_pool_bs; } - inline std::vector &id_scratch() + inline std::vector &id_scratch() { return _id_scratch; } @@ -87,7 +88,7 @@ template class InMemQueryScratch { return _dist_scratch; } - inline tsl::robin_set &expanded_nodes_set() + inline tsl::robin_set &expanded_nodes_set() { return _expanded_nodes_set; } @@ -95,7 +96,7 @@ template class InMemQueryScratch { return _expanded_nghrs_vec; } - inline std::vector &occlude_list_output() + inline std::vector &occlude_list_output() { return _occlude_list_output; } @@ -129,7 +130,7 @@ template class InMemQueryScratch std::vector _occlude_factor; // Capacity initialized to 20L - tsl::robin_set _inserted_into_pool_rs; + tsl::robin_set _inserted_into_pool_rs; // Use a pointer here to allow for forward declaration of dynamic_bitset // in public headers to avoid making boost a dependency for clients @@ -137,17 +138,16 @@ template class InMemQueryScratch boost::dynamic_bitset<> *_inserted_into_pool_bs; // _id_scratch.size() must be > R*GRAPH_SLACK_FACTOR for iterate_to_fp - std::vector _id_scratch; + std::vector _id_scratch; // _dist_scratch must be > R*GRAPH_SLACK_FACTOR for iterate_to_fp // _dist_scratch should be at least the size of id_scratch std::vector _dist_scratch; // Buffers used in process delete, capacity increases as needed - tsl::robin_set _expanded_nodes_set; + tsl::robin_set _expanded_nodes_set; std::vector _expanded_nghrs_vec; - std::vector _occlude_list_output; - + std::vector _occlude_list_output; // bitmask buffer in searching time std::vector _query_label_bitmask; }; @@ -159,17 +159,16 @@ template class InMemQueryScratch template class SSDQueryScratch { public: - T *coord_scratch = nullptr; // MUST BE AT LEAST [MAX_N_CMPS * data_dim] - _u64 coord_idx = 0; // index of next [data_dim] scratch to use + T *coord_scratch = nullptr; // MUST BE AT LEAST [sizeof(T) * data_dim] char *sector_scratch = nullptr; // MUST BE AT LEAST [MAX_N_SECTOR_READS * SECTOR_LEN] - _u64 sector_idx = 0; // index of next [SECTOR_LEN] scratch to use + size_t sector_idx = 0; // index of next [SECTOR_LEN] scratch to use T *aligned_query_T = nullptr; PQScratch *_pq_scratch; - tsl::robin_set<_u64> visited; + tsl::robin_set visited; NeighborPriorityQueue retset; std::vector full_retset; diff --git a/include/timer.h b/include/timer.h index 963927bd7..5ddc3c857 100644 --- a/include/timer.h +++ b/include/timer.h @@ -27,7 +27,7 @@ class Timer float elapsed_seconds() const { - return (float)elapsed() / 1000000.0; + return (float)elapsed() / 1000000.0f; } std::string elapsed_seconds_for_step(const std::string &step) const diff --git a/include/types.h b/include/types.h new file mode 100644 index 000000000..b95848869 --- /dev/null +++ b/include/types.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include +#include +#include "any_wrappers.h" + +namespace diskann +{ +typedef uint32_t location_t; + +using DataType = std::any; +using TagType = std::any; +using LabelType = std::any; +using TagVector = AnyWrapper::AnyVector; +using DataVector = AnyWrapper::AnyVector; +using TagRobinSet = AnyWrapper::AnyRobinSet; +} // namespace diskann diff --git a/include/utils.h b/include/utils.h index a011aca1e..f81f6a68b 100644 --- a/include/utils.h +++ b/include/utils.h @@ -2,6 +2,7 @@ // Licensed under the MIT license. #pragma once + #include #include "common_includes.h" @@ -25,16 +26,14 @@ typedef int FileHandle; #include "ann_exception.h" #include "windows_customizations.h" #include "tsl/robin_set.h" +#include "types.h" +#include #ifdef EXEC_ENV_OLS #include "content_buf.h" #include "memory_mapped_files.h" #endif -#include -#include -#include - // taken from // https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h // round up X to the nearest multiple of Y @@ -53,7 +52,10 @@ typedef int FileHandle; 4096 // all metadata of individual sub-component files is written in first // 4KB for unified files -#define BUFFER_SIZE_FOR_CACHED_IO (_u64)1024 * (_u64)1048576 +#define BUFFER_SIZE_FOR_CACHED_IO (size_t)1024 * (size_t)1048576 + +#define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||" +#define PBWIDTH 60 inline bool file_exists(const std::string &name, bool dirCheck = false) { @@ -63,8 +65,8 @@ inline bool file_exists(const std::string &name, bool dirCheck = false) val = stat(name.c_str(), &buffer); #else // It is the 21st century but Windows API still thinks in 32-bit terms. - // Turns out calling stat() on a file > 4GB results in errno = 132 (OVERFLOW). - // How silly is this!? So calling _stat64() + // Turns out calling stat() on a file > 4GB results in errno = 132 + // (OVERFLOW). How silly is this!? So calling _stat64() struct _stat64 buffer; val = _stat64(name.c_str(), &buffer); #endif @@ -92,15 +94,6 @@ inline bool file_exists(const std::string &name, bool dirCheck = false) } } -typedef uint64_t _u64; -typedef int64_t _s64; -typedef uint32_t _u32; -typedef int32_t _s32; -typedef uint16_t _u16; -typedef int16_t _s16; -typedef uint8_t _u8; -typedef int8_t _s8; - inline void open_file_to_write(std::ofstream &writer, const std::string &filename) { writer.exceptions(std::ofstream::failbit | std::ofstream::badbit); @@ -113,21 +106,22 @@ inline void open_file_to_write(std::ofstream &writer, const std::string &filenam { char buff[1024]; #ifdef _WINDOWS - strerror_s(buff, 1024, errno); + auto ret = std::to_string(strerror_s(buff, 1024, errno)); #else - strerror_r(errno, buff, 1024); + auto ret = std::string(strerror_r(errno, buff, 1024)); #endif - diskann::cerr << std::string("Failed to open file") + filename + " for write because " + buff << std::endl; - throw diskann::ANNException(std::string("Failed to open file ") + filename + " for write because: " + buff, -1); + auto message = std::string("Failed to open file") + filename + " for write because " + buff + ", ret=" + ret; + diskann::cerr << message << std::endl; + throw diskann::ANNException(message, -1); } } -inline _u64 get_file_size(const std::string &fname) +inline size_t get_file_size(const std::string &fname) { std::ifstream reader(fname, std::ios::binary | std::ios::ate); if (!reader.fail() && reader.is_open()) { - _u64 end_pos = reader.tellg(); + size_t end_pos = reader.tellg(); reader.close(); return end_pos; } @@ -146,7 +140,8 @@ inline int delete_file(const std::string &fileName) if (rc != 0) { diskann::cerr << "Could not delete file: " << fileName - << " even though it exists. This might indicate a permissions issue. " + << " even though it exists. This might indicate a permissions " + "issue. " "If you see this message, please contact the diskann team." << std::endl; } @@ -160,9 +155,9 @@ inline int delete_file(const std::string &fileName) inline void convert_labels_string_to_int(const std::string &inFileName, const std::string &outFileName, const std::string &mapFileName, const std::string &unv_label, - _u32& unv_label_id) + uint32_t& unv_label_id) { - std::unordered_map string_int_map; + std::unordered_map string_int_map; std::ofstream label_writer(outFileName); std::ifstream label_reader(inFileName); //if (unv_label != "") @@ -171,14 +166,14 @@ inline void convert_labels_string_to_int(const std::string &inFileName, const st while (std::getline(label_reader, line)) { std::istringstream new_iss(line); - std::vector<_u32> lbls; + std::vector lbls; while (getline(new_iss, token, ',')) { token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); if (string_int_map.find(token) == string_int_map.end()) { - _u32 nextId = (_u32)string_int_map.size() + 1; + uint32_t nextId = (uint32_t)string_int_map.size() + 1; string_int_map[token] = nextId; } lbls.push_back(string_int_map[token]); @@ -349,6 +344,26 @@ inline void get_bin_metadata(const std::string &bin_file, size_t &nrows, size_t } // get_bin_metadata functions END +#ifndef EXEC_ENV_OLS +inline size_t get_graph_num_frozen_points(const std::string &graph_file) +{ + size_t expected_file_size; + uint32_t max_observed_degree, start; + size_t file_frozen_pts; + + std::ifstream in; + in.exceptions(std::ios::badbit | std::ios::failbit); + + in.open(graph_file, std::ios::binary); + in.read((char *)&expected_file_size, sizeof(size_t)); + in.read((char *)&max_observed_degree, sizeof(uint32_t)); + in.read((char *)&start, sizeof(uint32_t)); + in.read((char *)&file_frozen_pts, sizeof(size_t)); + + return file_frozen_pts; +} +#endif + template inline std::string getValues(T *data, size_t num) { std::stringstream stream; @@ -448,7 +463,7 @@ inline void wait_for_keystroke() inline void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim) { - _u64 read_blk_size = 64 * 1024 * 1024; + size_t read_blk_size = 64 * 1024 * 1024; cached_ifstream reader(bin_file, read_blk_size); diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl; size_t actual_file_size = reader.get_file_size(); @@ -496,9 +511,9 @@ inline void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&d } inline void prune_truthset_for_range(const std::string &bin_file, float range, - std::vector> &groundtruth, size_t &npts) + std::vector> &groundtruth, size_t &npts) { - _u64 read_blk_size = 64 * 1024 * 1024; + size_t read_blk_size = 64 * 1024 * 1024; cached_ifstream reader(bin_file, read_blk_size); diskann::cout << "Reading truthset file " << bin_file.c_str() << "... " << std::endl; size_t actual_file_size = reader.get_file_size(); @@ -507,8 +522,8 @@ inline void prune_truthset_for_range(const std::string &bin_file, float range, reader.read((char *)&npts_i32, sizeof(int)); reader.read((char *)&dim_i32, sizeof(int)); npts = (unsigned)npts_i32; - _u64 dim = (unsigned)dim_i32; - _u32 *ids; + uint64_t dim = (unsigned)dim_i32; + uint32_t *ids; float *dists; diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; @@ -542,10 +557,10 @@ inline void prune_truthset_for_range(const std::string &bin_file, float range, float min_dist = std::numeric_limits::max(); float max_dist = 0; groundtruth.resize(npts); - for (_u32 i = 0; i < npts; i++) + for (uint32_t i = 0; i < npts; i++) { groundtruth[i].clear(); - for (_u32 j = 0; j < dim; j++) + for (uint32_t j = 0; j < dim; j++) { if (dists[i * dim + j] <= range) { @@ -561,23 +576,24 @@ inline void prune_truthset_for_range(const std::string &bin_file, float range, delete[] dists; } -inline void load_range_truthset(const std::string &bin_file, std::vector> &groundtruth, _u64 >_num) +inline void load_range_truthset(const std::string &bin_file, std::vector> &groundtruth, + uint64_t >_num) { - _u64 read_blk_size = 64 * 1024 * 1024; + size_t read_blk_size = 64 * 1024 * 1024; cached_ifstream reader(bin_file, read_blk_size); diskann::cout << "Reading truthset file " << bin_file.c_str() << "... " << std::flush; size_t actual_file_size = reader.get_file_size(); - int npts_u32, total_u32; - reader.read((char *)&npts_u32, sizeof(int)); - reader.read((char *)&total_u32, sizeof(int)); + int nptsuint32_t, totaluint32_t; + reader.read((char *)&nptsuint32_t, sizeof(int)); + reader.read((char *)&totaluint32_t, sizeof(int)); - gt_num = (_u64)npts_u32; - _u64 total_res = (_u64)total_u32; + gt_num = (uint64_t)nptsuint32_t; + uint64_t total_res = (uint64_t)totaluint32_t; diskann::cout << "Metadata: #pts = " << gt_num << ", #total_results = " << total_res << "..." << std::endl; - size_t expected_file_size = 2 * sizeof(_u32) + gt_num * sizeof(_u32) + total_res * sizeof(_u32); + size_t expected_file_size = 2 * sizeof(uint32_t) + gt_num * sizeof(uint32_t) + total_res * sizeof(uint32_t); if (actual_file_size != expected_file_size) { @@ -589,26 +605,26 @@ inline void load_range_truthset(const std::string &bin_file, std::vector gt_count(gt_num); + std::vector gt_count(gt_num); - reader.read((char *)gt_count.data(), sizeof(_u32) * gt_num); + reader.read((char *)gt_count.data(), sizeof(uint32_t) * gt_num); - std::vector<_u32> gt_stats(gt_count); + std::vector gt_stats(gt_count); std::sort(gt_stats.begin(), gt_stats.end()); std::cout << "GT count percentiles:" << std::endl; - for (_u32 p = 0; p < 100; p += 5) + for (uint32_t p = 0; p < 100; p += 5) std::cout << "percentile " << p << ": " << gt_stats[static_cast(std::floor((p / 100.0) * gt_num))] << std::endl; std::cout << "percentile 100" << ": " << gt_stats[gt_num - 1] << std::endl; - for (_u32 i = 0; i < gt_num; i++) + for (uint32_t i = 0; i < gt_num; i++) { groundtruth[i].clear(); groundtruth[i].resize(gt_count[i]); if (gt_count[i] != 0) - reader.read((char *)groundtruth[i].data(), sizeof(_u32) * gt_count[i]); + reader.read((char *)groundtruth[i].data(), sizeof(uint32_t) * gt_count[i]); } } @@ -645,8 +661,8 @@ DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_s const tsl::robin_set &active_tags); DISKANN_DLLEXPORT double calculate_range_search_recall(unsigned num_queries, - std::vector> &groundtruth, - std::vector> &our_results); + std::vector> &groundtruth, + std::vector> &our_results); template inline void load_bin(const std::string &bin_file, std::unique_ptr &data, size_t &npts, size_t &dim, @@ -669,18 +685,19 @@ inline void open_file_to_write(std::ofstream &writer, const std::string &filenam { char buff[1024]; #ifdef _WINDOWS - strerror_s(buff, 1024, errno); + auto ret = std::to_string(strerror_s(buff, 1024, errno)); #else - strerror_r(errno, buff, 1024); + auto ret = std::string(strerror_r(errno, buff, 1024)); #endif - std::string error_message = std::string("Failed to open file") + filename + " for write because " + buff; + std::string error_message = + std::string("Failed to open file") + filename + " for write because " + buff + ", ret=" + ret; diskann::cerr << error_message << std::endl; throw diskann::ANNException(error_message, -1); } } template -inline uint64_t save_bin(const std::string &filename, T *data, size_t npts, size_t ndims, size_t offset = 0) +inline size_t save_bin(const std::string &filename, T *data, size_t npts, size_t ndims, size_t offset = 0) { std::ofstream writer; open_file_to_write(writer, filename); @@ -699,6 +716,16 @@ inline uint64_t save_bin(const std::string &filename, T *data, size_t npts, size diskann::cout << "Finished writing bin." << std::endl; return bytes_written; } + +inline void print_progress(double percentage) +{ + int val = (int)(percentage * 100); + int lpad = (int)(percentage * PBWIDTH); + int rpad = PBWIDTH - lpad; + printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, ""); + fflush(stdout); +} + // load_aligned_bin functions START template @@ -784,7 +811,7 @@ template void convert_types(const InType *srcmat, OutType *destmat, size_t npts, size_t dim) { #pragma omp parallel for schedule(static, 65536) - for (int64_t i = 0; i < (_s64)npts; i++) + for (int64_t i = 0; i < (int64_t)npts; i++) { for (uint64_t j = 0; j < dim; j++) { @@ -806,17 +833,17 @@ template float prepare_base_for_inner_products(const std::string in std::cout << "Pre-processing base file by adding extra coordinate" << std::endl; std::ifstream in_reader(in_file.c_str(), std::ios::binary); std::ofstream out_writer(out_file.c_str(), std::ios::binary); - _u64 npts, in_dims, out_dims; + uint64_t npts, in_dims, out_dims; float max_norm = 0; - _u32 npts32, dims32; + uint32_t npts32, dims32; in_reader.read((char *)&npts32, sizeof(uint32_t)); in_reader.read((char *)&dims32, sizeof(uint32_t)); npts = npts32; in_dims = dims32; out_dims = in_dims + 1; - _u32 outdims32 = (_u32)out_dims; + uint32_t outdims32 = (uint32_t)out_dims; out_writer.write((char *)&npts32, sizeof(uint32_t)); out_writer.write((char *)&outdims32, sizeof(uint32_t)); @@ -827,19 +854,19 @@ template float prepare_base_for_inner_products(const std::string in std::unique_ptr out_block_data = std::make_unique(block_size * out_dims); std::memset(out_block_data.get(), 0, sizeof(float) * block_size * out_dims); - _u64 num_blocks = DIV_ROUND_UP(npts, block_size); + uint64_t num_blocks = DIV_ROUND_UP(npts, block_size); std::vector norms(npts, 0); - for (_u64 b = 0; b < num_blocks; b++) + for (uint64_t b = 0; b < num_blocks; b++) { - _u64 start_id = b * block_size; - _u64 end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts; - _u64 block_pts = end_id - start_id; + uint64_t start_id = b * block_size; + uint64_t end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts; + uint64_t block_pts = end_id - start_id; in_reader.read((char *)in_block_data.get(), block_pts * in_dims * sizeof(T)); - for (_u64 p = 0; p < block_pts; p++) + for (uint64_t p = 0; p < block_pts; p++) { - for (_u64 j = 0; j < in_dims; j++) + for (uint64_t j = 0; j < in_dims; j++) { norms[start_id + p] += in_block_data[p * in_dims + j] * in_block_data[p * in_dims + j]; } @@ -849,16 +876,16 @@ template float prepare_base_for_inner_products(const std::string in max_norm = std::sqrt(max_norm); - in_reader.seekg(2 * sizeof(_u32), std::ios::beg); - for (_u64 b = 0; b < num_blocks; b++) + in_reader.seekg(2 * sizeof(uint32_t), std::ios::beg); + for (uint64_t b = 0; b < num_blocks; b++) { - _u64 start_id = b * block_size; - _u64 end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts; - _u64 block_pts = end_id - start_id; + uint64_t start_id = b * block_size; + uint64_t end_id = (b + 1) * block_size < npts ? (b + 1) * block_size : npts; + uint64_t block_pts = end_id - start_id; in_reader.read((char *)in_block_data.get(), block_pts * in_dims * sizeof(T)); - for (_u64 p = 0; p < block_pts; p++) + for (uint64_t p = 0; p < block_pts; p++) { - for (_u64 j = 0; j < in_dims; j++) + for (uint64_t j = 0; j < in_dims; j++) { out_block_data[p * out_dims + j] = in_block_data[p * in_dims + j] / max_norm; } @@ -883,7 +910,7 @@ template void save_Tvecs(const char *filename, T *data, size_t npts unsigned dims_u32 = (unsigned)ndims; // start writing - for (uint64_t i = 0; i < npts; i++) + for (size_t i = 0; i < npts; i++) { // write dims in u32 writer.write((char *)&dims_u32, sizeof(unsigned)); @@ -894,13 +921,13 @@ template void save_Tvecs(const char *filename, T *data, size_t npts } } template -inline uint64_t save_data_in_base_dimensions(const std::string &filename, T *data, size_t npts, size_t ndims, - size_t aligned_dim, size_t offset = 0) +inline size_t save_data_in_base_dimensions(const std::string &filename, T *data, size_t npts, size_t ndims, + size_t aligned_dim, size_t offset = 0) { std::ofstream writer; //(filename, std::ios::binary | std::ios::out); open_file_to_write(writer, filename); int npts_i32 = (int)npts, ndims_i32 = (int)ndims; - _u64 bytes_written = 2 * sizeof(uint32_t) + npts * ndims * sizeof(T); + size_t bytes_written = 2 * sizeof(uint32_t) + npts * ndims * sizeof(T); writer.seekp(offset, writer.beg); writer.write((char *)&npts_i32, sizeof(int)); writer.write((char *)&ndims_i32, sizeof(int)); @@ -958,7 +985,7 @@ inline void prefetch_vector_l2(const char *vec, size_t vecsize) } // NOTE: Implementation in utils.cpp. -void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, _u64 npts, _u64 ndims); +void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, uint64_t npts, uint64_t ndims); DISKANN_DLLEXPORT void normalize_data_file(const std::string &inFileName, const std::string &outFileName); @@ -1006,21 +1033,130 @@ inline bool validate_index_file_size(std::ifstream &in) return true; } -// This function is valid only for float data type. -template inline void normalize(T *arr, size_t dim) +template inline float get_norm(T *arr, const size_t dim) { float sum = 0.0f; for (uint32_t i = 0; i < dim; i++) { sum += arr[i] * arr[i]; } - sum = sqrt(sum); + return sqrt(sum); +} + +// This function is valid only for float data type. +template inline void normalize(T *arr, const size_t dim) +{ + float norm = get_norm(arr, dim); for (uint32_t i = 0; i < dim; i++) { - arr[i] = (T)(arr[i] / sum); + arr[i] = (T)(arr[i] / norm); } } +inline std::vector read_file_to_vector_of_strings(const std::string &filename, bool unique = false) +{ + std::vector result; + std::set elementSet; + if (filename != "") + { + std::ifstream file(filename); + if (file.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + filename, -1); + } + std::string line; + while (std::getline(file, line)) + { + if (line.empty()) + { + break; + } + if (line.find(',') != std::string::npos) + { + std::cerr << "Every query must have exactly one filter" << std::endl; + exit(-1); + } + if (!line.empty() && (line.back() == '\r' || line.back() == '\n')) + { + line.erase(line.size() - 1); + } + if (!elementSet.count(line)) + { + result.push_back(line); + } + if (unique) + { + elementSet.insert(line); + } + } + file.close(); + } + else + { + throw diskann::ANNException(std::string("Failed to open file. filename can not be blank"), -1); + } + return result; +} + +inline void clean_up_artifacts(tsl::robin_set paths_to_clean, tsl::robin_set path_suffixes) +{ + try + { + for (const auto &path : paths_to_clean) + { + for (const auto &suffix : path_suffixes) + { + std::string curr_path_to_clean(path + "_" + suffix); + if (std::remove(curr_path_to_clean.c_str()) != 0) + diskann::cout << "Warning: Unable to remove file :" << curr_path_to_clean << std::endl; + } + } + diskann::cout << "Cleaned all artifacts" << std::endl; + } + catch (const std::exception &e) + { + diskann::cout << "Warning: Unable to clean all artifacts " << e.what() << std::endl; + } +} + +template inline const char *diskann_type_to_name() = delete; +template <> inline const char *diskann_type_to_name() +{ + return "float"; +} +template <> inline const char *diskann_type_to_name() +{ + return "uint8"; +} +template <> inline const char *diskann_type_to_name() +{ + return "int8"; +} +template <> inline const char *diskann_type_to_name() +{ + return "uint16"; +} +template <> inline const char *diskann_type_to_name() +{ + return "int16"; +} +template <> inline const char *diskann_type_to_name() +{ + return "uint32"; +} +template <> inline const char *diskann_type_to_name() +{ + return "int32"; +} +template <> inline const char *diskann_type_to_name() +{ + return "uint64"; +} +template <> inline const char *diskann_type_to_name() +{ + return "int64"; +} + #ifdef _WINDOWS #include #include diff --git a/pyproject.toml b/pyproject.toml index 107d56886..fb4349fab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,39 +3,54 @@ requires = [ "setuptools>=59.6", "pybind11>=2.10.0", "cmake>=3.22", - "numpy>=1.21", + "numpy==1.25", # this is important to keep fixed. It also means anyone using something other than 1.25 won't be able to use this library "wheel", + "ninja" ] build-backend = "setuptools.build_meta" [project] name = "diskannpy" -version = "0.4.0" +version = "0.6.0" description = "DiskANN Python extension module" -# readme = "../README.md" -requires-python = ">=3.7" +readme = "python/README.md" +requires-python = ">=3.9" license = {text = "MIT License"} dependencies = [ - "numpy" + "numpy==1.25" ] authors = [ {name = "Harsha Vardhan Simhadri", email = "harshasi@microsoft.com"}, {name = "Dax Pryce", email = "daxpryce@microsoft.com"} ] +[project.optional-dependencies] +dev = ["black", "isort", "mypy"] + +[tool.setuptools] +package-dir = {"" = "python/src"} + +[tool.isort] +profile = "black" +multi_line_output = 3 + +[tool.mypy] +plugins = "numpy.typing.mypy_plugin" + [tool.cibuildwheel] -manylinux-x86_64-image = "manylinux_2_24" +manylinux-x86_64-image = "manylinux_2_28" +test-requires = ["scikit-learn~=1.2"] build-frontend = "build" -skip = "pp* *musllinux*" -test-command = "python -m unittest discover -s {package}/tests" - +skip = ["pp*", "*-win32", "*-manylinux_i686", "*-musllinux*"] +test-command = "python -m unittest discover {project}/python/tests" [tool.cibuildwheel.linux] -before-all = """\ - apt-get update && \ - apt-get -y upgrade && \ - apt-get install -y wget make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev && \ - wget https://registrationcenter-download.intel.com/akdlm/irc_nas/18487/l_BaseKit_p_2022.1.2.146.sh && \ - sh l_BaseKit_p_2022.1.2.146.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s --ignore-errors \ -""" +before-build = [ + "dnf makecache --refresh", + "dnf install -y epel-release", + "dnf config-manager -y --add-repo https://yum.repos.intel.com/mkl/setup/intel-mkl.repo", + "rpm --import https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB", + "dnf makecache --refresh -y", + "dnf install -y wget make cmake gcc-c++ libaio-devel gperftools-libs libunwind-devel clang-tools-extra boost-devel boost-program-options intel-mkl-2020.4-912" +] diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index d93d80422..d4faebf9b 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.18...3.22) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) if (PYTHON_EXECUTABLE) set(Python3_EXECUTABLE ${PYTHON_EXECUTABLE}) @@ -26,14 +26,22 @@ execute_process(COMMAND ${Python3_EXECUTABLE} -c "import numpy; print(numpy.get_ # pybind11_add_module(diskannpy MODULE src/diskann_bindings.cpp) # the following is fairly synonymous with pybind11_add_module, but we need more target_link_libraries # see https://pybind11.readthedocs.io/en/latest/compiling.html#advanced-interface-library-targets for more details -add_library(diskannpy MODULE src/diskann_bindings.cpp) +add_library(_diskannpy MODULE + src/module.cpp + src/builder.cpp + src/dynamic_memory_index.cpp + src/static_memory_index.cpp + src/static_disk_index.cpp +) + +target_include_directories(_diskannpy AFTER PRIVATE include) if (MSVC) - target_compile_options(diskannpy PRIVATE /U_WINDLL) + target_compile_options(_diskannpy PRIVATE /U_WINDLL) endif() target_link_libraries( - diskannpy + _diskannpy PRIVATE pybind11::module pybind11::lto @@ -43,13 +51,13 @@ target_link_libraries( ${DISKANN_ASYNC_LIB} ) -pybind11_extension(diskannpy) +pybind11_extension(_diskannpy) if(NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES Debug|RelWithDebInfo) # Strip unnecessary sections of the binary on Linux/macOS - pybind11_strip(diskannpy) + pybind11_strip(_diskannpy) endif() -set_target_properties(diskannpy PROPERTIES CXX_VISIBILITY_PRESET "hidden" +set_target_properties(_diskannpy PROPERTIES CXX_VISIBILITY_PRESET "hidden" CUDA_VISIBILITY_PRESET "hidden") # generally, the VERSION_INFO flag is set by pyproject.toml, by way of setup.py. @@ -59,4 +67,4 @@ set_target_properties(diskannpy PROPERTIES CXX_VISIBILITY_PRESET "hidden" if(NOT VERSION_INFO) set(VERSION_INFO "0.0.0dev") endif() -target_compile_definitions(diskannpy PRIVATE VERSION_INFO="${VERSION_INFO}") +target_compile_definitions(_diskannpy PRIVATE VERSION_INFO="${VERSION_INFO}") diff --git a/python/README.md b/python/README.md new file mode 100644 index 000000000..1365fb422 --- /dev/null +++ b/python/README.md @@ -0,0 +1,55 @@ +# diskannpy + +[![DiskANN Paper](https://img.shields.io/badge/Paper-NeurIPS%3A_DiskANN-blue)](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf) +[![DiskANN Paper](https://img.shields.io/badge/Paper-Arxiv%3A_Fresh--DiskANN-blue)](https://arxiv.org/abs/2105.09613) +[![DiskANN Paper](https://img.shields.io/badge/Paper-Filtered--DiskANN-blue)](https://harsha-simhadri.org/pubs/Filtered-DiskANN23.pdf) +[![DiskANN Main](https://github.com/microsoft/DiskANN/actions/workflows/push-test.yml/badge.svg?branch=main)](https://github.com/microsoft/DiskANN/actions/workflows/push-test.yml) +[![PyPI version](https://img.shields.io/pypi/v/diskannpy.svg)](https://pypi.org/project/diskannpy/) +[![Downloads shield](https://pepy.tech/badge/diskannpy)](https://pepy.tech/project/diskannpy) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) + +## Installation +Packages published to PyPI will always be built using the latest numpy major.minor release (at this time, 1.25). + +Conda distributions for versions 1.19-1.25 will be completed as a future effort. In the meantime, feel free to +clone this repository and build it yourself. + +## Local Build Instructions +Please see the [Project README](https://github.com/microsoft/DiskANN/blob/main/README.md) for system dependencies and requirements. + +After ensuring you've followed the directions to build the project library and executables, you will be ready to also +build `diskannpy` with these additional instructions. + +### Changing Numpy Version +In the root folder of DiskANN, there is a file `pyproject.toml`. You will need to edit the version of numpy in both the +`[build-system.requires]` section, as well as the `[project.dependencies]` section. The version numbers must match. + +#### Linux +```bash +python3.11 -m venv venv # versions from python3.9 and up should work +source venv/bin/activate +pip install build +python -m build +``` + +#### Windows +```powershell +py -3.11 -m venv venv # versions from python3.9 and up should work +venv\Scripts\Activate.ps1 +pip install build +python -m build +``` + +The built wheel will be placed in the `dist` directory in your DiskANN root. Install it using `pip install dist/.whl` + +## Citations +Please cite this software in your work as: +``` +@misc{diskann-github, + author = {Simhadri, Harsha Vardhan and Krishnaswamy, Ravishankar and Srinivasa, Gopal and Subramanya, Suhas Jayaram and Antonijevic, Andrija and Pryce, Dax and Kaczynski, David and Williams, Shane and Gollapudi, Siddarth and Sivashankar, Varun and Karia, Neel and Singh, Aditi and Jaiswal, Shikhar and Mahapatro, Neelam and Adams, Philip and Tower, Bryan and Patel, Yash}}, + title = {{DiskANN: Graph-structured Indices for Scalable, Fast, Fresh and Filtered Approximate Nearest Neighbor Search}}, + url = {https://github.com/Microsoft/DiskANN}, + version = {0.6.0}, + year = {2023} +} +``` diff --git a/python/apps/cli/__main__.py b/python/apps/cli/__main__.py new file mode 100644 index 000000000..d2c999052 --- /dev/null +++ b/python/apps/cli/__main__.py @@ -0,0 +1,152 @@ +import diskannpy as dap +import numpy as np +import numpy.typing as npt + +import fire + +from contextlib import contextmanager +from time import perf_counter + +from typing import Tuple + + +def _basic_setup( + dtype: str, + query_vectors_file: str +) -> Tuple[dap.VectorDType, npt.NDArray[dap.VectorDType]]: + _dtype = dap.valid_dtype(dtype) + vectors_to_query = dap.vectors_from_binary(query_vectors_file, dtype=_dtype) + return _dtype, vectors_to_query + + +def dynamic( + dtype: str, + index_vectors_file: str, + query_vectors_file: str, + build_complexity: int, + graph_degree: int, + K: int, + search_complexity: int, + num_insert_threads: int, + num_search_threads: int, + gt_file: str = "", +): + _dtype, vectors_to_query = _basic_setup(dtype, query_vectors_file) + vectors_to_index = dap.vectors_from_binary(index_vectors_file, dtype=_dtype) + + npts, ndims = vectors_to_index.shape + index = dap.DynamicMemoryIndex( + "l2", _dtype, ndims, npts, build_complexity, graph_degree + ) + + tags = np.arange(1, npts+1, dtype=np.uintc) + timer = Timer() + + with timer.time("batch insert"): + index.batch_insert(vectors_to_index, tags, num_insert_threads) + + delete_tags = np.random.choice( + np.array(range(1, npts + 1, 1), dtype=np.uintc), + size=int(0.5 * npts), + replace=False + ) + with timer.time("mark deletion"): + for tag in delete_tags: + index.mark_deleted(tag) + + with timer.time("consolidation"): + index.consolidate_delete() + + deleted_data = vectors_to_index[delete_tags - 1, :] + + with timer.time("re-insertion"): + index.batch_insert(deleted_data, delete_tags, num_insert_threads) + + with timer.time("batch searched"): + tags, dists = index.batch_search(vectors_to_query, K, search_complexity, num_search_threads) + + # res_ids = tags - 1 + # if gt_file != "": + # recall = utils.calculate_recall_from_gt_file(K, res_ids, gt_file) + # print(f"recall@{K} is {recall}") + +def static( + dtype: str, + index_directory: str, + index_vectors_file: str, + query_vectors_file: str, + build_complexity: int, + graph_degree: int, + K: int, + search_complexity: int, + num_threads: int, + gt_file: str = "", + index_prefix: str = "ann" +): + _dtype, vectors_to_query = _basic_setup(dtype, query_vectors_file) + timer = Timer() + with timer.time("build static index"): + # build index + dap.build_memory_index( + data=index_vectors_file, + metric="l2", + vector_dtype=_dtype, + index_directory=index_directory, + complexity=build_complexity, + graph_degree=graph_degree, + num_threads=num_threads, + index_prefix=index_prefix, + alpha=1.2, + use_pq_build=False, + num_pq_bytes=8, + use_opq=False, + ) + + with timer.time("load static index"): + # ready search object + index = dap.StaticMemoryIndex( + metric="l2", + vector_dtype=_dtype, + data_path=index_vectors_file, + index_directory=index_directory, + num_threads=num_threads, # this can be different at search time if you would like + initial_search_complexity=search_complexity, + index_prefix=index_prefix + ) + + ids, dists = index.batch_search(vectors_to_query, K, search_complexity, num_threads) + + # if gt_file != "": + # recall = utils.calculate_recall_from_gt_file(K, ids, gt_file) + # print(f"recall@{K} is {recall}") + +def dynamic_clustered(): + pass + +def generate_clusters(): + pass + + +class Timer: + def __init__(self): + self._start = -1 + + @contextmanager + def time(self, message: str): + start = perf_counter() + if self._start == -1: + self._start = start + yield + now = perf_counter() + print(f"Operation {message} completed in {(now - start):.3f}s, total: {(now - self._start):.3f}s") + + + + +if __name__ == "__main__": + fire.Fire({ + "in-mem-dynamic": dynamic, + "in-mem-static": static, + "in-mem-dynamic-clustered": dynamic_clustered, + "generate-clusters": generate_clusters + }, name="cli") diff --git a/python/apps/cluster.py b/python/apps/cluster.py new file mode 100644 index 000000000..27a34bb70 --- /dev/null +++ b/python/apps/cluster.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import argparse +import utils + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="cluster", description="kmeans cluster points in a file" + ) + + parser.add_argument("-d", "--data_type", required=True) + parser.add_argument("-i", "--indexdata_file", required=True) + parser.add_argument("-k", "--num_clusters", type=int, required=True) + args = parser.parse_args() + + npts, ndims = get_bin_metadata(indexdata_file) + + data = utils.bin_to_numpy(args.data_type, args.indexdata_file) + + offsets, permutation = utils.cluster_and_permute( + args.data_type, npts, ndims, data, args.num_clusters + ) + + permuted_data = data[permutation] + + utils.numpy_to_bin(permuted_data, args.indexdata_file + ".cluster") diff --git a/python/apps/in-mem-dynamic.py b/python/apps/in-mem-dynamic.py new file mode 100644 index 000000000..f97e1313f --- /dev/null +++ b/python/apps/in-mem-dynamic.py @@ -0,0 +1,161 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import argparse + +import diskannpy +import numpy as np +import utils + +def insert_and_search( + dtype_str, + indexdata_file, + querydata_file, + Lb, + graph_degree, + K, + Ls, + num_insert_threads, + num_search_threads, + gt_file, +) -> dict[str, float]: + """ + + :param dtype_str: + :param indexdata_file: + :param querydata_file: + :param Lb: + :param graph_degree: + :param K: + :param Ls: + :param num_insert_threads: + :param num_search_threads: + :param gt_file: + :return: Dictionary of timings. Key is the event and value is the number of seconds the event took + """ + timer_results: dict[str, float] = {} + + method_timer: utils.Timer = utils.Timer() + + npts, ndims = utils.get_bin_metadata(indexdata_file) + + if dtype_str == "float": + dtype = np.float32 + elif dtype_str == "int8": + dtype = np.int8 + elif dtype_str == "uint8": + dtype = np.uint8 + else: + raise ValueError("data_type must be float, int8 or uint8") + + index = diskannpy.DynamicMemoryIndex( + distance_metric="l2", + vector_dtype=dtype, + dimensions=ndims, + max_vectors=npts, + complexity=Lb, + graph_degree=graph_degree + ) + queries = diskannpy.vectors_from_file(querydata_file, dtype) + data = diskannpy.vectors_from_file(indexdata_file, dtype) + + tags = np.zeros(npts, dtype=np.uintc) + timer = utils.Timer() + for i in range(npts): + tags[i] = i + 1 + index.batch_insert(data, tags, num_insert_threads) + compute_seconds = timer.elapsed() + print('batch_insert complete in', compute_seconds, 's') + timer_results["batch_insert_seconds"] = compute_seconds + + delete_tags = np.random.choice( + np.array(range(1, npts + 1, 1), dtype=np.uintc), + size=int(0.5 * npts), + replace=False + ) + + timer.reset() + for tag in delete_tags: + index.mark_deleted(tag) + compute_seconds = timer.elapsed() + timer_results['mark_deletion_seconds'] = compute_seconds + print('mark deletion completed in', compute_seconds, 's') + + timer.reset() + index.consolidate_delete() + compute_seconds = timer.elapsed() + print('consolidation completed in', compute_seconds, 's') + timer_results['consolidation_completed_seconds'] = compute_seconds + + deleted_data = data[delete_tags - 1, :] + + timer.reset() + index.batch_insert(deleted_data, delete_tags, num_insert_threads) + compute_seconds = timer.elapsed() + print('re-insertion completed in', compute_seconds, 's') + timer_results['re-insertion_seconds'] = compute_seconds + + timer.reset() + tags, dists = index.batch_search(queries, K, Ls, num_search_threads) + compute_seconds = timer.elapsed() + print('Batch searched', queries.shape[0], ' queries in ', compute_seconds, 's') + timer_results['batch_searched_seconds'] = compute_seconds + + res_ids = tags - 1 + if gt_file != "": + timer.reset() + recall = utils.calculate_recall_from_gt_file(K, res_ids, gt_file) + print(f"recall@{K} is {recall}") + timer_results['recall_computed_seconds'] = timer.elapsed() + + timer_results['total_time_seconds'] = method_timer.elapsed() + + return timer_results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="in-mem-dynamic", + description="Inserts points dynamically in a clustered order and search from vectors in a file.", + ) + + parser.add_argument("-d", "--data_type", required=True) + parser.add_argument("-i", "--indexdata_file", required=True) + parser.add_argument("-q", "--querydata_file", required=True) + parser.add_argument("-Lb", "--Lbuild", default=50, type=int) + parser.add_argument("-Ls", "--Lsearch", default=50, type=int) + parser.add_argument("-R", "--graph_degree", default=32, type=int) + parser.add_argument("-TI", "--num_insert_threads", default=8, type=int) + parser.add_argument("-TS", "--num_search_threads", default=8, type=int) + parser.add_argument("-K", default=10, type=int) + parser.add_argument("--gt_file", default="") + parser.add_argument("--json_timings_output", required=False, default=None, help="File to write out timings to as JSON. If not specified, timings will not be written out.") + args = parser.parse_args() + + timings = insert_and_search( + args.data_type, + args.indexdata_file, + args.querydata_file, + args.Lbuild, + args.graph_degree, # Build args + args.K, + args.Lsearch, + args.num_insert_threads, + args.num_search_threads, # search args + args.gt_file, + ) + + if args.json_timings_output is not None: + import json + timings['log_file'] = args.json_timings_output + with open(args.json_timings_output, "w") as f: + json.dump(timings, f) + +""" +An ingest optimized example with SIFT1M +source venv/bin/activate +python python/apps/in-mem-dynamic.py -d float \ +-i "$HOME/data/sift/sift_base.fbin" -q "$HOME/data/sift/sift_query.fbin" --gt_file "$HOME/data/sift/gt100_base" \ +-Lb 10 -R 30 -Ls 200 +""" + diff --git a/python/apps/in-mem-static.py b/python/apps/in-mem-static.py new file mode 100644 index 000000000..9fb9a2cce --- /dev/null +++ b/python/apps/in-mem-static.py @@ -0,0 +1,149 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import argparse +from xml.dom.pulldom import default_bufsize + +import diskannpy +import numpy as np +import utils + +def build_and_search( + metric, + dtype_str, + index_directory, + indexdata_file, + querydata_file, + Lb, + graph_degree, + K, + Ls, + num_threads, + gt_file, + index_prefix, + search_only +) -> dict[str, float]: + """ + + :param metric: + :param dtype_str: + :param index_directory: + :param indexdata_file: + :param querydata_file: + :param Lb: + :param graph_degree: + :param K: + :param Ls: + :param num_threads: + :param gt_file: + :param index_prefix: + :param search_only: + :return: Dictionary of timings. Key is the event and value is the number of seconds the event took + in wall-clock-time. + """ + timer_results: dict[str, float] = {} + + method_timer: utils.Timer = utils.Timer() + + if dtype_str == "float": + dtype = np.single + elif dtype_str == "int8": + dtype = np.byte + elif dtype_str == "uint8": + dtype = np.ubyte + else: + raise ValueError("data_type must be float, int8 or uint8") + + # build index + if not search_only: + build_index_timer = utils.Timer() + diskannpy.build_memory_index( + data=indexdata_file, + distance_metric=metric, + vector_dtype=dtype, + index_directory=index_directory, + complexity=Lb, + graph_degree=graph_degree, + num_threads=num_threads, + index_prefix=index_prefix, + alpha=1.2, + use_pq_build=False, + num_pq_bytes=8, + use_opq=False, + ) + timer_results["build_index_seconds"] = build_index_timer.elapsed() + + # ready search object + load_index_timer = utils.Timer() + index = diskannpy.StaticMemoryIndex( + distance_metric=metric, + vector_dtype=dtype, + index_directory=index_directory, + num_threads=num_threads, # this can be different at search time if you would like + initial_search_complexity=Ls, + index_prefix=index_prefix + ) + timer_results["load_index_seconds"] = load_index_timer.elapsed() + + queries = utils.bin_to_numpy(dtype, querydata_file) + + query_timer = utils.Timer() + ids, dists = index.batch_search(queries, 10, Ls, num_threads) + query_time = query_timer.elapsed() + qps = round(queries.shape[0]/query_time, 1) + print('Batch searched', queries.shape[0], 'in', query_time, 's @', qps, 'QPS') + timer_results["query_seconds"] = query_time + + if gt_file != "": + recall_timer = utils.Timer() + recall = utils.calculate_recall_from_gt_file(K, ids, gt_file) + print(f"recall@{K} is {recall}") + timer_results["recall_seconds"] = recall_timer.elapsed() + + timer_results['total_time_seconds'] = method_timer.elapsed() + + return timer_results + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="in-mem-static", + description="Static in-memory build and search from vectors in a file", + ) + + parser.add_argument("-m", "--metric", required=False, default="l2") + parser.add_argument("-d", "--data_type", required=True) + parser.add_argument("-id", "--index_directory", required=False, default=".") + parser.add_argument("-i", "--indexdata_file", required=True) + parser.add_argument("-q", "--querydata_file", required=True) + parser.add_argument("-Lb", "--Lbuild", default=50, type=int) + parser.add_argument("-Ls", "--Lsearch", default=50, type=int) + parser.add_argument("-R", "--graph_degree", default=32, type=int) + parser.add_argument("-T", "--num_threads", default=8, type=int) + parser.add_argument("-K", default=10, type=int) + parser.add_argument("-G", "--gt_file", default="") + parser.add_argument("-ip", "--index_prefix", required=False, default="ann") + parser.add_argument("--search_only", required=False, default=False) + parser.add_argument("--json_timings_output", required=False, default=None, help="File to write out timings to as JSON. If not specified, timings will not be written out.") + args = parser.parse_args() + + timings: dict[str, float] = build_and_search( + args.metric, + args.data_type, + args.index_directory.strip(), + args.indexdata_file.strip(), + args.querydata_file.strip(), + args.Lbuild, + args.graph_degree, # Build args + args.K, + args.Lsearch, + args.num_threads, # search args + args.gt_file, + args.index_prefix, + args.search_only + ) + + if args.json_timings_output is not None: + import json + timings['log_file'] = args.json_timings_output + with open(args.json_timings_output, "w") as f: + json.dump(timings, f) diff --git a/python/apps/insert-in-clustered-order.py b/python/apps/insert-in-clustered-order.py new file mode 100644 index 000000000..25cb9d53c --- /dev/null +++ b/python/apps/insert-in-clustered-order.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import argparse + +import diskannpy +import numpy as np +import utils + + +def insert_and_search( + dtype_str, + indexdata_file, + querydata_file, + Lb, + graph_degree, + num_clusters, + num_insert_threads, + K, + Ls, + num_search_threads, + gt_file, +): + npts, ndims = utils.get_bin_metadata(indexdata_file) + + if dtype_str == "float": + dtype = np.float32 + elif dtype_str == "int8": + dtype = np.int8 + elif dtype_str == "uint8": + dtype = np.uint8 + else: + raise ValueError("data_type must be float, int8 or uint8") + + index = diskannpy.DynamicMemoryIndex( + distance_metric="l2", + vector_dtype=dtype, + dimensions=ndims, + max_vectors=npts, + complexity=Lb, + graph_degree=graph_degree + ) + queries = diskannpy.vectors_from_file(querydata_file, dtype) + data = diskannpy.vectors_from_file(indexdata_file, dtype) + + offsets, permutation = utils.cluster_and_permute( + dtype_str, npts, ndims, data, num_clusters + ) + + i = 0 + timer = utils.Timer() + for c in range(num_clusters): + cluster_index_range = range(offsets[c], offsets[c + 1]) + cluster_indices = np.array(permutation[cluster_index_range], dtype=np.uint32) + cluster_data = data[cluster_indices, :] + index.batch_insert(cluster_data, cluster_indices + 1, num_insert_threads) + print('Inserted cluster', c, 'in', timer.elapsed(), 's') + tags, dists = index.batch_search(queries, K, Ls, num_search_threads) + print('Batch searched', queries.shape[0], 'queries in', timer.elapsed(), 's') + res_ids = tags - 1 + + if gt_file != "": + recall = utils.calculate_recall_from_gt_file(K, res_ids, gt_file) + print(f"recall@{K} is {recall}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="in-mem-dynamic", + description="Inserts points dynamically in a clustered order and search from vectors in a file.", + ) + + parser.add_argument("-d", "--data_type", required=True) + parser.add_argument("-i", "--indexdata_file", required=True) + parser.add_argument("-q", "--querydata_file", required=True) + parser.add_argument("-Lb", "--Lbuild", default=50, type=int) + parser.add_argument("-Ls", "--Lsearch", default=50, type=int) + parser.add_argument("-R", "--graph_degree", default=32, type=int) + parser.add_argument("-TI", "--num_insert_threads", default=8, type=int) + parser.add_argument("-TS", "--num_search_threads", default=8, type=int) + parser.add_argument("-C", "--num_clusters", default=32, type=int) + parser.add_argument("-K", default=10, type=int) + parser.add_argument("--gt_file", default="") + args = parser.parse_args() + + insert_and_search( + args.data_type, + args.indexdata_file, + args.querydata_file, + args.Lbuild, + args.graph_degree, # Build args + args.num_clusters, + args.num_insert_threads, + args.K, + args.Lsearch, + args.num_search_threads, # search args + args.gt_file, + ) + +# An ingest optimized example with SIFT1M +# python3 ~/DiskANN/python/apps/insert-in-clustered-order.py -d float \ +# -i sift_base.fbin -q sift_query.fbin --gt_file gt100_base \ +# -Lb 10 -R 30 -Ls 200 -C 32 \ No newline at end of file diff --git a/python/apps/requirements.txt b/python/apps/requirements.txt new file mode 100644 index 000000000..87b4a72cc --- /dev/null +++ b/python/apps/requirements.txt @@ -0,0 +1,2 @@ +diskannpy +fire diff --git a/python/apps/utils.py b/python/apps/utils.py new file mode 100644 index 000000000..a52698470 --- /dev/null +++ b/python/apps/utils.py @@ -0,0 +1,120 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import numpy as np +from scipy.cluster.vq import vq, kmeans2 +from typing import Tuple +from time import perf_counter + + +def get_bin_metadata(bin_file) -> Tuple[int, int]: + array = np.fromfile(file=bin_file, dtype=np.uint32, count=2) + return array[0], array[1] + + +def bin_to_numpy(dtype, bin_file) -> np.ndarray: + npts, ndims = get_bin_metadata(bin_file) + return np.fromfile(file=bin_file, dtype=dtype, offset=8).reshape(npts, ndims) + + +class Timer: + last = perf_counter() + + def reset(self): + new = perf_counter() + self.last = new + + def elapsed(self, round_digit:int = 3): + new = perf_counter() + elapsed_time = new - self.last + self.last = new + return round(elapsed_time, round_digit) + + +def numpy_to_bin(array, out_file): + shape = np.shape(array) + npts = shape[0].astype(np.uint32) + ndims = shape[1].astype(np.uint32) + f = open(out_file, "wb") + f.write(npts.tobytes()) + f.write(ndims.tobytes()) + f.write(array.tobytes()) + f.close() + + +def read_gt_file(gt_file) -> Tuple[np.ndarray[int], np.ndarray[float]]: + """ + Return ids and distances to queries + """ + nq, K = get_bin_metadata(gt_file) + ids = np.fromfile(file=gt_file, dtype=np.uint32, offset=8, count=nq * K).reshape( + nq, K + ) + dists = np.fromfile( + file=gt_file, dtype=np.float32, offset=8 + nq * K * 4, count=nq * K + ).reshape(nq, K) + return ids, dists + + +def calculate_recall( + result_set_indices: np.ndarray[int], + truth_set_indices: np.ndarray[int], + recall_at: int = 5, +) -> float: + """ + result_set_indices and truth_set_indices correspond by row index. the columns in each row contain the indices of + the nearest neighbors, with result_set_indices being the approximate nearest neighbor results and truth_set_indices + being the brute force nearest neighbor calculation via sklearn's NearestNeighbor class. + :param result_set_indices: + :param truth_set_indices: + :param recall_at: + :return: + """ + found = 0 + for i in range(0, result_set_indices.shape[0]): + result_set_set = set(result_set_indices[i][0:recall_at]) + truth_set_set = set(truth_set_indices[i][0:recall_at]) + found += len(result_set_set.intersection(truth_set_set)) + return found / (result_set_indices.shape[0] * recall_at) + + +def calculate_recall_from_gt_file(K: int, ids: np.ndarray[int], gt_file: str) -> float: + """ + Calculate recall from ids returned from search and those read from file + """ + gt_ids, gt_dists = read_gt_file(gt_file) + return calculate_recall(ids, gt_ids, K) + + +def cluster_and_permute( + dtype_str, npts, ndims, data, num_clusters +) -> Tuple[np.ndarray[int], np.ndarray[int]]: + """ + Cluster the data and return permutation of row indices + that would group indices of the same cluster together + """ + sample_size = min(100000, npts) + sample_indices = np.random.choice(range(npts), size=sample_size, replace=False) + sampled_data = data[sample_indices, :] + centroids, sample_labels = kmeans2(sampled_data, num_clusters, minit="++", iter=10) + labels, dist = vq(data, centroids) + + count = np.zeros(num_clusters) + for i in range(npts): + count[labels[i]] += 1 + print("Cluster counts") + print(count) + + offsets = np.zeros(num_clusters + 1, dtype=int) + for i in range(0, num_clusters, 1): + offsets[i + 1] = offsets[i] + count[i] + + permutation = np.zeros(npts, dtype=int) + counters = np.zeros(num_clusters, dtype=int) + for i in range(npts): + label = labels[i] + row = offsets[label] + counters[label] + counters[label] += 1 + permutation[row] = i + + return offsets, permutation diff --git a/python/include/builder.h b/python/include/builder.h new file mode 100644 index 000000000..fc12976e7 --- /dev/null +++ b/python/include/builder.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +#include "common.h" +#include "distance.h" + +namespace diskannpy +{ +template +void build_disk_index(diskann::Metric metric, const std::string &data_file_path, const std::string &index_prefix_path, + uint32_t complexity, uint32_t graph_degree, double final_index_ram_limit, + double indexing_ram_budget, uint32_t num_threads, uint32_t pq_disk_bytes); + +template +void build_memory_index(diskann::Metric metric, const std::string &vector_bin_path, + const std::string &index_output_path, uint32_t graph_degree, uint32_t complexity, + float alpha, uint32_t num_threads, bool use_pq_build, + size_t num_pq_bytes, bool use_opq, uint32_t filter_complexity, + bool use_tags = false); + +} diff --git a/python/include/common.h b/python/include/common.h new file mode 100644 index 000000000..7c63534fa --- /dev/null +++ b/python/include/common.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +#include +#include + +namespace py = pybind11; + +namespace diskannpy +{ + +typedef uint32_t filterT; + +typedef uint32_t StaticIdType; +typedef uint32_t DynamicIdType; + +template using NeighborsAndDistances = std::pair, py::array_t>; + +}; // namespace diskannpy diff --git a/python/include/dynamic_memory_index.h b/python/include/dynamic_memory_index.h new file mode 100644 index 000000000..02d6b8cce --- /dev/null +++ b/python/include/dynamic_memory_index.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +#include +#include + +#include "common.h" +#include "index.h" +#include "parameters.h" + +namespace py = pybind11; + +namespace diskannpy +{ + +template +class DynamicMemoryIndex +{ + public: + DynamicMemoryIndex(diskann::Metric m, size_t dimensions, size_t max_vectors, uint32_t complexity, + uint32_t graph_degree, bool saturate_graph, uint32_t max_occlusion_size, float alpha, + uint32_t num_threads, uint32_t filter_complexity, uint32_t num_frozen_points, + uint32_t initial_search_complexity, uint32_t initial_search_threads, + bool concurrent_consolidation); + + void load(const std::string &index_path); + int insert(const py::array_t &vector, DynamicIdType id); + py::array_t batch_insert(py::array_t &vectors, + py::array_t &ids, int32_t num_inserts, + int num_threads = 0); + int mark_deleted(DynamicIdType id); + void save(const std::string &save_path, bool compact_before_save = false); + NeighborsAndDistances search(py::array_t &query, uint64_t knn, + uint64_t complexity); + NeighborsAndDistances batch_search(py::array_t &queries, + uint64_t num_queries, uint64_t knn, uint64_t complexity, + uint32_t num_threads); + void consolidate_delete(); + size_t num_points(); + + + private: + const uint32_t _initial_search_complexity; + const diskann::IndexWriteParameters _write_parameters; + diskann::Index _index; +}; + +}; // namespace diskannpy \ No newline at end of file diff --git a/python/include/static_disk_index.h b/python/include/static_disk_index.h new file mode 100644 index 000000000..71a1b5aff --- /dev/null +++ b/python/include/static_disk_index.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + + +#include +#include + +#ifdef _WINDOWS +#include "windows_aligned_file_reader.h" +#else +#include "linux_aligned_file_reader.h" +#endif + +#include "common.h" +#include "pq_flash_index.h" + +namespace py = pybind11; + +namespace diskannpy { + +#ifdef _WINDOWS +typedef WindowsAlignedFileReader PlatformSpecificAlignedFileReader; +#else +typedef LinuxAlignedFileReader PlatformSpecificAlignedFileReader; +#endif + +template +class StaticDiskIndex +{ + public: + StaticDiskIndex(diskann::Metric metric, const std::string &index_path_prefix, uint32_t num_threads, + size_t num_nodes_to_cache, uint32_t cache_mechanism); + + void cache_bfs_levels(size_t num_nodes_to_cache); + + void cache_sample_paths(size_t num_nodes_to_cache, const std::string &warmup_query_file, uint32_t num_threads); + + NeighborsAndDistances search(py::array_t &query, uint64_t knn, + uint64_t complexity, uint64_t beam_width); + + NeighborsAndDistances batch_search(py::array_t &queries, uint64_t num_queries, + uint64_t knn, uint64_t complexity, uint64_t beam_width, uint32_t num_threads); + private: + std::shared_ptr _reader; + diskann::PQFlashIndex
_index; +}; +} diff --git a/python/include/static_memory_index.h b/python/include/static_memory_index.h new file mode 100644 index 000000000..33f3187ae --- /dev/null +++ b/python/include/static_memory_index.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include +#include + +#include +#include + +#include "common.h" +#include "index.h" + +namespace py = pybind11; + +namespace diskannpy { + +template +class StaticMemoryIndex +{ + public: + StaticMemoryIndex(diskann::Metric m, const std::string &index_prefix, size_t num_points, + size_t dimensions, uint32_t num_threads, uint32_t initial_search_complexity); + + NeighborsAndDistances search(py::array_t &query, uint64_t knn, + uint64_t complexity); + + NeighborsAndDistances batch_search(py::array_t &queries, + uint64_t num_queries, uint64_t knn, uint64_t complexity, uint32_t num_threads); + private: + diskann::Index _index; +}; +} \ No newline at end of file diff --git a/python/src/__init__.py b/python/src/__init__.py new file mode 100644 index 000000000..c2e1b07f6 --- /dev/null +++ b/python/src/__init__.py @@ -0,0 +1,138 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +""" +# Documentation Overview +`diskannpy` is mostly structured around 2 distinct processes: [Index Builder Functions](#index-builders) and [Search Classes](#search-classes) + +It also includes a few nascent [utilities](#utilities). + +And lastly, it makes substantial use of type hints, with various shorthand [type aliases](#parameter-and-response-type-aliases) documented. +When reading the `diskannpy` code we refer to the type aliases, though `pdoc` helpfully expands them. + +## Index Builders +- `build_disk_index` - To build an index that cannot fully fit into memory when searching +- `build_memory_index` - To build an index that can fully fit into memory when searching + +## Search Classes +- `StaticMemoryIndex` - for indices that can fully fit in memory and won't be changed during the search operations +- `StaticDiskIndex` - for indices that cannot fully fit in memory, thus relying on disk IO to search, and also won't be changed during search operations +- `DynamicMemoryIndex` - for indices that can fully fit in memory and will be mutated via insert/deletion operations as well as search operations + +## Parameter Defaults +- `diskannpy.defaults` - Default values exported from the C++ extension for Python users + +## Parameter and Response Type Aliases +- `DistanceMetric` - What distance metrics does `diskannpy` support? +- `VectorDType` - What vector datatypes does `diskannpy` support? +- `QueryResponse` - What can I expect as a response to my search? +- `QueryResponseBatch` - What can I expect as a response to my batch search? +- `VectorIdentifier` - What types do `diskannpy` support as vector identifiers? +- `VectorIdentifierBatch` - A batch of identifiers of the exact same type. The type can change, but they must **all** change. +- `VectorLike` - How does a vector look to `diskannpy`, to be inserted or searched with. +- `VectorLikeBatch` - A batch of those vectors, to be inserted or searched with. +- `Metadata` - DiskANN vector binary file metadata (num_points, vector_dim) + +## Utilities +- `vectors_to_file` - Turns a 2 dimensional `numpy.typing.NDArray[VectorDType]` with shape `(number_of_points, vector_dim)` into a DiskANN vector bin file. +- `vectors_from_file` - Reads a DiskANN vector bin file representing stored vectors into a numpy ndarray. +- `vectors_metadata_from_file` - Reads metadata stored in a DiskANN vector bin file without reading the entire file +- `tags_to_file` - Turns a 1 dimensional `numpy.typing.NDArray[VectorIdentifier]` into a DiskANN tags bin file. +- `tags_from_file` - Reads a DiskANN tags bin file representing stored tags into a numpy ndarray. +- `valid_dtype` - Checks if a given vector dtype is supported by `diskannpy` +""" + +from typing import Any, Literal, NamedTuple, Type, Union + +import numpy as np +from numpy import typing as npt + +DistanceMetric = Literal["l2", "mips", "cosine"] +""" Type alias for one of {"l2", "mips", "cosine"} """ +VectorDType = Union[Type[np.float32], Type[np.int8], Type[np.uint8]] +""" Type alias for one of {`numpy.float32`, `numpy.int8`, `numpy.uint8`} """ +VectorLike = npt.NDArray[VectorDType] +""" Type alias for something that can be treated as a vector """ +VectorLikeBatch = npt.NDArray[VectorDType] +""" Type alias for a batch of VectorLikes """ +VectorIdentifier = np.uint32 +""" +Type alias for a vector identifier, whether it be an implicit array index identifier from StaticMemoryIndex or +StaticDiskIndex, or an explicit tag identifier from DynamicMemoryIndex +""" +VectorIdentifierBatch = npt.NDArray[np.uint32] +""" Type alias for a batch of VectorIdentifiers """ + + +class QueryResponse(NamedTuple): + """ + Tuple with two values, identifiers and distances. Both are 1d arrays, positionally correspond, and will contain the + nearest neighbors from [0..k_neighbors) + """ + + identifiers: npt.NDArray[VectorIdentifier] + """ A `numpy.typing.NDArray[VectorIdentifier]` array of vector identifiers, 1 dimensional """ + distances: npt.NDArray[np.float32] + """ + A `numpy.typing.NDAarray[numpy.float32]` of distances as calculated by the distance metric function, 1 dimensional + """ + + +class QueryResponseBatch(NamedTuple): + """ + Tuple with two values, identifiers and distances. Both are 2d arrays, with dimensionality determined by the + rows corresponding to the number of queries made, and the columns corresponding to the k neighbors + requested. The two 2d arrays have an implicit, position-based relationship + """ + + identifiers: npt.NDArray[VectorIdentifier] + """ + A `numpy.typing.NDArray[VectorIdentifier]` array of vector identifiers, 2 dimensional. The row corresponds to index + of the query, and the column corresponds to the k neighbors requested + """ + distances: np.ndarray[np.float32] + """ + A `numpy.typing.NDAarray[numpy.float32]` of distances as calculated by the distance metric function, 2 dimensional. + The row corresponds to the index of the query, and the column corresponds to the distance of the query to the + *k-th* neighbor + """ + + +from . import defaults +from ._builder import build_disk_index, build_memory_index +from ._common import valid_dtype +from ._dynamic_memory_index import DynamicMemoryIndex +from ._files import ( + Metadata, + tags_from_file, + tags_to_file, + vectors_from_file, + vectors_metadata_from_file, + vectors_to_file, +) +from ._static_disk_index import StaticDiskIndex +from ._static_memory_index import StaticMemoryIndex + +__all__ = [ + "build_disk_index", + "build_memory_index", + "StaticDiskIndex", + "StaticMemoryIndex", + "DynamicMemoryIndex", + "defaults", + "DistanceMetric", + "VectorDType", + "QueryResponse", + "QueryResponseBatch", + "VectorIdentifier", + "VectorIdentifierBatch", + "VectorLike", + "VectorLikeBatch", + "Metadata", + "vectors_metadata_from_file", + "vectors_to_file", + "vectors_from_file", + "tags_to_file", + "tags_from_file", + "valid_dtype", +] diff --git a/python/src/_builder.py b/python/src/_builder.py new file mode 100644 index 000000000..18e9e9fa0 --- /dev/null +++ b/python/src/_builder.py @@ -0,0 +1,280 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import os +import shutil +from pathlib import Path +from typing import Optional, Tuple, Union + +import numpy as np + +from . import DistanceMetric, VectorDType, VectorIdentifierBatch, VectorLikeBatch +from . import _diskannpy as _native_dap +from ._common import ( + _assert, + _assert_is_nonnegative_uint32, + _assert_is_positive_uint32, + _castable_dtype_or_raise, + _valid_metric, + _write_index_metadata, + valid_dtype, +) +from ._diskannpy import defaults +from ._files import tags_to_file, vectors_metadata_from_file, vectors_to_file + + +def _valid_path_and_dtype( + data: Union[str, VectorLikeBatch], + vector_dtype: VectorDType, + index_path: str, + index_prefix: str, +) -> Tuple[str, VectorDType]: + if isinstance(data, str): + vector_bin_path = data + _assert( + Path(data).exists() and Path(data).is_file(), + "if data is of type `str`, it must both exist and be a file", + ) + vector_dtype_actual = valid_dtype(vector_dtype) + else: + vector_bin_path = os.path.join(index_path, f"{index_prefix}_vectors.bin") + if Path(vector_bin_path).exists(): + raise ValueError( + f"The path {vector_bin_path} already exists. Remove it and try again." + ) + vector_dtype_actual = valid_dtype(data.dtype) + vectors_to_file(vector_file=vector_bin_path, vectors=data) + + return vector_bin_path, vector_dtype_actual + + +def build_disk_index( + data: Union[str, VectorLikeBatch], + distance_metric: DistanceMetric, + index_directory: str, + complexity: int, + graph_degree: int, + search_memory_maximum: float, + build_memory_maximum: float, + num_threads: int, + pq_disk_bytes: int = defaults.PQ_DISK_BYTES, + vector_dtype: Optional[VectorDType] = None, + index_prefix: str = "ann", +) -> None: + """ + This function will construct a DiskANN disk index. Disk indices are ideal for very large datasets that + are too large to fit in memory. Memory is still used, but it is primarily used to provide precise disk + locations for fast retrieval of smaller subsets of the index without compromising much on recall. + + If you provide a numpy array, it will save this array to disk in a temp location + in the format DiskANN's PQ Flash Index builder requires. This temp folder is deleted upon index creation completion + or error. + + ### Parameters + - **data**: Either a `str` representing a path to a DiskANN vector bin file, or a numpy.ndarray, + of a supported dtype, in 2 dimensions. Note that `vector_dtype` must be provided if data is a `str` + - **distance_metric**: A `str`, strictly one of {"l2", "mips", "cosine"}. `l2` and `cosine` are supported for all 3 + vector dtypes, but `mips` is only available for single precision floats. + - **index_directory**: The index files will be saved to this **existing** directory path + - **complexity**: The size of the candidate nearest neighbor list to use when building the index. Values between 75 + and 200 are typical. Larger values will take more time to build but result in indices that provide higher recall + for the same search complexity. Use a value that is at least as large as `graph_degree` unless you are prepared + to compromise on quality + - **graph_degree**: The degree of the graph index, typically between 60 and 150. A larger maximum degree will + result in larger indices and longer indexing times, but better search quality. + - **search_memory_maximum**: Build index with the expectation that the search will use at most + `search_memory_maximum`, in gb. + - **build_memory_maximum**: Build index using at most `build_memory_maximum` in gb. Building processes typically + require more memory, while search memory can be reduced. + - **num_threads**: Number of threads to use when creating this index. `0` is used to indicate all available + logical processors should be used. + - **pq_disk_bytes**: Use `0` to store uncompressed data on SSD. This allows the index to asymptote to 100% + recall. If your vectors are too large to store in SSD, this parameter provides the option to compress the + vectors using PQ for storing on SSD. This will trade off recall. You would also want this to be greater + than the number of bytes used for the PQ compressed data stored in-memory. Default is `0`. + - **vector_dtype**: Required if the provided `data` is of type `str`, else we use the `data.dtype` if np array. + - **index_prefix**: The prefix of the index files. Defaults to "ann". + """ + + _assert( + (isinstance(data, str) and vector_dtype is not None) + or isinstance(data, np.ndarray), + "vector_dtype is required if data is a str representing a path to the vector bin file", + ) + dap_metric = _valid_metric(distance_metric) + _assert_is_positive_uint32(complexity, "complexity") + _assert_is_positive_uint32(graph_degree, "graph_degree") + _assert(search_memory_maximum > 0, "search_memory_maximum must be larger than 0") + _assert(build_memory_maximum > 0, "build_memory_maximum must be larger than 0") + _assert_is_nonnegative_uint32(num_threads, "num_threads") + _assert_is_nonnegative_uint32(pq_disk_bytes, "pq_disk_bytes") + _assert(index_prefix != "", "index_prefix cannot be an empty string") + + index_path = Path(index_directory) + _assert( + index_path.exists() and index_path.is_dir(), + "index_directory must both exist and be a directory", + ) + + vector_bin_path, vector_dtype_actual = _valid_path_and_dtype( + data, vector_dtype, index_directory, index_prefix + ) + + num_points, dimensions = vectors_metadata_from_file(vector_bin_path) + + if vector_dtype_actual == np.uint8: + _builder = _native_dap.build_disk_uint8_index + elif vector_dtype_actual == np.int8: + _builder = _native_dap.build_disk_int8_index + else: + _builder = _native_dap.build_disk_float_index + + index_prefix_path = os.path.join(index_directory, index_prefix) + + _builder( + distance_metric=dap_metric, + data_file_path=vector_bin_path, + index_prefix_path=index_prefix_path, + complexity=complexity, + graph_degree=graph_degree, + final_index_ram_limit=search_memory_maximum, + indexing_ram_budget=build_memory_maximum, + num_threads=num_threads, + pq_disk_bytes=pq_disk_bytes, + ) + _write_index_metadata( + index_prefix_path, vector_dtype_actual, dap_metric, num_points, dimensions + ) + + +def build_memory_index( + data: Union[str, VectorLikeBatch], + distance_metric: DistanceMetric, + index_directory: str, + complexity: int, + graph_degree: int, + num_threads: int, + alpha: float = defaults.ALPHA, + use_pq_build: bool = defaults.USE_PQ_BUILD, + num_pq_bytes: int = defaults.NUM_PQ_BYTES, + use_opq: bool = defaults.USE_OPQ, + vector_dtype: Optional[VectorDType] = None, + filter_complexity: int = defaults.FILTER_COMPLEXITY, + tags: Union[str, VectorIdentifierBatch] = "", + index_prefix: str = "ann", +) -> None: + """ + This function will construct a DiskANN memory index. Memory indices are ideal for smaller datasets whose + indices can fit into memory. Memory indices are faster than disk indices, but usually cannot scale to massive + sizes in an individual index on an individual machine. + + `diskannpy`'s memory indices take two forms: a `diskannpy.StaticMemoryIndex`, which will not be mutated, only + searched upon, and a `diskannpy.DynamicMemoryIndex`, which can be mutated AND searched upon in the same process. + + ## Important Note: + You **must** determine the type of index you are building for. If you are building for a + `diskannpy.DynamicMemoryIndex`, you **must** supply a valid value for the `tags` parameter. **Do not supply + tags if the index is intended to be `diskannpy.StaticMemoryIndex`**! + + ### Parameters + + - **data**: Either a `str` representing a path to an existing DiskANN vector bin file, or a numpy.ndarray of a + supported dtype in 2 dimensions. Note that `vector_dtype` must be provided if `data` is a `str`. + - **distance_metric**: A `str`, strictly one of {"l2", "mips", "cosine"}. `l2` and `cosine` are supported for all 3 + vector dtypes, but `mips` is only available for single precision floats. + - **index_directory**: The index files will be saved to this **existing** directory path + - **complexity**: The size of the candidate nearest neighbor list to use when building the index. Values between 75 + and 200 are typical. Larger values will take more time to build but result in indices that provide higher recall + for the same search complexity. Use a value that is at least as large as `graph_degree` unless you are prepared + to compromise on quality + - **graph_degree**: The degree of the graph index, typically between 60 and 150. A larger maximum degree will + result in larger indices and longer indexing times, but better search quality. + - **num_threads**: Number of threads to use when creating this index. `0` is used to indicate all available + logical processors should be used. + - **alpha**: The alpha parameter (>=1) is used to control the nature and number of points that are added to the + graph. A higher alpha value (e.g., 1.4) will result in fewer hops (and IOs) to convergence, but probably more + distance comparisons compared to a lower alpha value. + - **use_pq_build**: Use product quantization during build. Product quantization is a lossy compression technique + that can reduce the size of the index on disk. This will trade off recall. Default is `True`. + - **num_pq_bytes**: The number of bytes used to store the PQ compressed data in memory. This will trade off recall. + Default is `0`. + - **use_opq**: Use optimized product quantization during build. + - **vector_dtype**: Required if the provided `data` is of type `str`, else we use the `data.dtype` if np array. + - **filter_complexity**: Complexity to use when using filters. Default is 0. + - **tags**: A `str` representing a path to a pre-built tags file on disk, or a `numpy.ndarray` of uint32 ids + corresponding to the ordinal position of the vectors provided to build the index. Defaults to "". **This value + must be provided if you want to build a memory index intended for use with `diskannpy.DynamicMemoryIndex`**. + - **index_prefix**: The prefix of the index files. Defaults to "ann". + """ + _assert( + (isinstance(data, str) and vector_dtype is not None) + or isinstance(data, np.ndarray), + "vector_dtype is required if data is a str representing a path to the vector bin file", + ) + dap_metric = _valid_metric(distance_metric) + _assert_is_positive_uint32(complexity, "complexity") + _assert_is_positive_uint32(graph_degree, "graph_degree") + _assert( + alpha >= 1, + "alpha must be >= 1, and realistically should be kept between [1.0, 2.0)", + ) + _assert_is_nonnegative_uint32(num_threads, "num_threads") + _assert_is_nonnegative_uint32(num_pq_bytes, "num_pq_bytes") + _assert_is_nonnegative_uint32(filter_complexity, "filter_complexity") + _assert(index_prefix != "", "index_prefix cannot be an empty string") + + index_path = Path(index_directory) + _assert( + index_path.exists() and index_path.is_dir(), + "index_directory must both exist and be a directory", + ) + + vector_bin_path, vector_dtype_actual = _valid_path_and_dtype( + data, vector_dtype, index_directory, index_prefix + ) + + num_points, dimensions = vectors_metadata_from_file(vector_bin_path) + + if vector_dtype_actual == np.uint8: + _builder = _native_dap.build_memory_uint8_index + elif vector_dtype_actual == np.int8: + _builder = _native_dap.build_memory_int8_index + else: + _builder = _native_dap.build_memory_float_index + + index_prefix_path = os.path.join(index_directory, index_prefix) + + if isinstance(tags, str) and tags != "": + use_tags = True + shutil.copy(tags, index_prefix_path + ".tags") + elif not isinstance(tags, str): + use_tags = True + tags_as_array = _castable_dtype_or_raise(tags, expected=np.uint32) + _assert(len(tags_as_array.shape) == 1, "Provided tags must be 1 dimensional") + _assert( + tags_as_array.shape[0] == num_points, + "Provided tags must contain an identical population to the number of points, " + f"{tags_as_array.shape[0]=}, {num_points=}", + ) + tags_to_file(index_prefix_path + ".tags", tags_as_array) + else: + use_tags = False + + _builder( + distance_metric=dap_metric, + data_file_path=vector_bin_path, + index_output_path=index_prefix_path, + complexity=complexity, + graph_degree=graph_degree, + alpha=alpha, + num_threads=num_threads, + use_pq_build=use_pq_build, + num_pq_bytes=num_pq_bytes, + use_opq=use_opq, + filter_complexity=filter_complexity, + use_tags=use_tags, + ) + + _write_index_metadata( + index_prefix_path, vector_dtype_actual, dap_metric, num_points, dimensions + ) diff --git a/python/src/_builder.pyi b/python/src/_builder.pyi new file mode 100644 index 000000000..5014880c6 --- /dev/null +++ b/python/src/_builder.pyi @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from typing import BinaryIO, Optional, overload + +import numpy as np + +from . import DistanceMetric, VectorDType, VectorIdentifierBatch, VectorLikeBatch + +def numpy_to_diskann_file(vectors: np.ndarray, file_handler: BinaryIO): ... +@overload +def build_disk_index( + data: str, + distance_metric: DistanceMetric, + index_directory: str, + complexity: int, + graph_degree: int, + search_memory_maximum: float, + build_memory_maximum: float, + num_threads: int, + pq_disk_bytes: int, + vector_dtype: VectorDType, + index_prefix: str, +) -> None: ... +@overload +def build_disk_index( + data: VectorLikeBatch, + distance_metric: DistanceMetric, + index_directory: str, + complexity: int, + graph_degree: int, + search_memory_maximum: float, + build_memory_maximum: float, + num_threads: int, + pq_disk_bytes: int, + index_prefix: str, +) -> None: ... +@overload +def build_memory_index( + data: VectorLikeBatch, + distance_metric: DistanceMetric, + index_directory: str, + complexity: int, + graph_degree: int, + alpha: float, + num_threads: int, + use_pq_build: bool, + num_pq_bytes: int, + use_opq: bool, + label_file: str, + universal_label: str, + filter_complexity: int, + tags: Optional[VectorIdentifierBatch], + index_prefix: str, +) -> None: ... +@overload +def build_memory_index( + data: str, + distance_metric: DistanceMetric, + index_directory: str, + complexity: int, + graph_degree: int, + alpha: float, + num_threads: int, + use_pq_build: bool, + num_pq_bytes: int, + use_opq: bool, + vector_dtype: VectorDType, + label_file: str, + universal_label: str, + filter_complexity: int, + tags: Optional[str], + index_prefix: str, +) -> None: ... diff --git a/python/src/_common.py b/python/src/_common.py new file mode 100644 index 000000000..53f1dbcab --- /dev/null +++ b/python/src/_common.py @@ -0,0 +1,249 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import os +import warnings +from enum import Enum +from pathlib import Path +from typing import Literal, NamedTuple, Optional, Tuple, Type, Union + +import numpy as np + +from . import ( + DistanceMetric, + VectorDType, + VectorIdentifierBatch, + VectorLike, + VectorLikeBatch, +) +from . import _diskannpy as _native_dap + +__ALL__ = ["valid_dtype"] + +_VALID_DTYPES = [np.float32, np.int8, np.uint8] + + +def valid_dtype(dtype: Type) -> VectorDType: + """ + Utility method to determine whether the provided dtype is supported by `diskannpy`, and if so, the canonical + dtype we will use internally (e.g. np.single -> np.float32) + """ + _assert_dtype(dtype) + if dtype == np.uint8: + return np.uint8 + if dtype == np.int8: + return np.int8 + if dtype == np.float32: + return np.float32 + + +def _assert(statement_eval: bool, message: str): + if not statement_eval: + raise ValueError(message) + + +def _valid_metric(metric: str) -> _native_dap.Metric: + if not isinstance(metric, str): + raise ValueError("distance_metric must be a string") + if metric.lower() == "l2": + return _native_dap.L2 + elif metric.lower() == "mips": + return _native_dap.INNER_PRODUCT + elif metric.lower() == "cosine": + return _native_dap.COSINE + else: + raise ValueError("distance_metric must be one of 'l2', 'mips', or 'cosine'") + + +def _assert_dtype(dtype: Type): + _assert( + any(np.can_cast(dtype, _dtype) for _dtype in _VALID_DTYPES), + f"Vector dtype must be of one of type {{(np.single, np.float32), (np.byte, np.int8), (np.ubyte, np.uint8)}}", + ) + + +def _castable_dtype_or_raise( + data: Union[VectorLike, VectorLikeBatch, VectorIdentifierBatch], expected: np.dtype +) -> np.ndarray: + if isinstance(data, np.ndarray) and np.can_cast(data.dtype, expected): + return data.astype(expected, casting="safe") + else: + raise TypeError( + f"expecting a numpy ndarray of dtype {expected}, not a {type(data)}" + ) + + +def _assert_2d(vectors: np.ndarray, name: str): + _assert(len(vectors.shape) == 2, f"{name} must be 2d numpy array") + + +__MAX_UINT32_VAL = 4_294_967_295 + + +def _assert_is_positive_uint32(test_value: int, parameter: str): + _assert( + test_value is not None and 0 < test_value < __MAX_UINT32_VAL, + f"{parameter} must be a positive integer in the uint32 range", + ) + + +def _assert_is_nonnegative_uint32(test_value: int, parameter: str): + _assert( + test_value is not None and -1 < test_value < __MAX_UINT32_VAL, + f"{parameter} must be a non-negative integer in the uint32 range", + ) + + +def _assert_is_nonnegative_uint64(test_value: int, parameter: str): + _assert( + -1 < test_value, + f"{parameter} must be a non-negative integer in the uint64 range", + ) + + +def _assert_existing_directory(path: str, parameter: str): + _path = Path(path) + _assert( + _path.exists() and _path.is_dir(), f"{parameter} must be an existing directory" + ) + + +def _assert_existing_file(path: str, parameter: str): + _path = Path(path) + _assert(_path.exists() and _path.is_file(), f"{parameter} must be an existing file") + + +class _DataType(Enum): + FLOAT32 = 0 + INT8 = 1 + UINT8 = 2 + + @classmethod + def from_type(cls, vector_dtype: VectorDType) -> "DataType": + if vector_dtype == np.float32: + return cls.FLOAT32 + if vector_dtype == np.int8: + return cls.INT8 + if vector_dtype == np.uint8: + return cls.UINT8 + + def to_type(self) -> VectorDType: + if self is _DataType.FLOAT32: + return np.float32 + if self is _DataType.INT8: + return np.int8 + if self is _DataType.UINT8: + return np.uint8 + + +class _Metric(Enum): + L2 = 0 + MIPS = 1 + COSINE = 2 + + @classmethod + def from_native(cls, metric: _native_dap.Metric) -> "_Metric": + if metric == _native_dap.L2: + return cls.L2 + if metric == _native_dap.INNER_PRODUCT: + return cls.MIPS + if metric == _native_dap.COSINE: + return cls.COSINE + + def to_native(self) -> _native_dap.Metric: + if self is _Metric.L2: + return _native_dap.L2 + if self is _Metric.MIPS: + return _native_dap.INNER_PRODUCT + if self is _Metric.COSINE: + return _native_dap.COSINE + + def to_str(self) -> _native_dap.Metric: + if self is _Metric.L2: + return "l2" + if self is _Metric.MIPS: + return "mips" + if self is _Metric.COSINE: + return "cosine" + + +def _build_metadata_path(index_path_and_prefix: str) -> str: + return index_path_and_prefix + "_metadata.bin" + + +def _write_index_metadata( + index_path_and_prefix: str, + dtype: VectorDType, + metric: _native_dap.Metric, + num_points: int, + dimensions: int, +): + np.array( + [ + _DataType.from_type(dtype).value, + _Metric.from_native(metric).value, + num_points, + dimensions, + ], + dtype=np.uint64, + ).tofile(_build_metadata_path(index_path_and_prefix)) + + +def _read_index_metadata( + index_path_and_prefix: str, +) -> Optional[Tuple[VectorDType, str, np.uint64, np.uint64]]: + path = _build_metadata_path(index_path_and_prefix) + if not Path(path).exists(): + return None + else: + metadata = np.fromfile(path, dtype=np.uint64, count=-1) + return ( + _DataType(int(metadata[0])).to_type(), + _Metric(int(metadata[1])).to_str(), + metadata[2], + metadata[3], + ) + + +def _ensure_index_metadata( + index_path_and_prefix: str, + vector_dtype: Optional[VectorDType], + distance_metric: Optional[DistanceMetric], + max_vectors: int, + dimensions: Optional[int], +) -> Tuple[VectorDType, str, np.uint64, np.uint64]: + possible_metadata = _read_index_metadata(index_path_and_prefix) + if possible_metadata is None: + _assert( + all([vector_dtype, distance_metric, dimensions]), + "distance_metric, vector_dtype, and dimensions must provided if a corresponding metadata file has not " + "been built for this index, such as when an index was built via the CLI tools or prior to the addition " + "of a metadata file", + ) + _assert_dtype(vector_dtype) + _assert_is_positive_uint32(max_vectors, "max_vectors") + _assert_is_positive_uint32(dimensions, "dimensions") + return vector_dtype, distance_metric, max_vectors, dimensions # type: ignore + else: + vector_dtype, distance_metric, num_vectors, dimensions = possible_metadata + if max_vectors is not None and num_vectors > max_vectors: + warnings.warn( + "The number of vectors in the saved index exceeds the max_vectors parameter. " + "max_vectors is being adjusted to accommodate the dataset, but any insertions will fail." + ) + max_vectors = num_vectors + if num_vectors == max_vectors: + warnings.warn( + "The number of vectors in the saved index equals max_vectors parameter. Any insertions will fail." + ) + return possible_metadata + + +def _valid_index_prefix(index_directory: str, index_prefix: str) -> str: + _assert( + index_directory is not None and index_directory != "", + "index_directory cannot be None or empty", + ) + _assert_existing_directory(index_directory, "index_directory") + _assert(index_prefix != "", "index_prefix cannot be an empty string") + return os.path.join(index_directory, index_prefix) diff --git a/python/src/_dynamic_memory_index.py b/python/src/_dynamic_memory_index.py new file mode 100644 index 000000000..9570b8345 --- /dev/null +++ b/python/src/_dynamic_memory_index.py @@ -0,0 +1,509 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import os +import warnings +from pathlib import Path +from typing import Optional + +import numpy as np + +from . import ( + DistanceMetric, + QueryResponse, + QueryResponseBatch, + VectorDType, + VectorIdentifier, + VectorIdentifierBatch, + VectorLike, + VectorLikeBatch, +) +from . import _diskannpy as _native_dap +from ._common import ( + _assert, + _assert_2d, + _assert_dtype, + _assert_existing_directory, + _assert_is_nonnegative_uint32, + _assert_is_positive_uint32, + _castable_dtype_or_raise, + _ensure_index_metadata, + _valid_index_prefix, + _valid_metric, + _write_index_metadata, +) +from ._diskannpy import defaults + +__ALL__ = ["DynamicMemoryIndex"] + + +class DynamicMemoryIndex: + """ + A DynamicMemoryIndex instance is used to both search and mutate a `diskannpy` memory index. This index is unlike + either `diskannpy.StaticMemoryIndex` or `diskannpy.StaticDiskIndex` in the following ways: + + - It requires an explicit vector identifier for each vector added to it. + - Insert and (lazy) deletion operations are provided for a flexible, living index + + The mutable aspect of this index will absolutely impact search time performance as new vectors are added and + old deleted. `DynamicMemoryIndex.consolidate_deletes()` should be called periodically to restructure the index + to remove deleted vectors and improve per-search performance, at the cost of an expensive index consolidation to + occur. + """ + + @classmethod + def from_file( + cls, + index_directory: str, + max_vectors: int, + complexity: int, + graph_degree: int, + saturate_graph: bool = defaults.SATURATE_GRAPH, + max_occlusion_size: int = defaults.MAX_OCCLUSION_SIZE, + alpha: float = defaults.ALPHA, + num_threads: int = defaults.NUM_THREADS, + filter_complexity: int = defaults.FILTER_COMPLEXITY, + num_frozen_points: int = defaults.NUM_FROZEN_POINTS_DYNAMIC, + initial_search_complexity: int = 0, + search_threads: int = 0, + concurrent_consolidation: bool = True, + index_prefix: str = "ann", + distance_metric: Optional[DistanceMetric] = None, + vector_dtype: Optional[VectorDType] = None, + dimensions: Optional[int] = None, + ) -> "DynamicMemoryIndex": + """ + The `from_file` classmethod is used to load a previously saved index from disk. This index *must* have been + created with a valid `tags` file or `tags` np.ndarray of `diskannpy.VectorIdentifier`s. It is *strongly* + recommended that you use the same parameters as the `diskannpy.build_memory_index()` function that created + the index. + + ### Parameters + - **index_directory**: The directory containing the index files. This directory must contain the following + files: + - `{index_prefix}.data` + - `{index_prefix}.tags` + - `{index_prefix}` + + It may also include the following optional files: + - `{index_prefix}_vectors.bin`: Optional. `diskannpy` builder functions may create this file in the + `index_directory` if the index was created from a numpy array + - `{index_prefix}_metadata.bin`: Optional. `diskannpy` builder functions create this file to store metadata + about the index, such as vector dtype, distance metric, number of vectors and vector dimensionality. + If an index is built from the `diskann` cli tools, this file will not exist. + - **max_vectors**: Capacity of the memory index including space for future insertions. + - **complexity**: Complexity (a.k.a `L`) references the size of the list we store candidate approximate + neighbors in. It's used during save (which is an index rebuild), and it's used as an initial search size to + warm up our index and lower the latency for initial real searches. + - **graph_degree**: Graph degree (a.k.a. `R`) is the maximum degree allowed for a node in the index's graph + structure. This degree will be pruned throughout the course of the index build, but it will never grow beyond + this value. Higher R values require longer index build times, but may result in an index showing excellent + recall and latency characteristics. + - **saturate_graph**: If True, the adjacency list of each node will be saturated with neighbors to have exactly + `graph_degree` neighbors. If False, each node will have between 1 and `graph_degree` neighbors. + - **max_occlusion_size**: The maximum number of points that can be considered by occlude_list function. + - **alpha**: The alpha parameter (>=1) is used to control the nature and number of points that are added to the + graph. A higher alpha value (e.g., 1.4) will result in fewer hops (and IOs) to convergence, but probably + more distance comparisons compared to a lower alpha value. + - **num_threads**: Number of threads to use when creating this index. `0` indicates we should use all available + logical processors. + - **filter_complexity**: Complexity to use when using filters. Default is 0. + - **num_frozen_points**: Number of points to freeze. Default is 1. + - **initial_search_complexity**: Should be set to the most common `complexity` expected to be used during the + life of this `diskannpy.DynamicMemoryIndex` object. The working scratch memory allocated is based off of + `initial_search_complexity` * `search_threads`. Note that it may be resized if a `search` or `batch_search` + operation requests a space larger than can be accommodated by these values. + - **search_threads**: Should be set to the most common `num_threads` expected to be used during the + life of this `diskannpy.DynamicMemoryIndex` object. The working scratch memory allocated is based off of + `initial_search_complexity` * `search_threads`. Note that it may be resized if a `batch_search` + operation requests a space larger than can be accommodated by these values. + - **concurrent_consolidation**: This flag dictates whether consolidation can be run alongside inserts and + deletes, or whether the index is locked down to changes while consolidation is ongoing. + - **index_prefix**: The prefix of the index files. Defaults to "ann". + - **distance_metric**: A `str`, strictly one of {"l2", "mips", "cosine"}. `l2` and `cosine` are supported for all 3 + vector dtypes, but `mips` is only available for single precision floats. Default is `None`. **This + value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it does not exist, + you are required to provide it. + - **vector_dtype**: The vector dtype this index has been built with. **This value is only used if a + `{index_prefix}_metadata.bin` file does not exist.** If it does not exist, you are required to provide it. + - **dimensions**: The vector dimensionality of this index. All new vectors inserted must be the same + dimensionality. **This value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it + does not exist, you are required to provide it. + + ### Returns + A `diskannpy.DynamicMemoryIndex` object, with the index loaded from disk and ready to use for insertions, + deletions, and searches. + + """ + index_prefix_path = _valid_index_prefix(index_directory, index_prefix) + + # do tags exist? + tags_file = index_prefix_path + ".tags" + _assert( + Path(tags_file).exists(), + f"The file {tags_file} does not exist in {index_directory}", + ) + vector_dtype, dap_metric, num_vectors, dimensions = _ensure_index_metadata( + index_prefix_path, vector_dtype, distance_metric, max_vectors, dimensions + ) + + index = cls( + distance_metric=dap_metric, # type: ignore + vector_dtype=vector_dtype, + dimensions=dimensions, + max_vectors=max_vectors, + complexity=complexity, + graph_degree=graph_degree, + saturate_graph=saturate_graph, + max_occlusion_size=max_occlusion_size, + alpha=alpha, + num_threads=num_threads, + filter_complexity=filter_complexity, + num_frozen_points=num_frozen_points, + initial_search_complexity=initial_search_complexity, + search_threads=search_threads, + concurrent_consolidation=concurrent_consolidation, + ) + index._index.load(index_prefix_path) + index._num_vectors = num_vectors # current number of vectors loaded + return index + + def __init__( + self, + distance_metric: DistanceMetric, + vector_dtype: VectorDType, + dimensions: int, + max_vectors: int, + complexity: int, + graph_degree: int, + saturate_graph: bool = defaults.SATURATE_GRAPH, + max_occlusion_size: int = defaults.MAX_OCCLUSION_SIZE, + alpha: float = defaults.ALPHA, + num_threads: int = defaults.NUM_THREADS, + filter_complexity: int = defaults.FILTER_COMPLEXITY, + num_frozen_points: int = defaults.NUM_FROZEN_POINTS_DYNAMIC, + initial_search_complexity: int = 0, + search_threads: int = 0, + concurrent_consolidation: bool = True, + ): + """ + The `diskannpy.DynamicMemoryIndex` represents our python API into a mutable DiskANN memory index. + + This constructor is used to create a new, empty index. If you wish to load a previously saved index from disk, + please use the `diskannpy.DynamicMemoryIndex.from_file` classmethod instead. + + ### Parameters + - **distance_metric**: A `str`, strictly one of {"l2", "mips", "cosine"}. `l2` and `cosine` are supported for all 3 + vector dtypes, but `mips` is only available for single precision floats. + - **vector_dtype**: One of {`np.float32`, `np.int8`, `np.uint8`}. The dtype of the vectors this index will + be storing. + - **dimensions**: The vector dimensionality of this index. All new vectors inserted must be the same + dimensionality. + - **max_vectors**: Capacity of the data store including space for future insertions + - **graph_degree**: Graph degree (a.k.a. `R`) is the maximum degree allowed for a node in the index's graph + structure. This degree will be pruned throughout the course of the index build, but it will never grow beyond + this value. Higher `graph_degree` values require longer index build times, but may result in an index showing + excellent recall and latency characteristics. + - **saturate_graph**: If True, the adjacency list of each node will be saturated with neighbors to have exactly + `graph_degree` neighbors. If False, each node will have between 1 and `graph_degree` neighbors. + - **max_occlusion_size**: The maximum number of points that can be considered by occlude_list function. + - **alpha**: The alpha parameter (>=1) is used to control the nature and number of points that are added to the + graph. A higher alpha value (e.g., 1.4) will result in fewer hops (and IOs) to convergence, but probably + more distance comparisons compared to a lower alpha value. + - **num_threads**: Number of threads to use when creating this index. `0` indicates we should use all available + logical processors. + - **filter_complexity**: Complexity to use when using filters. Default is 0. + - **num_frozen_points**: Number of points to freeze. Default is 1. + - **initial_search_complexity**: Should be set to the most common `complexity` expected to be used during the + life of this `diskannpy.DynamicMemoryIndex` object. The working scratch memory allocated is based off of + `initial_search_complexity` * `search_threads`. Note that it may be resized if a `search` or `batch_search` + operation requests a space larger than can be accommodated by these values. + - **search_threads**: Should be set to the most common `num_threads` expected to be used during the + life of this `diskannpy.DynamicMemoryIndex` object. The working scratch memory allocated is based off of + `initial_search_complexity` * `search_threads`. Note that it may be resized if a `batch_search` + operation requests a space larger than can be accommodated by these values. + - **concurrent_consolidation**: This flag dictates whether consolidation can be run alongside inserts and + deletes, or whether the index is locked down to changes while consolidation is ongoing. + + """ + self._num_vectors = 0 + self._removed_num_vectors = 0 + dap_metric = _valid_metric(distance_metric) + self._dap_metric = dap_metric + _assert_dtype(vector_dtype) + _assert_is_positive_uint32(dimensions, "dimensions") + + self._vector_dtype = vector_dtype + self._dimensions = dimensions + + _assert_is_positive_uint32(max_vectors, "max_vectors") + _assert_is_positive_uint32(complexity, "complexity") + _assert_is_positive_uint32(graph_degree, "graph_degree") + _assert( + alpha >= 1, + "alpha must be >= 1, and realistically should be kept between [1.0, 2.0)", + ) + _assert_is_nonnegative_uint32(max_occlusion_size, "max_occlusion_size") + _assert_is_nonnegative_uint32(num_threads, "num_threads") + _assert_is_nonnegative_uint32(filter_complexity, "filter_complexity") + _assert_is_nonnegative_uint32(num_frozen_points, "num_frozen_points") + _assert_is_nonnegative_uint32( + initial_search_complexity, "initial_search_complexity" + ) + _assert_is_nonnegative_uint32(search_threads, "search_threads") + + self._max_vectors = max_vectors + self._complexity = complexity + self._graph_degree = graph_degree + + if vector_dtype == np.uint8: + _index = _native_dap.DynamicMemoryUInt8Index + elif vector_dtype == np.int8: + _index = _native_dap.DynamicMemoryInt8Index + else: + _index = _native_dap.DynamicMemoryFloatIndex + + self._index = _index( + distance_metric=dap_metric, + dimensions=dimensions, + max_vectors=max_vectors, + complexity=complexity, + graph_degree=graph_degree, + saturate_graph=saturate_graph, + max_occlusion_size=max_occlusion_size, + alpha=alpha, + num_threads=num_threads, + filter_complexity=filter_complexity, + num_frozen_points=num_frozen_points, + initial_search_complexity=initial_search_complexity, + search_threads=search_threads, + concurrent_consolidation=concurrent_consolidation, + ) + self._points_deleted = False + + def search( + self, query: VectorLike, k_neighbors: int, complexity: int + ) -> QueryResponse: + """ + Searches the index by a single query vector. + + ### Parameters + - **query**: 1d numpy array of the same dimensionality and dtype of the index. + - **k_neighbors**: Number of neighbors to be returned. If query vector exists in index, it almost definitely + will be returned as well, so adjust your ``k_neighbors`` as appropriate. Must be > 0. + - **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size + increases accuracy at the cost of latency. Must be at least k_neighbors in size. + """ + _query = _castable_dtype_or_raise(query, expected=self._vector_dtype) + _assert(len(_query.shape) == 1, "query vector must be 1-d") + _assert( + _query.shape[0] == self._dimensions, + f"query vector must have the same dimensionality as the index; index dimensionality: {self._dimensions}, " + f"query dimensionality: {_query.shape[0]}", + ) + _assert_is_positive_uint32(k_neighbors, "k_neighbors") + _assert_is_nonnegative_uint32(complexity, "complexity") + + if k_neighbors > complexity: + warnings.warn( + f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}" + ) + complexity = k_neighbors + return self._index.search(query=_query, knn=k_neighbors, complexity=complexity) + + def batch_search( + self, + queries: VectorLikeBatch, + k_neighbors: int, + complexity: int, + num_threads: int, + ) -> QueryResponseBatch: + """ + Searches the index by a batch of query vectors. + + This search is parallelized and far more efficient than searching for each vector individually. + + ### Parameters + - **queries**: 2d numpy array, with column dimensionality matching the index and row dimensionality being the + number of queries intended to search for in parallel. Dtype must match dtype of the index. + - **k_neighbors**: Number of neighbors to be returned. If query vector exists in index, it almost definitely + will be returned as well, so adjust your ``k_neighbors`` as appropriate. Must be > 0. + - **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size + increases accuracy at the cost of latency. Must be at least k_neighbors in size. + - **num_threads**: Number of threads to use when searching this index. (>= 0), 0 = num_threads in system + """ + _queries = _castable_dtype_or_raise(queries, expected=self._vector_dtype) + _assert_2d(_queries, "queries") + _assert( + _queries.shape[1] == self._dimensions, + f"query vectors must have the same dimensionality as the index; index dimensionality: {self._dimensions}, " + f"query dimensionality: {_queries.shape[1]}", + ) + + _assert_is_positive_uint32(k_neighbors, "k_neighbors") + _assert_is_positive_uint32(complexity, "complexity") + _assert_is_nonnegative_uint32(num_threads, "num_threads") + + if k_neighbors > complexity: + warnings.warn( + f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}" + ) + complexity = k_neighbors + + num_queries, dim = queries.shape + return self._index.batch_search( + queries=_queries, + num_queries=num_queries, + knn=k_neighbors, + complexity=complexity, + num_threads=num_threads, + ) + + def save(self, save_path: str, index_prefix: str = "ann"): + """ + Saves this index to file. + + ### Parameters + - **save_path**: The path to save these index files to. + - **index_prefix**: The prefix of the index files. Defaults to "ann". + """ + if save_path == "": + raise ValueError("save_path cannot be empty") + if index_prefix == "": + raise ValueError("index_prefix cannot be empty") + + index_prefix = index_prefix.format(complexity=self._complexity, graph_degree=self._graph_degree) + _assert_existing_directory(save_path, "save_path") + save_path = os.path.join(save_path, index_prefix) + if self._points_deleted is True: + warnings.warn( + "DynamicMemoryIndex.save() currently requires DynamicMemoryIndex.consolidate_delete() to be called " + "prior to save when items have been marked for deletion. This is being done automatically now, though" + "it will increase the time it takes to save; on large sets of data it can take a substantial amount of " + "time. In the future, we will implement a faster save with unconsolidated deletes, but for now this is " + "required." + ) + self._index.consolidate_delete() + self._index.save( + save_path=save_path, compact_before_save=True + ) # we do not yet support uncompacted saves + _write_index_metadata( + save_path, + self._vector_dtype, + self._dap_metric, + self._index.num_points(), + self._dimensions, + ) + + def insert(self, vector: VectorLike, vector_id: VectorIdentifier): + """ + Inserts a single vector into the index with the provided vector_id. + + If this insertion will overrun the `max_vectors` count boundaries of this index, `consolidate_delete()` will + be executed automatically. + + ### Parameters + - **vector**: The vector to insert. Note that dtype must match. + - **vector_id**: The vector_id to use for this vector. + """ + _vector = _castable_dtype_or_raise(vector, expected=self._vector_dtype) + _assert(len(vector.shape) == 1, "insert vector must be 1-d") + _assert_is_positive_uint32(vector_id, "vector_id") + if self._num_vectors + 1 > self._max_vectors: + if self._removed_num_vectors > 0: + warnings.warn(f"Inserting this vector would overrun the max_vectors={self._max_vectors} specified at index " + f"construction. We are attempting to consolidate_delete() to make space.") + self.consolidate_delete() + else: + raise RuntimeError(f"Inserting this vector would overrun the max_vectors={self._max_vectors} specified " + f"at index construction. Unable to make space by consolidating deletions. The insert" + f"operation has failed.") + status = self._index.insert(_vector, np.uint32(vector_id)) + if status == 0: + self._num_vectors += 1 + else: + raise RuntimeError( + f"Insert was unable to complete successfully; error code returned from diskann C++ lib: {status}" + ) + + + def batch_insert( + self, + vectors: VectorLikeBatch, + vector_ids: VectorIdentifierBatch, + num_threads: int = 0, + ): + """ + Inserts a batch of vectors into the index with the provided vector_ids. + + If this batch insertion will overrun the `max_vectors` count boundaries of this index, `consolidate_delete()` + will be executed automatically. + + ### Parameters + - **vectors**: The 2d numpy array of vectors to insert. + - **vector_ids**: The 1d array of vector ids to use. This array must have the same number of elements as + the vectors array has rows. The dtype of vector_ids must be `np.uint32` + - **num_threads**: Number of threads to use when inserting into this index. (>= 0), 0 = num_threads in system + """ + _query = _castable_dtype_or_raise(vectors, expected=self._vector_dtype) + _assert(len(vectors.shape) == 2, "vectors must be a 2-d array") + _assert( + vectors.shape[0] == vector_ids.shape[0], + "Number of vectors must be equal to number of ids", + ) + _vectors = vectors.astype(dtype=self._vector_dtype, casting="safe", copy=False) + _vector_ids = vector_ids.astype(dtype=np.uint32, casting="safe", copy=False) + + if self._num_vectors + _vector_ids.shape[0] > self._max_vectors: + if self._max_vectors + self._removed_num_vectors >= _vector_ids.shape[0]: + warnings.warn(f"Inserting these vectors, count={_vector_ids.shape[0]} would overrun the " + f"max_vectors={self._max_vectors} specified at index construction. We are attempting to " + f"consolidate_delete() to make space.") + self.consolidate_delete() + else: + raise RuntimeError(f"Inserting these vectors count={_vector_ids.shape[0]} would overrun the " + f"max_vectors={self._max_vectors} specified at index construction. Unable to make " + f"space by consolidating deletions. The batch insert operation has failed.") + + statuses = self._index.batch_insert( + _vectors, _vector_ids, _vector_ids.shape[0], num_threads + ) + successes = [] + failures = [] + for i in range(0, len(statuses)): + if statuses[i] == 0: + successes.append(i) + else: + failures.append(i) + self._num_vectors += len(successes) + if len(failures) == 0: + return + failed_ids = vector_ids[failures] + raise RuntimeError( + f"During batch insert, the following vector_ids were unable to be inserted into the index: {failed_ids}. " + f"{len(successes)} were successfully inserted" + ) + + + def mark_deleted(self, vector_id: VectorIdentifier): + """ + Mark vector for deletion. This is a soft delete that won't return the vector id in any results, but does not + remove it from the underlying index files or memory structure. To execute a hard delete, call this method and + then call the much more expensive `consolidate_delete` method on this index. + ### Parameters + - **vector_id**: The vector id to delete. Must be a uint32. + """ + _assert_is_positive_uint32(vector_id, "vector_id") + self._points_deleted = True + self._removed_num_vectors += 1 + # we do not decrement self._num_vectors until consolidate_delete + self._index.mark_deleted(np.uint32(vector_id)) + + def consolidate_delete(self): + """ + This method actually restructures the DiskANN index to remove the items that have been marked for deletion. + """ + self._index.consolidate_delete() + self._points_deleted = False + self._num_vectors -= self._removed_num_vectors + self._removed_num_vectors = 0 diff --git a/python/src/_files.py b/python/src/_files.py new file mode 100644 index 000000000..1c9fa2103 --- /dev/null +++ b/python/src/_files.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import warnings +from typing import BinaryIO, NamedTuple + +import numpy as np +import numpy.typing as npt + +from . import VectorDType, VectorIdentifierBatch, VectorLikeBatch +from ._common import _assert, _assert_2d, _assert_dtype, _assert_existing_file + + +class Metadata(NamedTuple): + """DiskANN binary vector files contain a small stanza containing some metadata about them.""" + + num_vectors: int + """ The number of vectors in the file. """ + dimensions: int + """ The dimensionality of the vectors in the file. """ + + +def vectors_metadata_from_file(vector_file: str) -> Metadata: + """ + Read the metadata from a DiskANN binary vector file. + ### Parameters + - **vector_file**: The path to the vector file to read the metadata from. + + ### Returns + `diskannpy.Metadata` + """ + _assert_existing_file(vector_file, "vector_file") + points, dims = np.fromfile(file=vector_file, dtype=np.int32, count=2) + return Metadata(points, dims) + + +def _write_bin(data: np.ndarray, file_handler: BinaryIO): + if len(data.shape) == 1: + _ = file_handler.write(np.array([data.shape[0], 1], dtype=np.int32).tobytes()) + else: + _ = file_handler.write(np.array(data.shape, dtype=np.int32).tobytes()) + _ = file_handler.write(data.tobytes()) + + +def vectors_to_file(vector_file: str, vectors: VectorLikeBatch) -> None: + """ + Utility function that writes a DiskANN binary vector formatted file to the location of your choosing. + + ### Parameters + - **vector_file**: The path to the vector file to write the vectors to. + - **vectors**: A 2d array of dtype `numpy.float32`, `numpy.uint8`, or `numpy.int8` + """ + _assert_dtype(vectors.dtype) + _assert_2d(vectors, "vectors") + with open(vector_file, "wb") as fh: + _write_bin(vectors, fh) + + +def vectors_from_file(vector_file: str, dtype: VectorDType) -> npt.NDArray[VectorDType]: + """ + Read vectors from a DiskANN binary vector file. + + ### Parameters + - **vector_file**: The path to the vector file to read the vectors from. + - **dtype**: The data type of the vectors in the file. Ensure you match the data types exactly + + ### Returns + `numpy.typing.NDArray[dtype]` + """ + points, dims = vectors_metadata_from_file(vector_file) + return np.fromfile(file=vector_file, dtype=dtype, offset=8).reshape(points, dims) + + +def tags_to_file(tags_file: str, tags: VectorIdentifierBatch) -> None: + """ + Write tags to a DiskANN binary tag file. + + ### Parameters + - **tags_file**: The path to the tag file to write the tags to. + - **tags**: A 1d array of dtype `numpy.uint32` containing the tags to write. If you have a 2d array of tags with + one column, you can pass it here and it will be reshaped and copied to a new array. It is more efficient for you + to reshape on your own without copying it first, as it should be a constant time operation vs. linear time + + """ + _assert(np.can_cast(tags.dtype, np.uint32), "valid tags must be uint32") + _assert( + len(tags.shape) == 1 or tags.shape[1] == 1, + "tags must be 1d or 2d with 1 column", + ) + if len(tags.shape) == 2: + warnings.warn( + "Tags in 2d with one column will be reshaped and copied to a new array. " + "It is more efficient for you to reshape without copying first." + ) + tags = tags.reshape(tags.shape[0], copy=True) + with open(tags_file, "wb") as fh: + _write_bin(tags.astype(np.uint32), fh) + + +def tags_from_file(tags_file: str) -> VectorIdentifierBatch: + """ + Read tags from a DiskANN binary tag file and return them as a 1d array of dtype `numpy.uint32`. + + ### Parameters + - **tags_file**: The path to the tag file to read the tags from. + """ + _assert_existing_file(tags_file, "tags_file") + points, dims = vectors_metadata_from_file( + tags_file + ) # tag files contain the same metadata stanza + return np.fromfile(file=tags_file, dtype=np.uint32, offset=8).reshape(points) diff --git a/python/src/_static_disk_index.py b/python/src/_static_disk_index.py new file mode 100644 index 000000000..1ca93c0a4 --- /dev/null +++ b/python/src/_static_disk_index.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import os +import warnings +from typing import Optional + +import numpy as np + +from . import ( + DistanceMetric, + QueryResponse, + QueryResponseBatch, + VectorDType, + VectorLike, + VectorLikeBatch, +) +from . import _diskannpy as _native_dap +from ._common import ( + _assert, + _assert_2d, + _assert_is_nonnegative_uint32, + _assert_is_positive_uint32, + _castable_dtype_or_raise, + _ensure_index_metadata, + _valid_index_prefix, + _valid_metric, +) + +__ALL__ = ["StaticDiskIndex"] + + +class StaticDiskIndex: + """ + A StaticDiskIndex is a disk-backed index that is not mutable. + """ + + def __init__( + self, + index_directory: str, + num_threads: int, + num_nodes_to_cache: int, + cache_mechanism: int = 1, + distance_metric: Optional[DistanceMetric] = None, + vector_dtype: Optional[VectorDType] = None, + dimensions: Optional[int] = None, + index_prefix: str = "ann", + ): + """ + ### Parameters + - **index_directory**: The directory containing the index files. This directory must contain the following + files: + - `{index_prefix}_sample_data.bin` + - `{index_prefix}_mem.index.data` + - `{index_prefix}_pq_compressed.bin` + - `{index_prefix}_pq_pivots.bin` + - `{index_prefix}_sample_ids.bin` + - `{index_prefix}_disk.index` + + It may also include the following optional files: + - `{index_prefix}_vectors.bin`: Optional. `diskannpy` builder functions may create this file in the + `index_directory` if the index was created from a numpy array + - `{index_prefix}_metadata.bin`: Optional. `diskannpy` builder functions create this file to store metadata + about the index, such as vector dtype, distance metric, number of vectors and vector dimensionality. + If an index is built from the `diskann` cli tools, this file will not exist. + - **num_threads**: Number of threads to use when searching this index. (>= 0), 0 = num_threads in system + - **num_nodes_to_cache**: Number of nodes to cache in memory (> -1) + - **cache_mechanism**: 1 -> use the generated sample_data.bin file for + the index to initialize a set of cached nodes, up to `num_nodes_to_cache`, 2 -> ready the cache for up to + `num_nodes_to_cache`, but do not initialize it with any nodes. Any other value disables node caching. + - **distance_metric**: A `str`, strictly one of {"l2", "mips", "cosine"}. `l2` and `cosine` are supported for all 3 + vector dtypes, but `mips` is only available for single precision floats. Default is `None`. **This + value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it does not exist, + you are required to provide it. + - **vector_dtype**: The vector dtype this index has been built with. **This value is only used if a + `{index_prefix}_metadata.bin` file does not exist.** If it does not exist, you are required to provide it. + - **dimensions**: The vector dimensionality of this index. All new vectors inserted must be the same + dimensionality. **This value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it + does not exist, you are required to provide it. + - **index_prefix**: The prefix of the index files. Defaults to "ann". + """ + index_prefix = _valid_index_prefix(index_directory, index_prefix) + vector_dtype, metric, _, _ = _ensure_index_metadata( + index_prefix, + vector_dtype, + distance_metric, + 1, # it doesn't matter because we don't need it in this context anyway + dimensions, + ) + dap_metric = _valid_metric(metric) + + _assert_is_nonnegative_uint32(num_threads, "num_threads") + _assert_is_nonnegative_uint32(num_nodes_to_cache, "num_nodes_to_cache") + + self._vector_dtype = vector_dtype + if vector_dtype == np.uint8: + _index = _native_dap.StaticDiskUInt8Index + elif vector_dtype == np.int8: + _index = _native_dap.StaticDiskInt8Index + else: + _index = _native_dap.StaticDiskFloatIndex + self._index = _index( + distance_metric=dap_metric, + index_path_prefix=os.path.join(index_directory, index_prefix), + num_threads=num_threads, + num_nodes_to_cache=num_nodes_to_cache, + cache_mechanism=cache_mechanism, + ) + + def search( + self, query: VectorLike, k_neighbors: int, complexity: int, beam_width: int = 2 + ) -> QueryResponse: + """ + Searches the index by a single query vector. + + ### Parameters + - **query**: 1d numpy array of the same dimensionality and dtype of the index. + - **k_neighbors**: Number of neighbors to be returned. If query vector exists in index, it almost definitely + will be returned as well, so adjust your ``k_neighbors`` as appropriate. Must be > 0. + - **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size + increases accuracy at the cost of latency. Must be at least k_neighbors in size. + - **beam_width**: The beamwidth to be used for search. This is the maximum number of IO requests each query + will issue per iteration of search code. Larger beamwidth will result in fewer IO round-trips per query, + but might result in slightly higher total number of IO requests to SSD per query. For the highest query + throughput with a fixed SSD IOps rating, use W=1. For best latency, use W=4,8 or higher complexity search. + Specifying 0 will optimize the beamwidth depending on the number of threads performing search, but will + involve some tuning overhead. + """ + _query = _castable_dtype_or_raise(query, expected=self._vector_dtype) + _assert(len(_query.shape) == 1, "query vector must be 1-d") + _assert_is_positive_uint32(k_neighbors, "k_neighbors") + _assert_is_positive_uint32(complexity, "complexity") + _assert_is_positive_uint32(beam_width, "beam_width") + + if k_neighbors > complexity: + warnings.warn( + f"{k_neighbors=} asked for, but {complexity=} was smaller. Increasing {complexity} to {k_neighbors}" + ) + complexity = k_neighbors + + return self._index.search( + query=_query, + knn=k_neighbors, + complexity=complexity, + beam_width=beam_width, + ) + + def batch_search( + self, + queries: VectorLikeBatch, + k_neighbors: int, + complexity: int, + num_threads: int, + beam_width: int = 2, + ) -> QueryResponseBatch: + """ + Searches the index by a batch of query vectors. + + This search is parallelized and far more efficient than searching for each vector individually. + + ### Parameters + - **queries**: 2d numpy array, with column dimensionality matching the index and row dimensionality being the + number of queries intended to search for in parallel. Dtype must match dtype of the index. + - **k_neighbors**: Number of neighbors to be returned. If query vector exists in index, it almost definitely + will be returned as well, so adjust your ``k_neighbors`` as appropriate. Must be > 0. + - **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size + increases accuracy at the cost of latency. Must be at least k_neighbors in size. + - **num_threads**: Number of threads to use when searching this index. (>= 0), 0 = num_threads in system + - **beam_width**: The beamwidth to be used for search. This is the maximum number of IO requests each query + will issue per iteration of search code. Larger beamwidth will result in fewer IO round-trips per query, + but might result in slightly higher total number of IO requests to SSD per query. For the highest query + throughput with a fixed SSD IOps rating, use W=1. For best latency, use W=4,8 or higher complexity search. + Specifying 0 will optimize the beamwidth depending on the number of threads performing search, but will + involve some tuning overhead. + """ + _queries = _castable_dtype_or_raise(queries, expected=self._vector_dtype) + _assert_2d(_queries, "queries") + _assert_is_positive_uint32(k_neighbors, "k_neighbors") + _assert_is_positive_uint32(complexity, "complexity") + _assert_is_nonnegative_uint32(num_threads, "num_threads") + _assert_is_positive_uint32(beam_width, "beam_width") + + if k_neighbors > complexity: + warnings.warn( + f"{k_neighbors=} asked for, but {complexity=} was smaller. Increasing {complexity} to {k_neighbors}" + ) + complexity = k_neighbors + + num_queries, dim = _queries.shape + return self._index.batch_search( + queries=_queries, + num_queries=num_queries, + knn=k_neighbors, + complexity=complexity, + beam_width=beam_width, + num_threads=num_threads, + ) diff --git a/python/src/_static_memory_index.py b/python/src/_static_memory_index.py new file mode 100644 index 000000000..8b87cd561 --- /dev/null +++ b/python/src/_static_memory_index.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import os +import warnings +from typing import Optional + +import numpy as np + +from . import ( + DistanceMetric, + QueryResponse, + QueryResponseBatch, + VectorDType, + VectorLike, + VectorLikeBatch, +) +from . import _diskannpy as _native_dap +from ._common import ( + _assert, + _assert_is_nonnegative_uint32, + _assert_is_positive_uint32, + _castable_dtype_or_raise, + _ensure_index_metadata, + _valid_index_prefix, + _valid_metric, +) + +__ALL__ = ["StaticMemoryIndex"] + + +class StaticMemoryIndex: + """ + A StaticMemoryIndex is an immutable in-memory DiskANN index. + """ + + def __init__( + self, + index_directory: str, + num_threads: int, + initial_search_complexity: int, + index_prefix: str = "ann", + distance_metric: Optional[DistanceMetric] = None, + vector_dtype: Optional[VectorDType] = None, + dimensions: Optional[int] = None, + ): + """ + ### Parameters + - **index_directory**: The directory containing the index files. This directory must contain the following + files: + - `{index_prefix}.data` + - `{index_prefix}` + + + It may also include the following optional files: + - `{index_prefix}_vectors.bin`: Optional. `diskannpy` builder functions may create this file in the + `index_directory` if the index was created from a numpy array + - `{index_prefix}_metadata.bin`: Optional. `diskannpy` builder functions create this file to store metadata + about the index, such as vector dtype, distance metric, number of vectors and vector dimensionality. + If an index is built from the `diskann` cli tools, this file will not exist. + - **num_threads**: Number of threads to use when searching this index. (>= 0), 0 = num_threads in system + - **initial_search_complexity**: Should be set to the most common `complexity` expected to be used during the + life of this `diskannpy.DynamicMemoryIndex` object. The working scratch memory allocated is based off of + `initial_search_complexity` * `search_threads`. Note that it may be resized if a `search` or `batch_search` + operation requests a space larger than can be accommodated by these values. + - **index_prefix**: The prefix of the index files. Defaults to "ann". + - **distance_metric**: A `str`, strictly one of {"l2", "mips", "cosine"}. `l2` and `cosine` are supported for all 3 + vector dtypes, but `mips` is only available for single precision floats. Default is `None`. **This + value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it does not exist, + you are required to provide it. + - **vector_dtype**: The vector dtype this index has been built with. **This value is only used if a + `{index_prefix}_metadata.bin` file does not exist.** If it does not exist, you are required to provide it. + - **dimensions**: The vector dimensionality of this index. All new vectors inserted must be the same + dimensionality. **This value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it + does not exist, you are required to provide it. + """ + index_prefix = _valid_index_prefix(index_directory, index_prefix) + vector_dtype, metric, num_points, dims = _ensure_index_metadata( + index_prefix, + vector_dtype, + distance_metric, + 1, # it doesn't matter because we don't need it in this context anyway + dimensions, + ) + dap_metric = _valid_metric(metric) + + _assert_is_nonnegative_uint32(num_threads, "num_threads") + _assert_is_positive_uint32( + initial_search_complexity, "initial_search_complexity" + ) + + self._vector_dtype = vector_dtype + self._dimensions = dims + + if vector_dtype == np.uint8: + _index = _native_dap.StaticMemoryUInt8Index + elif vector_dtype == np.int8: + _index = _native_dap.StaticMemoryInt8Index + else: + _index = _native_dap.StaticMemoryFloatIndex + + self._index = _index( + distance_metric=dap_metric, + num_points=num_points, + dimensions=dims, + index_path=os.path.join(index_directory, index_prefix), + num_threads=num_threads, + initial_search_complexity=initial_search_complexity, + ) + + def search( + self, query: VectorLike, k_neighbors: int, complexity: int + ) -> QueryResponse: + """ + Searches the index by a single query vector. + + ### Parameters + - **query**: 1d numpy array of the same dimensionality and dtype of the index. + - **k_neighbors**: Number of neighbors to be returned. If query vector exists in index, it almost definitely + will be returned as well, so adjust your ``k_neighbors`` as appropriate. Must be > 0. + - **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size + increases accuracy at the cost of latency. Must be at least k_neighbors in size. + """ + _query = _castable_dtype_or_raise(query, expected=self._vector_dtype) + _assert(len(_query.shape) == 1, "query vector must be 1-d") + _assert( + _query.shape[0] == self._dimensions, + f"query vector must have the same dimensionality as the index; index dimensionality: {self._dimensions}, " + f"query dimensionality: {_query.shape[0]}", + ) + _assert_is_positive_uint32(k_neighbors, "k_neighbors") + _assert_is_nonnegative_uint32(complexity, "complexity") + + if k_neighbors > complexity: + warnings.warn( + f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}" + ) + complexity = k_neighbors + return self._index.search(query=_query, knn=k_neighbors, complexity=complexity) + + def batch_search( + self, + queries: VectorLikeBatch, + k_neighbors: int, + complexity: int, + num_threads: int, + ) -> QueryResponseBatch: + """ + Searches the index by a batch of query vectors. + + This search is parallelized and far more efficient than searching for each vector individually. + + ### Parameters + - **queries**: 2d numpy array, with column dimensionality matching the index and row dimensionality being the + number of queries intended to search for in parallel. Dtype must match dtype of the index. + - **k_neighbors**: Number of neighbors to be returned. If query vector exists in index, it almost definitely + will be returned as well, so adjust your ``k_neighbors`` as appropriate. Must be > 0. + - **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size + increases accuracy at the cost of latency. Must be at least k_neighbors in size. + - **num_threads**: Number of threads to use when searching this index. (>= 0), 0 = num_threads in system + """ + + _queries = _castable_dtype_or_raise(queries, expected=self._vector_dtype) + _assert(len(_queries.shape) == 2, "queries must must be 2-d np array") + _assert( + _queries.shape[1] == self._dimensions, + f"query vectors must have the same dimensionality as the index; index dimensionality: {self._dimensions}, " + f"query dimensionality: {_queries.shape[1]}", + ) + _assert_is_positive_uint32(k_neighbors, "k_neighbors") + _assert_is_positive_uint32(complexity, "complexity") + _assert_is_nonnegative_uint32(num_threads, "num_threads") + + if k_neighbors > complexity: + warnings.warn( + f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}" + ) + complexity = k_neighbors + + num_queries, dim = _queries.shape + return self._index.batch_search( + queries=_queries, + num_queries=num_queries, + knn=k_neighbors, + complexity=complexity, + num_threads=num_threads, + ) diff --git a/python/src/builder.cpp b/python/src/builder.cpp new file mode 100644 index 000000000..4485d66e6 --- /dev/null +++ b/python/src/builder.cpp @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "builder.h" +#include "common.h" +#include "disk_utils.h" +#include "index.h" +#include "parameters.h" + +namespace diskannpy +{ +template +void build_disk_index(const diskann::Metric metric, const std::string &data_file_path, + const std::string &index_prefix_path, const uint32_t complexity, const uint32_t graph_degree, + const double final_index_ram_limit, const double indexing_ram_budget, const uint32_t num_threads, + const uint32_t pq_disk_bytes) +{ + std::string params = std::to_string(graph_degree) + " " + std::to_string(complexity) + " " + + std::to_string(final_index_ram_limit) + " " + std::to_string(indexing_ram_budget) + " " + + std::to_string(num_threads); + if (pq_disk_bytes > 0) + params = params + " " + std::to_string(pq_disk_bytes); + diskann::build_disk_index
(data_file_path.c_str(), index_prefix_path.c_str(), params.c_str(), metric); +} + +template void build_disk_index(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t, + double, double, uint32_t, uint32_t); + +template void build_disk_index(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t, + double, double, uint32_t, uint32_t); +template void build_disk_index(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t, + double, double, uint32_t, uint32_t); + +template +void build_memory_index(const diskann::Metric metric, const std::string &vector_bin_path, + const std::string &index_output_path, const uint32_t graph_degree, const uint32_t complexity, + const float alpha, const uint32_t num_threads, const bool use_pq_build, + const size_t num_pq_bytes, const bool use_opq, const uint32_t filter_complexity, + const bool use_tags) +{ + diskann::IndexWriteParameters index_build_params = diskann::IndexWriteParametersBuilder(complexity, graph_degree) + .with_filter_list_size(filter_complexity) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + size_t data_num, data_dim; + diskann::get_bin_metadata(vector_bin_path, data_num, data_dim); + diskann::Index index(metric, data_dim, data_num, use_tags, use_tags, false, use_pq_build, + num_pq_bytes, use_opq); + + if (use_tags) + { + const std::string tags_file = index_output_path + ".tags"; + if (!file_exists(tags_file)) + { + throw std::runtime_error("tags file not found at expected path: " + tags_file); + } + TagT *tags_data; + size_t tag_dims = 1; + diskann::load_bin(tags_file, tags_data, data_num, tag_dims); + std::vector tags(tags_data, tags_data + data_num); + index.build(vector_bin_path.c_str(), data_num, index_build_params, tags); + } + else + { + index.build(vector_bin_path.c_str(), data_num, index_build_params); + } + + index.save(index_output_path.c_str()); +} + +template void build_memory_index(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t, + float, uint32_t, bool, size_t, bool, uint32_t, bool); + +template void build_memory_index(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t, + float, uint32_t, bool, size_t, bool, uint32_t, bool); + +template void build_memory_index(diskann::Metric, const std::string &, const std::string &, uint32_t, uint32_t, + float, uint32_t, bool, size_t, bool, uint32_t, bool); + +} // namespace diskannpy diff --git a/python/src/defaults.py b/python/src/defaults.py new file mode 100644 index 000000000..4e22983fd --- /dev/null +++ b/python/src/defaults.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +""" +# Parameter Defaults +These parameter defaults are re-exported from the C++ extension module, and used to keep the pythonic wrapper in sync with the C++. +""" +from ._diskannpy import defaults as _defaults + +ALPHA = _defaults.ALPHA +""" +Note that, as ALPHA is a `float32` (single precision float) in C++, when converted into Python it becomes a +`float64` (double precision float). The actual value is 1.2f. The alpha parameter (>=1) is used to control the nature +and number of points that are added to the graph. A higher alpha value (e.g., 1.4) will result in fewer hops (and IOs) +to convergence, but probably more distance comparisons compared to a lower alpha value. +""" +NUM_THREADS = _defaults.NUM_THREADS +""" Number of threads to use. `0` will use all available detected logical processors """ +MAX_OCCLUSION_SIZE = _defaults.MAX_OCCLUSION_SIZE +""" +The maximum number of points that can be occluded by a single point. This is used to prevent a single point from +dominating the graph structure. If a point has more than `max_occlusion_size` neighbors closer to it than the current +point, it will not be added to the graph. This is a tradeoff between index build time and search quality. +""" +FILTER_COMPLEXITY = _defaults.FILTER_COMPLEXITY +""" +Complexity (a.k.a. `L`) references the size of the list we store candidate approximate neighbors in while doing a +filtered search. This value must be larger than `k_neighbors`, and larger values tend toward higher recall in the +resultant ANN search at the cost of more time. +""" +NUM_FROZEN_POINTS_STATIC = _defaults.NUM_FROZEN_POINTS_STATIC +""" Number of points frozen by default in a StaticMemoryIndex """ +NUM_FROZEN_POINTS_DYNAMIC = _defaults.NUM_FROZEN_POINTS_DYNAMIC +""" Number of points frozen by default in a DynamicMemoryIndex """ +SATURATE_GRAPH = _defaults.SATURATE_GRAPH +""" Whether to saturate the graph or not. Default is `True` """ +GRAPH_DEGREE = _defaults.GRAPH_DEGREE +""" +Graph degree (a.k.a. `R`) is the maximum degree allowed for a node in the index's graph structure. This degree will be +pruned throughout the course of the index build, but it will never grow beyond this value. Higher R values require +longer index build times, but may result in an index showing excellent recall and latency characteristics. +""" +COMPLEXITY = _defaults.COMPLEXITY +""" +Complexity (a.k.a `L`) references the size of the list we store candidate approximate neighbors in while doing build +or search tasks. It's used during index build as part of the index optimization processes. It's used in index search +classes both to help mitigate poor latencies during cold start, as well as on subsequent queries to conduct the search. +Large values will likely increase latency but also may improve recall, and tuning these values for your particular +index is certainly a reasonable choice. +""" +PQ_DISK_BYTES = _defaults.PQ_DISK_BYTES +""" +Use `0` to store uncompressed data on SSD. This allows the index to asymptote to 100% recall. If your vectors are +too large to store in SSD, this parameter provides the option to compress the vectors using PQ for storing on SSD. +This will trade off recall. You would also want this to be greater than the number of bytes used for the PQ +compressed data stored in-memory. Default is `0`. +""" +USE_PQ_BUILD = _defaults.USE_PQ_BUILD +""" + Whether to use product quantization in the index building process. Product quantization is an approximation +technique that can vastly speed up vector computations and comparisons in a spatial neighborhood, but it is still an +approximation technique. It should be preferred when index creation times take longer than you can afford for your +use case. +""" +NUM_PQ_BYTES = _defaults.NUM_PQ_BYTES +""" +The number of product quantization bytes to use. More bytes requires more resources in both memory and time, but is +like to result in better approximations. +""" +USE_OPQ = _defaults.USE_OPQ +""" Whether to use Optimized Product Quantization or not. """ diff --git a/python/src/diskann_bindings.cpp b/python/src/diskann_bindings.cpp index 613211a0b..8b1378917 100644 --- a/python/src/diskann_bindings.cpp +++ b/python/src/diskann_bindings.cpp @@ -1,549 +1 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. -#include -#include -#include - -#include -#include -#include -#include -#include - -#ifdef _WINDOWS -#include "windows_aligned_file_reader.h" -#else -#include "linux_aligned_file_reader.h" -#endif - -#include "disk_utils.h" -#include "pq_flash_index.h" - -PYBIND11_MAKE_OPAQUE(std::vector); -PYBIND11_MAKE_OPAQUE(std::vector); -PYBIND11_MAKE_OPAQUE(std::vector); -PYBIND11_MAKE_OPAQUE(std::vector); - -namespace py = pybind11; -using namespace diskann; - -template struct DiskANNIndex -{ - PQFlashIndex *pq_flash_index; - std::shared_ptr reader; - - DiskANNIndex(diskann::Metric metric) - { -#ifdef _WINDOWS - reader = std::make_shared(); -#else - reader = std::make_shared(); -#endif - pq_flash_index = new PQFlashIndex(reader, metric); - } - - ~DiskANNIndex() - { - delete pq_flash_index; - } - - auto get_metric() - { - return pq_flash_index->get_metric(); - } - - void cache_bfs_levels(size_t num_nodes_to_cache) - { - std::vector node_list; - pq_flash_index->cache_bfs_levels(num_nodes_to_cache, node_list); - pq_flash_index->load_cache_list(node_list); - } - - void cache_sample_paths(size_t num_nodes_to_cache, const std::string &warmup_query_file, uint32_t num_threads) - { - if (!file_exists(warmup_query_file)) - { - return; - } - - std::vector node_list; - pq_flash_index->generate_cache_list_from_sample_queries(warmup_query_file, 15, 4, num_nodes_to_cache, - num_threads, node_list); - pq_flash_index->load_cache_list(node_list); - } - - int load_index(const std::string &index_path_prefix, const int num_threads, const size_t num_nodes_to_cache, - int cache_mechanism) - { - const std::string index_path = index_path_prefix + std::string("_disk.index"); - int load_success = pq_flash_index->load(num_threads, index_path.c_str()); - if (load_success != 0) - { - return load_success; - } - if (cache_mechanism == 0) - { - // Nothing to do - } - else if (cache_mechanism == 1) - { - std::string sample_file = index_path_prefix + std::string("_sample_data.bin"); - cache_sample_paths(num_nodes_to_cache, sample_file, num_threads); - } - else if (cache_mechanism == 2) - { - cache_bfs_levels(num_nodes_to_cache); - } - return 0; - } - - void search(std::vector &query, const _u64 query_idx, const _u64 dim, const _u64 num_queries, const _u64 knn, - const _u64 l_search, const _u64 beam_width, std::vector &ids, std::vector &dists) - { - QueryStats stats; - if (ids.size() < knn * num_queries) - { - ids.resize(knn * num_queries); - dists.resize(knn * num_queries); - } - std::vector<_u64> _u64_ids(knn); - pq_flash_index->cached_beam_search(query.data() + (query_idx * dim), knn, l_search, _u64_ids.data(), - dists.data() + (query_idx * knn), beam_width, &stats); - for (_u64 i = 0; i < knn; i++) - ids[(query_idx * knn) + i] = _u64_ids[i]; - } - - void batch_search(std::vector &queries, const _u64 dim, const _u64 num_queries, const _u64 knn, - const _u64 l_search, const _u64 beam_width, std::vector &ids, std::vector &dists, - const int num_threads) - { - if (ids.size() < knn * num_queries) - { - ids.resize(knn * num_queries); - dists.resize(knn * num_queries); - } - omp_set_num_threads(num_threads); -#pragma omp parallel for schedule(dynamic, 1) - for (int64_t q = 0; q < num_queries; ++q) - { - std::vector<_u64> u64_ids(knn); - - pq_flash_index->cached_beam_search(queries.data() + q * dim, knn, l_search, u64_ids.data(), - dists.data() + q * knn, beam_width); - for (_u64 i = 0; i < knn; i++) - ids[(q * knn) + i] = u64_ids[i]; - } - } - - auto search_numpy_input(py::array_t &query, const _u64 dim, - const _u64 knn, const _u64 l_search, const _u64 beam_width) - { - py::array_t ids(knn); - py::array_t dists(knn); - - std::vector u32_ids(knn); - std::vector<_u64> u64_ids(knn); - QueryStats stats; - - pq_flash_index->cached_beam_search(query.data(), knn, l_search, u64_ids.data(), dists.mutable_data(), - beam_width, &stats); - - auto r = ids.mutable_unchecked<1>(); - for (_u64 i = 0; i < knn; ++i) - r(i) = (unsigned)u64_ids[i]; - - return std::make_pair(ids, dists); - } - - auto batch_search_numpy_input(py::array_t &queries, const _u64 dim, - const _u64 num_queries, const _u64 knn, const _u64 l_search, const _u64 beam_width, - const int num_threads) - { - py::array_t ids({num_queries, knn}); - py::array_t dists({num_queries, knn}); - - std::vector<_u64> u64_ids(knn * num_queries); - diskann::QueryStats *stats = new diskann::QueryStats[num_queries]; - -#pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < num_queries; i++) - { - pq_flash_index->cached_beam_search(queries.data(i), knn, l_search, u64_ids.data() + i * knn, - dists.mutable_data(i), beam_width, stats + i); - } - - auto r = ids.mutable_unchecked(); - for (_u64 i = 0; i < num_queries; ++i) - for (_u64 j = 0; j < knn; ++j) - r(i, j) = (unsigned)u64_ids[i * knn + j]; - - std::unordered_map collective_stats; - collective_stats["mean_latency"] = diskann::get_mean_stats( - stats, num_queries, [](const diskann::QueryStats &stats) { return stats.total_us; }); - collective_stats["latency_999"] = diskann::get_percentile_stats( - stats, num_queries, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; }); - collective_stats["mean_ssd_ios"] = diskann::get_mean_stats( - stats, num_queries, [](const diskann::QueryStats &stats) { return stats.n_ios; }); - collective_stats["mean_dist_comps"] = diskann::get_mean_stats( - stats, num_queries, [](const diskann::QueryStats &stats) { return stats.n_cmps; }); - delete[] stats; - return std::make_pair(std::make_pair(ids, dists), collective_stats); - } - - auto batch_range_search_numpy_input(py::array_t &queries, - const _u64 dim, const _u64 num_queries, const double range, - const _u64 min_list_size, const _u64 max_list_size, const _u64 beam_width, - const int num_threads) - { - py::array_t offsets(num_queries + 1); - - std::vector> u64_ids(num_queries); - std::vector> dists(num_queries); - - auto offsets_mutable = offsets.mutable_unchecked(); - offsets_mutable(0) = 0; - - diskann::QueryStats *stats = new diskann::QueryStats[num_queries]; - -#pragma omp parallel for schedule(dynamic, 1) - for (int64_t i = 0; i < num_queries; i++) - { - _u32 res_count = pq_flash_index->range_search(queries.data(i), range, min_list_size, max_list_size, - u64_ids[i], dists[i], beam_width, stats + i); - offsets_mutable(i + 1) = res_count; - } - - uint64_t total_res_count = 0; - for (_u64 i = 0; i < num_queries; ++i) - { - total_res_count += offsets_mutable(i + 1); - } - - py::array_t ids(total_res_count); - py::array_t res_dists(total_res_count); - - auto ids_mutable = ids.mutable_unchecked(); - auto res_dists_mutable = res_dists.mutable_unchecked(); - size_t pos = 0; - for (_u64 i = 0; i < num_queries; ++i) - { - for (_u64 j = 0; j < offsets_mutable(i + 1); ++j) - { - ids_mutable(pos) = (unsigned)u64_ids[i][j]; - res_dists_mutable(pos++) = dists[i][j]; - } - offsets_mutable(i + 1) = offsets_mutable(i) + offsets_mutable(i + 1); - } - - std::unordered_map collective_stats; - collective_stats["mean_latency"] = diskann::get_mean_stats( - stats, num_queries, [](const diskann::QueryStats &stats) { return stats.total_us; }); - collective_stats["latency_999"] = diskann::get_percentile_stats( - stats, num_queries, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; }); - collective_stats["mean_ssd_ios"] = diskann::get_mean_stats( - stats, num_queries, [](const diskann::QueryStats &stats) { return stats.n_ios; }); - collective_stats["mean_dist_comps"] = diskann::get_mean_stats( - stats, num_queries, [](const diskann::QueryStats &stats) { return stats.n_cmps; }); - delete[] stats; - return std::make_pair(std::make_pair(offsets, std::make_pair(ids, res_dists)), collective_stats); - } -}; - -PYBIND11_MODULE(diskannpy, m) -{ - m.doc() = "DiskANN Python Bindings"; -#ifdef VERSION_INFO - m.attr("__version__") = VERSION_INFO; -#else - m.attr("__version__") = "dev"; -#endif - - py::bind_vector>(m, "VectorUnsigned"); - py::bind_vector>(m, "VectorFloat"); - py::bind_vector>(m, "VectorInt8"); - py::bind_vector>(m, "VectorUInt8"); - - py::enum_(m, "Metric") - .value("L2", Metric::L2) - .value("INNER_PRODUCT", Metric::INNER_PRODUCT) - .export_values(); - - py::class_(m, "Parameters") - .def(py::init<>()) - .def( - "set", - [](Parameters &self, const std::string &name, py::object value) { - if (py::isinstance(value)) - { - return self.Set(name, py::cast(value)); - } - else if (py::isinstance(value)) - { - return self.Set(name, py::cast(value)); - } - else if (py::isinstance(value)) - { - return self.Set(name, py::cast(value)); - } - }, - py::arg("name"), py::arg("value")); - - py::class_(m, "Neighbor") - .def(py::init<>()) - .def(py::init()) - .def(py::self < py::self) - .def(py::self == py::self); - - py::class_(m, "AlignedFileReader"); - -#ifdef _WINDOWS - py::class_(m, "WindowsAlignedFileReader").def(py::init<>()); -#else - py::class_(m, "LinuxAlignedFileReader").def(py::init<>()); -#endif - - m.def( - "omp_set_num_threads", [](const size_t num_threads) { omp_set_num_threads(num_threads); }, - py::arg("num_threads") = 1); - - m.def("omp_get_max_threads", []() { return omp_get_max_threads(); }); - - m.def( - "load_aligned_bin_float", - [](const std::string &path, std::vector &data) { - float *data_ptr = nullptr; - size_t num, dims, aligned_dims; - load_aligned_bin(path, data_ptr, num, dims, aligned_dims); - data.assign(data_ptr, data_ptr + num * aligned_dims); - auto l = py::list(3); - l[0] = py::int_(num); - l[1] = py::int_(dims); - l[2] = py::int_(aligned_dims); - aligned_free(data_ptr); - return l; - }, - py::arg("path"), py::arg("data")); - - m.def( - "load_truthset", - [](const std::string &path, std::vector &ids, std::vector &distances) { - unsigned *id_ptr = nullptr; - float *dist_ptr = nullptr; - size_t num, dims; - load_truthset(path, id_ptr, dist_ptr, num, dims); - // TODO: Remove redundant copies. - ids.assign(id_ptr, id_ptr + num * dims); - distances.assign(dist_ptr, dist_ptr + num * dims); - auto l = py::list(2); - l[0] = py::int_(num); - l[1] = py::int_(dims); - delete[] id_ptr; - delete[] dist_ptr; - return l; - }, - py::arg("path"), py::arg("ids"), py::arg("distances")); - - m.def( - "calculate_recall", - [](const unsigned num_queries, std::vector &ground_truth_ids, std::vector &ground_truth_dists, - const unsigned ground_truth_dims, std::vector &results, const unsigned result_dims, - const unsigned recall_at) { - unsigned *gti_ptr = ground_truth_ids.data(); - float *gtd_ptr = ground_truth_dists.data(); - unsigned *r_ptr = results.data(); - - double total_recall = 0; - std::set gt, res; - for (size_t i = 0; i < num_queries; i++) - { - gt.clear(); - res.clear(); - size_t tie_breaker = recall_at; - if (gtd_ptr != nullptr) - { - tie_breaker = recall_at - 1; - float *gt_dist_vec = gtd_ptr + ground_truth_dims * i; - while (tie_breaker < ground_truth_dims && gt_dist_vec[tie_breaker] == gt_dist_vec[recall_at - 1]) - tie_breaker++; - } - - gt.insert(gti_ptr + ground_truth_dims * i, gti_ptr + ground_truth_dims * i + tie_breaker); - res.insert(r_ptr + result_dims * i, r_ptr + result_dims * i + recall_at); - unsigned cur_recall = 0; - for (auto &v : gt) - { - if (res.find(v) != res.end()) - { - cur_recall++; - } - } - total_recall += cur_recall; - } - return py::float_(total_recall / (num_queries) * (100.0 / recall_at)); - }, - py::arg("num_queries"), py::arg("ground_truth_ids"), py::arg("ground_truth_dists"), - py::arg("ground_truth_dims"), py::arg("results"), py::arg("result_dims"), py::arg("recall_at")); - - m.def( - "calculate_recall_numpy_input", - [](const unsigned num_queries, std::vector &ground_truth_ids, std::vector &ground_truth_dists, - const unsigned ground_truth_dims, py::array_t &results, - const unsigned result_dims, const unsigned recall_at) { - unsigned *gti_ptr = ground_truth_ids.data(); - float *gtd_ptr = ground_truth_dists.data(); - unsigned *r_ptr = results.mutable_data(); - - double total_recall = 0; - std::set gt, res; - for (size_t i = 0; i < num_queries; i++) - { - gt.clear(); - res.clear(); - size_t tie_breaker = recall_at; - if (gtd_ptr != nullptr) - { - tie_breaker = recall_at - 1; - float *gt_dist_vec = gtd_ptr + ground_truth_dims * i; - while (tie_breaker < ground_truth_dims && gt_dist_vec[tie_breaker] == gt_dist_vec[recall_at - 1]) - tie_breaker++; - } - - gt.insert(gti_ptr + ground_truth_dims * i, gti_ptr + ground_truth_dims * i + tie_breaker); - res.insert(r_ptr + result_dims * i, r_ptr + result_dims * i + recall_at); - unsigned cur_recall = 0; - for (auto &v : gt) - { - if (res.find(v) != res.end()) - { - cur_recall++; - } - } - total_recall += cur_recall; - } - return py::float_(total_recall / (num_queries) * (100.0 / recall_at)); - }, - py::arg("num_queries"), py::arg("ground_truth_ids"), py::arg("ground_truth_dists"), - py::arg("ground_truth_dims"), py::arg("results"), py::arg("result_dims"), py::arg("recall_at")); - - m.def( - "save_bin_u32", - [](const std::string &file_name, std::vector &data, size_t npts, size_t dims) { - save_bin<_u32>(file_name, data.data(), npts, dims); - }, - py::arg("file_name"), py::arg("data"), py::arg("npts"), py::arg("dims")); - - py::class_>(m, "DiskANNFloatIndex") - .def(py::init([](diskann::Metric metric) { - return std::unique_ptr>(new DiskANNIndex(metric)); - })) - .def("cache_bfs_levels", &DiskANNIndex::cache_bfs_levels, py::arg("num_nodes_to_cache")) - .def("load_index", &DiskANNIndex::load_index, py::arg("index_path_prefix"), py::arg("num_threads"), - py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1) - .def("search", &DiskANNIndex::search, py::arg("query"), py::arg("query_idx"), py::arg("dim"), - py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("ids"), - py::arg("dists")) - .def("batch_search", &DiskANNIndex::batch_search, py::arg("queries"), py::arg("dim"), - py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("ids"), - py::arg("dists"), py::arg("num_threads")) - .def("search_numpy_input", &DiskANNIndex::search_numpy_input, py::arg("query"), py::arg("dim"), - py::arg("knn"), py::arg("l_search"), py::arg("beam_width")) - .def("batch_search_numpy_input", &DiskANNIndex::batch_search_numpy_input, py::arg("queries"), - py::arg("dim"), py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), - py::arg("num_threads")) - .def("batch_range_search_numpy_input", &DiskANNIndex::batch_range_search_numpy_input, py::arg("queries"), - py::arg("dim"), py::arg("num_queries"), py::arg("range"), py::arg("min_list_size"), - py::arg("max_list_size"), py::arg("beam_width"), py::arg("num_threads")) - .def( - "build", - [](DiskANNIndex &self, const char *data_file_path, const char *index_prefix_path, unsigned R, - unsigned L, double final_index_ram_limit, double indexing_ram_budget, unsigned num_threads, - unsigned pq_disk_bytes) { - std::string params = std::to_string(R) + " " + std::to_string(L) + " " + - std::to_string(final_index_ram_limit) + " " + std::to_string(indexing_ram_budget) + - " " + std::to_string(num_threads); - if (pq_disk_bytes > 0) - { - params = params + " " + std::to_string(pq_disk_bytes); - } - diskann::build_disk_index(data_file_path, index_prefix_path, params.c_str(), self.get_metric()); - }, - py::arg("data_file_path"), py::arg("index_prefix_path"), py::arg("R"), py::arg("L"), - py::arg("final_index_ram_limit"), py::arg("indexing_ram_limit"), py::arg("num_threads"), - py::arg("pq_disk_bytes") = 0); - - py::class_>(m, "DiskANNInt8Index") - .def(py::init([](diskann::Metric metric) { - return std::unique_ptr>(new DiskANNIndex(metric)); - })) - .def("cache_bfs_levels", &DiskANNIndex::cache_bfs_levels, py::arg("num_nodes_to_cache")) - .def("load_index", &DiskANNIndex::load_index, py::arg("index_path_prefix"), py::arg("num_threads"), - py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1) - .def("search", &DiskANNIndex::search, py::arg("query"), py::arg("query_idx"), py::arg("dim"), - py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("ids"), - py::arg("dists")) - .def("batch_search", &DiskANNIndex::batch_search, py::arg("queries"), py::arg("dim"), - py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("ids"), - py::arg("dists"), py::arg("num_threads")) - .def("search_numpy_input", &DiskANNIndex::search_numpy_input, py::arg("query"), py::arg("dim"), - py::arg("knn"), py::arg("l_search"), py::arg("beam_width")) - .def("batch_search_numpy_input", &DiskANNIndex::batch_search_numpy_input, py::arg("queries"), - py::arg("dim"), py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), - py::arg("num_threads")) - .def("batch_range_search_numpy_input", &DiskANNIndex::batch_range_search_numpy_input, - py::arg("queries"), py::arg("dim"), py::arg("num_queries"), py::arg("range"), py::arg("min_list_size"), - py::arg("max_list_size"), py::arg("beam_width"), py::arg("num_threads")) - .def( - "build", - [](DiskANNIndex &self, const char *data_file_path, const char *index_prefix_path, unsigned R, - unsigned L, double final_index_ram_limit, double indexing_ram_budget, unsigned num_threads, - unsigned pq_disk_bytes) { - std::string params = std::to_string(R) + " " + std::to_string(L) + " " + - std::to_string(final_index_ram_limit) + " " + std::to_string(indexing_ram_budget) + - " " + std::to_string(num_threads); - if (pq_disk_bytes > 0) - params = params + " " + std::to_string(pq_disk_bytes); - diskann::build_disk_index(data_file_path, index_prefix_path, params.c_str(), self.get_metric()); - }, - py::arg("data_file_path"), py::arg("index_prefix_path"), py::arg("R"), py::arg("L"), - py::arg("final_index_ram_limit"), py::arg("indexing_ram_limit"), py::arg("num_threads"), - py::arg("pq_disk_bytes") = 0); - - py::class_>(m, "DiskANNUInt8Index") - .def(py::init([](diskann::Metric metric) { - return std::unique_ptr>(new DiskANNIndex(metric)); - })) - .def("cache_bfs_levels", &DiskANNIndex::cache_bfs_levels, py::arg("num_nodes_to_cache")) - .def("load_index", &DiskANNIndex::load_index, py::arg("index_path_prefix"), py::arg("num_threads"), - py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1) - .def("search", &DiskANNIndex::search, py::arg("query"), py::arg("query_idx"), py::arg("dim"), - py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("ids"), - py::arg("dists")) - .def("batch_search", &DiskANNIndex::batch_search, py::arg("queries"), py::arg("dim"), - py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), py::arg("ids"), - py::arg("dists"), py::arg("num_threads")) - .def("search_numpy_input", &DiskANNIndex::search_numpy_input, py::arg("query"), py::arg("dim"), - py::arg("knn"), py::arg("l_search"), py::arg("beam_width")) - .def("batch_search_numpy_input", &DiskANNIndex::batch_search_numpy_input, py::arg("queries"), - py::arg("dim"), py::arg("num_queries"), py::arg("knn"), py::arg("l_search"), py::arg("beam_width"), - py::arg("num_threads")) - .def("batch_range_search_numpy_input", &DiskANNIndex::batch_range_search_numpy_input, - py::arg("queries"), py::arg("dim"), py::arg("num_queries"), py::arg("range"), py::arg("min_list_size"), - py::arg("max_list_size"), py::arg("beam_width"), py::arg("num_threads")) - .def( - "build", - [](DiskANNIndex &self, const char *data_file_path, const char *index_prefix_path, unsigned R, - unsigned L, double final_index_ram_limit, double indexing_ram_budget, unsigned num_threads, - unsigned pq_disk_bytes) { - std::string params = std::to_string(R) + " " + std::to_string(L) + " " + - std::to_string(final_index_ram_limit) + " " + std::to_string(indexing_ram_budget) + - " " + std::to_string(num_threads); - if (pq_disk_bytes > 0) - params = params + " " + std::to_string(pq_disk_bytes); - diskann::build_disk_index(data_file_path, index_prefix_path, params.c_str(), - self.get_metric()); - }, - py::arg("data_file_path"), py::arg("index_prefix_path"), py::arg("R"), py::arg("L"), - py::arg("final_index_ram_limit"), py::arg("indexing_ram_limit"), py::arg("num_threads"), - py::arg("pq_disk_bytes") = 0); -} diff --git a/python/src/dynamic_memory_index.cpp b/python/src/dynamic_memory_index.cpp new file mode 100644 index 000000000..af276b85f --- /dev/null +++ b/python/src/dynamic_memory_index.cpp @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "parameters.h" +#include "dynamic_memory_index.h" + +#include "pybind11/numpy.h" + +namespace diskannpy +{ + +diskann::IndexWriteParameters dynamic_index_write_parameters(const uint32_t complexity, const uint32_t graph_degree, + const bool saturate_graph, + const uint32_t max_occlusion_size, const float alpha, + const uint32_t num_threads, + const uint32_t filter_complexity, + const uint32_t num_frozen_points) +{ + return diskann::IndexWriteParametersBuilder(complexity, graph_degree) + .with_saturate_graph(saturate_graph) + .with_max_occlusion_size(max_occlusion_size) + .with_alpha(alpha) + .with_num_threads(num_threads) + .with_filter_list_size(filter_complexity) + .with_num_frozen_points(num_frozen_points) + .build(); +} + +template +diskann::Index dynamic_index_builder(const diskann::Metric m, + const diskann::IndexWriteParameters &write_params, + const size_t dimensions, const size_t max_vectors, + const uint32_t initial_search_complexity, + const uint32_t initial_search_threads, + const bool concurrent_consolidation) +{ + const uint32_t _initial_search_threads = + initial_search_threads != 0 ? initial_search_threads : omp_get_num_threads(); + return diskann::Index( + m, dimensions, max_vectors, + true, // dynamic_index + write_params, // used for insert + initial_search_complexity, // used to prepare the scratch space for searching. can / may + // be expanded if the search asks for a larger L. + _initial_search_threads, // also used for the scratch space + true, // enable_tags + concurrent_consolidation, + false, // pq_dist_build + 0, // num_pq_chunks + false); // use_opq = false +} + +template +DynamicMemoryIndex
::DynamicMemoryIndex(const diskann::Metric m, const size_t dimensions, const size_t max_vectors, + const uint32_t complexity, const uint32_t graph_degree, + const bool saturate_graph, const uint32_t max_occlusion_size, + const float alpha, const uint32_t num_threads, + const uint32_t filter_complexity, const uint32_t num_frozen_points, + const uint32_t initial_search_complexity, + const uint32_t initial_search_threads, const bool concurrent_consolidation) + : _initial_search_complexity(initial_search_complexity != 0 ? initial_search_complexity : complexity), + _write_parameters(dynamic_index_write_parameters(complexity, graph_degree, saturate_graph, max_occlusion_size, + alpha, num_threads, filter_complexity, num_frozen_points)), + _index(dynamic_index_builder
(m, _write_parameters, dimensions, max_vectors, _initial_search_complexity, + initial_search_threads, concurrent_consolidation)) +{ +} + +template void DynamicMemoryIndex
::load(const std::string &index_path) +{ + const std::string tags_file = index_path + ".tags"; + if (!file_exists(tags_file)) + { + throw std::runtime_error("tags file not found at expected path: " + tags_file); + } + _index.load(index_path.c_str(), _write_parameters.num_threads, _initial_search_complexity); +} + +template +int DynamicMemoryIndex
::insert(const py::array_t &vector, + const DynamicIdType id) +{ + return _index.insert_point(vector.data(), id); +} + +template +py::array_t DynamicMemoryIndex
::batch_insert( + py::array_t &vectors, + py::array_t &ids, const int32_t num_inserts, + const int num_threads) +{ + if (num_threads == 0) + omp_set_num_threads(omp_get_num_procs()); + else + omp_set_num_threads(num_threads); + py::array_t insert_retvals(num_inserts); + +#pragma omp parallel for schedule(dynamic, 1) default(none) shared(num_inserts, insert_retvals, vectors, ids) + for (int32_t i = 0; i < num_inserts; i++) + { + insert_retvals.mutable_data()[i] = _index.insert_point(vectors.data(i), *(ids.data(i))); + } + + return insert_retvals; +} + +template int DynamicMemoryIndex
::mark_deleted(const DynamicIdType id) +{ + return this->_index.lazy_delete(id); +} + +template void DynamicMemoryIndex
::save(const std::string &save_path, const bool compact_before_save) +{ + if (save_path.empty()) + { + throw std::runtime_error("A save_path must be provided"); + } + _index.save(save_path.c_str(), compact_before_save); +} + +template +NeighborsAndDistances DynamicMemoryIndex
::search( + py::array_t &query, const uint64_t knn, const uint64_t complexity) +{ + py::array_t ids(knn); + py::array_t dists(knn); + std::vector
empty_vector; + _index.search_with_tags(query.data(), knn, complexity, ids.mutable_data(), dists.mutable_data(), empty_vector); + return std::make_pair(ids, dists); +} + +template +NeighborsAndDistances DynamicMemoryIndex
::batch_search( + py::array_t &queries, const uint64_t num_queries, const uint64_t knn, + const uint64_t complexity, const uint32_t num_threads) +{ + py::array_t ids({num_queries, knn}); + py::array_t dists({num_queries, knn}); + std::vector
empty_vector; + + if (num_threads == 0) + omp_set_num_threads(omp_get_num_procs()); + else + omp_set_num_threads(static_cast(num_threads)); + +#pragma omp parallel for schedule(dynamic, 1) default(none) \ + shared(num_queries, queries, knn, complexity, ids, dists, empty_vector) + for (int64_t i = 0; i < (int64_t)num_queries; i++) + { + _index.search_with_tags(queries.data(i), knn, complexity, ids.mutable_data(i), dists.mutable_data(i), + empty_vector); + } + + return std::make_pair(ids, dists); +} + +template void DynamicMemoryIndex
::consolidate_delete() +{ + _index.consolidate_deletes(_write_parameters); +} + +template size_t DynamicMemoryIndex
::num_points() +{ + return _index.get_num_points(); +} + +template class DynamicMemoryIndex; +template class DynamicMemoryIndex; +template class DynamicMemoryIndex; + +}; // namespace diskannpy diff --git a/python/src/module.cpp b/python/src/module.cpp new file mode 100644 index 000000000..7aea9fc03 --- /dev/null +++ b/python/src/module.cpp @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include + +#include +#include + +#include "defaults.h" +#include "distance.h" + +#include "builder.h" +#include "dynamic_memory_index.h" +#include "static_disk_index.h" +#include "static_memory_index.h" + +PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); + +namespace py = pybind11; +using namespace pybind11::literals; + +struct Variant +{ + std::string disk_builder_name; + std::string memory_builder_name; + std::string dynamic_memory_index_name; + std::string static_memory_index_name; + std::string static_disk_index_name; +}; + +const Variant FloatVariant{"build_disk_float_index", "build_memory_float_index", "DynamicMemoryFloatIndex", + "StaticMemoryFloatIndex", "StaticDiskFloatIndex"}; + +const Variant UInt8Variant{"build_disk_uint8_index", "build_memory_uint8_index", "DynamicMemoryUInt8Index", + "StaticMemoryUInt8Index", "StaticDiskUInt8Index"}; + +const Variant Int8Variant{"build_disk_int8_index", "build_memory_int8_index", "DynamicMemoryInt8Index", + "StaticMemoryInt8Index", "StaticDiskInt8Index"}; + +template inline void add_variant(py::module_ &m, const Variant &variant) +{ + m.def(variant.disk_builder_name.c_str(), &diskannpy::build_disk_index, "distance_metric"_a, "data_file_path"_a, + "index_prefix_path"_a, "complexity"_a, "graph_degree"_a, "final_index_ram_limit"_a, "indexing_ram_budget"_a, + "num_threads"_a, "pq_disk_bytes"_a); + + m.def(variant.memory_builder_name.c_str(), &diskannpy::build_memory_index, "distance_metric"_a, + "data_file_path"_a, "index_output_path"_a, "graph_degree"_a, "complexity"_a, "alpha"_a, "num_threads"_a, + "use_pq_build"_a, "num_pq_bytes"_a, "use_opq"_a, "filter_complexity"_a = 0, "use_tags"_a = false); + + py::class_>(m, variant.static_memory_index_name.c_str()) + .def(py::init(), + "distance_metric"_a, "index_path"_a, "num_points"_a, "dimensions"_a, "num_threads"_a, + "initial_search_complexity"_a) + .def("search", &diskannpy::StaticMemoryIndex::search, "query"_a, "knn"_a, "complexity"_a) + .def("batch_search", &diskannpy::StaticMemoryIndex::batch_search, "queries"_a, "num_queries"_a, "knn"_a, + "complexity"_a, "num_threads"_a); + + py::class_>(m, variant.dynamic_memory_index_name.c_str()) + .def(py::init(), + "distance_metric"_a, "dimensions"_a, "max_vectors"_a, "complexity"_a, "graph_degree"_a, + "saturate_graph"_a = diskann::defaults::SATURATE_GRAPH, + "max_occlusion_size"_a = diskann::defaults::MAX_OCCLUSION_SIZE, "alpha"_a = diskann::defaults::ALPHA, + "num_threads"_a = diskann::defaults::NUM_THREADS, + "filter_complexity"_a = diskann::defaults::FILTER_LIST_SIZE, + "num_frozen_points"_a = diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC, "initial_search_complexity"_a = 0, + "search_threads"_a = 0, "concurrent_consolidation"_a = true) + .def("search", &diskannpy::DynamicMemoryIndex::search, "query"_a, "knn"_a, "complexity"_a) + .def("load", &diskannpy::DynamicMemoryIndex::load, "index_path"_a) + .def("batch_search", &diskannpy::DynamicMemoryIndex::batch_search, "queries"_a, "num_queries"_a, "knn"_a, + "complexity"_a, "num_threads"_a) + .def("batch_insert", &diskannpy::DynamicMemoryIndex::batch_insert, "vectors"_a, "ids"_a, "num_inserts"_a, + "num_threads"_a) + .def("save", &diskannpy::DynamicMemoryIndex::save, "save_path"_a = "", "compact_before_save"_a = false) + .def("insert", &diskannpy::DynamicMemoryIndex::insert, "vector"_a, "id"_a) + .def("mark_deleted", &diskannpy::DynamicMemoryIndex::mark_deleted, "id"_a) + .def("consolidate_delete", &diskannpy::DynamicMemoryIndex::consolidate_delete) + .def("num_points", &diskannpy::DynamicMemoryIndex::num_points); + + py::class_>(m, variant.static_disk_index_name.c_str()) + .def(py::init(), + "distance_metric"_a, "index_path_prefix"_a, "num_threads"_a, "num_nodes_to_cache"_a, + "cache_mechanism"_a = 1) + .def("cache_bfs_levels", &diskannpy::StaticDiskIndex::cache_bfs_levels, "num_nodes_to_cache"_a) + .def("search", &diskannpy::StaticDiskIndex::search, "query"_a, "knn"_a, "complexity"_a, "beam_width"_a) + .def("batch_search", &diskannpy::StaticDiskIndex::batch_search, "queries"_a, "num_queries"_a, "knn"_a, + "complexity"_a, "beam_width"_a, "num_threads"_a); +} + +PYBIND11_MODULE(_diskannpy, m) +{ + m.doc() = "DiskANN Python Bindings"; +#ifdef VERSION_INFO + m.attr("__version__") = VERSION_INFO; +#else + m.attr("__version__") = "dev"; +#endif + + // let's re-export our defaults + py::module_ default_values = m.def_submodule( + "defaults", + "A collection of the default values used for common diskann operations. `GRAPH_DEGREE` and `COMPLEXITY` are not" + " set as defaults, but some semi-reasonable default values are selected for your convenience. We urge you to " + "investigate their meaning and adjust them for your use cases."); + + default_values.attr("ALPHA") = diskann::defaults::ALPHA; + default_values.attr("NUM_THREADS") = diskann::defaults::NUM_THREADS; + default_values.attr("MAX_OCCLUSION_SIZE") = diskann::defaults::MAX_OCCLUSION_SIZE; + default_values.attr("FILTER_COMPLEXITY") = diskann::defaults::FILTER_LIST_SIZE; + default_values.attr("NUM_FROZEN_POINTS_STATIC") = diskann::defaults::NUM_FROZEN_POINTS_STATIC; + default_values.attr("NUM_FROZEN_POINTS_DYNAMIC") = diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC; + default_values.attr("SATURATE_GRAPH") = diskann::defaults::SATURATE_GRAPH; + default_values.attr("GRAPH_DEGREE") = diskann::defaults::MAX_DEGREE; + default_values.attr("COMPLEXITY") = diskann::defaults::BUILD_LIST_SIZE; + default_values.attr("PQ_DISK_BYTES") = (uint32_t)0; + default_values.attr("USE_PQ_BUILD") = false; + default_values.attr("NUM_PQ_BYTES") = (uint32_t)0; + default_values.attr("USE_OPQ") = false; + + add_variant(m, FloatVariant); + add_variant(m, UInt8Variant); + add_variant(m, Int8Variant); + + py::enum_(m, "Metric") + .value("L2", diskann::Metric::L2) + .value("INNER_PRODUCT", diskann::Metric::INNER_PRODUCT) + .value("COSINE", diskann::Metric::COSINE) + .export_values(); +} diff --git a/python/src/py.typed b/python/src/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/python/src/static_disk_index.cpp b/python/src/static_disk_index.cpp new file mode 100644 index 000000000..654f8ec30 --- /dev/null +++ b/python/src/static_disk_index.cpp @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "static_disk_index.h" + +#include "pybind11/numpy.h" + +namespace diskannpy +{ + +template +StaticDiskIndex
::StaticDiskIndex(const diskann::Metric metric, const std::string &index_path_prefix, + const uint32_t num_threads, const size_t num_nodes_to_cache, + const uint32_t cache_mechanism) + : _reader(std::make_shared()), _index(_reader, metric) +{ + int load_success = _index.load(num_threads, index_path_prefix.c_str()); + if (load_success != 0) + { + throw std::runtime_error("index load failed."); + } + if (cache_mechanism == 1) + { + std::string sample_file = index_path_prefix + std::string("_sample_data.bin"); + cache_sample_paths(num_nodes_to_cache, sample_file, num_threads); + } + else if (cache_mechanism == 2) + { + cache_bfs_levels(num_nodes_to_cache); + } +} + +template void StaticDiskIndex
::cache_bfs_levels(const size_t num_nodes_to_cache) +{ + std::vector node_list; + _index.cache_bfs_levels(num_nodes_to_cache, node_list); + _index.load_cache_list(node_list); +} + +template +void StaticDiskIndex
::cache_sample_paths(const size_t num_nodes_to_cache, const std::string &warmup_query_file, + const uint32_t num_threads) +{ + if (!file_exists(warmup_query_file)) + { + return; + } + + std::vector node_list; + _index.generate_cache_list_from_sample_queries(warmup_query_file, 15, 4, num_nodes_to_cache, num_threads, + node_list); + _index.load_cache_list(node_list); +} + +template +NeighborsAndDistances StaticDiskIndex
::search( + py::array_t &query, const uint64_t knn, const uint64_t complexity, + const uint64_t beam_width) +{ + py::array_t ids(knn); + py::array_t dists(knn); + + std::vector u32_ids(knn); + std::vector u64_ids(knn); + diskann::QueryStats stats; + + _index.cached_beam_search(query.data(), knn, complexity, u64_ids.data(), dists.mutable_data(), beam_width, false, + &stats); + + auto r = ids.mutable_unchecked<1>(); + for (uint64_t i = 0; i < knn; ++i) + r(i) = (unsigned)u64_ids[i]; + + return std::make_pair(ids, dists); +} + +template +NeighborsAndDistances StaticDiskIndex
::batch_search( + py::array_t &queries, const uint64_t num_queries, const uint64_t knn, + const uint64_t complexity, const uint64_t beam_width, const uint32_t num_threads) +{ + py::array_t ids({num_queries, knn}); + py::array_t dists({num_queries, knn}); + + omp_set_num_threads(num_threads); + + std::vector u64_ids(knn * num_queries); + +#pragma omp parallel for schedule(dynamic, 1) default(none) \ + shared(num_queries, queries, knn, complexity, u64_ids, dists, beam_width) + for (int64_t i = 0; i < (int64_t)num_queries; i++) + { + _index.cached_beam_search(queries.data(i), knn, complexity, u64_ids.data() + i * knn, dists.mutable_data(i), + beam_width); + } + + auto r = ids.mutable_unchecked(); + for (uint64_t i = 0; i < num_queries; ++i) + for (uint64_t j = 0; j < knn; ++j) + r(i, j) = (uint32_t)u64_ids[i * knn + j]; + + return std::make_pair(ids, dists); +} + +template class StaticDiskIndex; +template class StaticDiskIndex; +template class StaticDiskIndex; +} // namespace diskannpy \ No newline at end of file diff --git a/python/src/static_memory_index.cpp b/python/src/static_memory_index.cpp new file mode 100644 index 000000000..3bd927174 --- /dev/null +++ b/python/src/static_memory_index.cpp @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "static_memory_index.h" + +#include "pybind11/numpy.h" + +namespace diskannpy +{ + +template +diskann::Index static_index_builder(const diskann::Metric m, const size_t num_points, + const size_t dimensions, + const uint32_t initial_search_complexity) +{ + if (initial_search_complexity == 0) + { + throw std::runtime_error("initial_search_complexity must be a positive uint32_t"); + } + + return diskann::Index
(m, dimensions, num_points, + false, // not a dynamic_index + false, // no enable_tags/ids + false, // no concurrent_consolidate, + false, // pq_dist_build + 0, // num_pq_chunks + false, // use_opq = false + 0); // num_frozen_points +} + +template +StaticMemoryIndex
::StaticMemoryIndex(const diskann::Metric m, const std::string &index_prefix, + const size_t num_points, const size_t dimensions, const uint32_t num_threads, + const uint32_t initial_search_complexity) + : _index(static_index_builder
(m, num_points, dimensions, initial_search_complexity)) +{ + const uint32_t _num_threads = num_threads != 0 ? num_threads : omp_get_num_threads(); + _index.load(index_prefix.c_str(), _num_threads, initial_search_complexity); +} + +template +NeighborsAndDistances StaticMemoryIndex
::search( + py::array_t &query, const uint64_t knn, const uint64_t complexity) +{ + py::array_t ids(knn); + py::array_t dists(knn); + std::vector
empty_vector; + _index.search(query.data(), knn, complexity, ids.mutable_data(), dists.mutable_data()); + return std::make_pair(ids, dists); +} + +template +NeighborsAndDistances StaticMemoryIndex
::batch_search( + py::array_t &queries, const uint64_t num_queries, const uint64_t knn, + const uint64_t complexity, const uint32_t num_threads) +{ + const uint32_t _num_threads = num_threads != 0 ? num_threads : omp_get_num_threads(); + py::array_t ids({num_queries, knn}); + py::array_t dists({num_queries, knn}); + std::vector
empty_vector; + + omp_set_num_threads(static_cast(_num_threads)); + +#pragma omp parallel for schedule(dynamic, 1) default(none) shared(num_queries, queries, knn, complexity, ids, dists) + for (int64_t i = 0; i < (int64_t)num_queries; i++) + { + _index.search(queries.data(i), knn, complexity, ids.mutable_data(i), dists.mutable_data(i)); + } + + return std::make_pair(ids, dists); +} + +template class StaticMemoryIndex; +template class StaticMemoryIndex; +template class StaticMemoryIndex; + +} // namespace diskannpy \ No newline at end of file diff --git a/python/tests/fixtures/__init__.py b/python/tests/fixtures/__init__.py new file mode 100644 index 000000000..4aeb96087 --- /dev/null +++ b/python/tests/fixtures/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from .build_memory_index import build_random_vectors_and_memory_index +from .create_test_data import random_vectors, vectors_as_temp_file, write_vectors +from .recall import calculate_recall diff --git a/python/tests/fixtures/build_memory_index.py b/python/tests/fixtures/build_memory_index.py new file mode 100644 index 000000000..3c30bed25 --- /dev/null +++ b/python/tests/fixtures/build_memory_index.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import os +from tempfile import mkdtemp + +import diskannpy as dap +import numpy as np + +from .create_test_data import random_vectors + + +def build_random_vectors_and_memory_index( + dtype, metric, with_tags: bool = False, index_prefix: str = "ann", seed: int = 12345 +): + query_vectors: np.ndarray = random_vectors(1000, 10, dtype=dtype, seed=seed) + index_vectors: np.ndarray = random_vectors(10000, 10, dtype=dtype, seed=seed) + ann_dir = mkdtemp() + + if with_tags: + rng = np.random.default_rng(seed) + tags = np.arange(start=1, stop=10001, dtype=np.uint32) + rng.shuffle(tags) + else: + tags = "" + + dap.build_memory_index( + data=index_vectors, + distance_metric=metric, + index_directory=ann_dir, + graph_degree=16, + complexity=32, + alpha=1.2, + num_threads=0, + use_pq_build=False, + num_pq_bytes=8, + use_opq=False, + filter_complexity=32, + tags=tags, + index_prefix=index_prefix, + ) + + return ( + metric, + dtype, + query_vectors, + index_vectors, + ann_dir, + os.path.join(ann_dir, "vectors.bin"), + tags, + ) diff --git a/python/tests/fixtures/create_test_data.py b/python/tests/fixtures/create_test_data.py new file mode 100644 index 000000000..44e413ed6 --- /dev/null +++ b/python/tests/fixtures/create_test_data.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from contextlib import contextmanager +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import BinaryIO + +import numpy as np + + +def random_vectors(rows: int, dimensions: int, dtype, seed: int = 12345) -> np.ndarray: + rng = np.random.default_rng(seed) + if dtype == np.float32: + vectors = rng.random((rows, dimensions), dtype=dtype) + elif dtype == np.uint8: + vectors = rng.integers( + low=0, high=256, size=(rows, dimensions), dtype=dtype + ) # low is inclusive, high is exclusive + elif dtype == np.int8: + vectors = rng.integers( + low=-128, high=128, size=(rows, dimensions), dtype=dtype + ) # low is inclusive, high is exclusive + else: + raise RuntimeError("Only np.float32, np.int8, and np.uint8 are supported") + return vectors + + +def write_vectors(file_handler: BinaryIO, vectors: np.ndarray): + _ = file_handler.write(np.array(vectors.shape, dtype=np.int32).tobytes()) + _ = file_handler.write(vectors.tobytes()) + + +@contextmanager +def vectors_as_temp_file(vectors: np.ndarray) -> str: + temp = NamedTemporaryFile(mode="wb", delete=False) + write_vectors(temp, vectors) + temp.close() + yield temp.name + Path(temp.name).unlink() diff --git a/python/tests/fixtures/recall.py b/python/tests/fixtures/recall.py new file mode 100644 index 000000000..03f38f37c --- /dev/null +++ b/python/tests/fixtures/recall.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import numpy as np + + +def calculate_recall( + result_set_indices: np.ndarray, truth_set_indices: np.ndarray, recall_at: int = 5 +) -> float: + """ + result_set_indices and truth_set_indices correspond by row index. the columns in each row contain the indices of + the nearest neighbors, with result_set_indices being the approximate nearest neighbor results and truth_set_indices + being the brute force nearest neighbor calculation via sklearn's NearestNeighbor class. + :param result_set_indices: + :param truth_set_indices: + :param recall_at: + :return: + """ + found = 0 + for i in range(0, result_set_indices.shape[0]): + result_set_set = set(result_set_indices[i][0:recall_at]) + truth_set_set = set(truth_set_indices[i][0:recall_at]) + found += len(result_set_set.intersection(truth_set_set)) + return found / (result_set_indices.shape[0] * recall_at) diff --git a/python/tests/test_build_disk_index.py b/python/tests/test_build_disk_index.py deleted file mode 100644 index 2c7c304ed..000000000 --- a/python/tests/test_build_disk_index.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -import time -import argparse -import diskannpy -from diskannpy import Metric, Parameters, DiskANNFloatIndex - - -parser = argparse.ArgumentParser() -parser.add_argument('data_path', type=str, help='Path to the input base set of vectors.') -parser.add_argument('save_path', type=str, help='Path to the built index.') -parser.add_argument('R', type=int, help='Graph degree.') -parser.add_argument('L', type=int, help='Index build complexity.') -parser.add_argument('B', type=float, help='Memory budget in GB for the final index.') -parser.add_argument('M', type=float, help='Memory budget in GB for the index construction.') -parser.add_argument('T', type=int, help='Number of threads for index construction.') - -args = parser.parse_args() - -start = time.time() -index = DiskANNFloatIndex(diskannpy.L2) -index.build(args.data_path, args.save_path, args.R, args.L, args.B, args.M, args.T) -end = time.time() - -print("Indexing Time: " + str(end - start) + " seconds") \ No newline at end of file diff --git a/python/tests/test_builder.py b/python/tests/test_builder.py new file mode 100644 index 000000000..cc484c938 --- /dev/null +++ b/python/tests/test_builder.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import unittest + +import diskannpy as dap +import numpy as np + + +class TestBuildDiskIndex(unittest.TestCase): + def test_valid_shape(self): + rng = np.random.default_rng(12345) + rando = rng.random((1000, 100, 5), dtype=np.single) + with self.assertRaises(ValueError): + dap.build_disk_index( + data=rando, + distance_metric="l2", + index_directory="test", + complexity=5, + graph_degree=5, + search_memory_maximum=0.01, + build_memory_maximum=0.01, + num_threads=1, + pq_disk_bytes=0, + ) + + rando = rng.random(1000, dtype=np.single) + with self.assertRaises(ValueError): + dap.build_disk_index( + data=rando, + distance_metric="l2", + index_directory="test", + complexity=5, + graph_degree=5, + search_memory_maximum=0.01, + build_memory_maximum=0.01, + num_threads=1, + pq_disk_bytes=0, + ) + + def test_value_ranges_build(self): + good_ranges = { + "vector_dtype": np.single, + "distance_metric": "l2", + "graph_degree": 5, + "complexity": 5, + "search_memory_maximum": 0.01, + "build_memory_maximum": 0.01, + "num_threads": 1, + "pq_disk_bytes": 0, + } + bad_ranges = { + "vector_dtype": np.float64, + "distance_metric": "soups this time", + "graph_degree": -1, + "complexity": -1, + "search_memory_maximum": 0, + "build_memory_maximum": 0, + "num_threads": -1, + "pq_disk_bytes": -1, + } + for bad_value_key in good_ranges.keys(): + kwargs = good_ranges.copy() + kwargs[bad_value_key] = bad_ranges[bad_value_key] + with self.subTest( + f"testing bad value key: {bad_value_key} with bad value: {bad_ranges[bad_value_key]}" + ): + with self.assertRaises(ValueError): + dap.build_disk_index(data="test", index_directory="test", **kwargs) + + +class TestBuildMemoryIndex(unittest.TestCase): + def test_valid_shape(self): + rng = np.random.default_rng(12345) + rando = rng.random((1000, 100, 5), dtype=np.single) + with self.assertRaises(ValueError): + dap.build_memory_index( + data=rando, + distance_metric="l2", + index_directory="test", + complexity=5, + graph_degree=5, + alpha=1.2, + num_threads=1, + use_pq_build=False, + num_pq_bytes=0, + use_opq=False, + ) + + rando = rng.random(1000, dtype=np.single) + with self.assertRaises(ValueError): + dap.build_memory_index( + data=rando, + distance_metric="l2", + index_directory="test", + complexity=5, + graph_degree=5, + alpha=1.2, + num_threads=1, + use_pq_build=False, + num_pq_bytes=0, + use_opq=False, + ) + + def test_value_ranges_build(self): + good_ranges = { + "vector_dtype": np.single, + "distance_metric": "l2", + "graph_degree": 5, + "complexity": 5, + "alpha": 1.2, + "num_threads": 1, + "num_pq_bytes": 0, + } + bad_ranges = { + "vector_dtype": np.float64, + "distance_metric": "soups this time", + "graph_degree": -1, + "complexity": -1, + "alpha": -1.2, + "num_threads": 1, + "num_pq_bytes": -60, + } + for bad_value_key in good_ranges.keys(): + kwargs = good_ranges.copy() + kwargs[bad_value_key] = bad_ranges[bad_value_key] + with self.subTest( + f"testing bad value key: {bad_value_key} with bad value: {bad_ranges[bad_value_key]}" + ): + with self.assertRaises(ValueError): + dap.build_memory_index( + data="test", + index_directory="test", + use_pq_build=True, + use_opq=False, + **kwargs, + ) diff --git a/python/tests/test_dynamic_memory_index.py b/python/tests/test_dynamic_memory_index.py new file mode 100644 index 000000000..ff9c8981d --- /dev/null +++ b/python/tests/test_dynamic_memory_index.py @@ -0,0 +1,440 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import shutil +import tempfile +import unittest +import warnings + +import diskannpy as dap +import numpy as np +from fixtures import build_random_vectors_and_memory_index +from sklearn.neighbors import NearestNeighbors + + +def _calculate_recall( + result_set_tags: np.ndarray, + original_indices_to_tags: np.ndarray, + truth_set_indices: np.ndarray, + recall_at: int = 5, +) -> float: + found = 0 + for i in range(0, result_set_tags.shape[0]): + result_set_set = set(result_set_tags[i][0:recall_at]) + truth_set_set = set() + for knn_index in truth_set_indices[i][0:recall_at]: + truth_set_set.add( + original_indices_to_tags[knn_index] + ) # mapped into our tag number instead + found += len(result_set_set.intersection(truth_set_set)) + return found / (result_set_tags.shape[0] * recall_at) + + +class TestDynamicMemoryIndex(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls._test_matrix = [ + build_random_vectors_and_memory_index(np.float32, "l2", with_tags=True), + build_random_vectors_and_memory_index(np.uint8, "l2", with_tags=True), + build_random_vectors_and_memory_index(np.int8, "l2", with_tags=True), + build_random_vectors_and_memory_index(np.float32, "cosine", with_tags=True), + build_random_vectors_and_memory_index(np.uint8, "cosine", with_tags=True), + build_random_vectors_and_memory_index(np.int8, "cosine", with_tags=True), + ] + cls._example_ann_dir = cls._test_matrix[0][4] + + @classmethod + def tearDownClass(cls) -> None: + for test in cls._test_matrix: + try: + ann_dir = test[4] + shutil.rmtree(ann_dir, ignore_errors=True) + except: + pass + + def test_recall_and_batch(self): + for ( + metric, + dtype, + query_vectors, + index_vectors, + ann_dir, + vector_bin_file, + generated_tags, + ) in self._test_matrix: + with self.subTest(msg=f"Testing dtype {dtype}"): + index = dap.DynamicMemoryIndex.from_file( + index_directory=ann_dir, + max_vectors=11_000, + complexity=64, + graph_degree=32, + num_threads=16, + ) + + k = 5 + diskann_neighbors, diskann_distances = index.batch_search( + query_vectors, + k_neighbors=k, + complexity=5, + num_threads=16, + ) + if metric == "l2" or metric == "cosine": + knn = NearestNeighbors( + n_neighbors=100, algorithm="auto", metric=metric + ) + knn.fit(index_vectors) + knn_distances, knn_indices = knn.kneighbors(query_vectors) + recall = _calculate_recall( + diskann_neighbors, generated_tags, knn_indices, k + ) + self.assertTrue( + recall > 0.70, + f"Recall [{recall}] was not over 0.7", + ) + + def test_single(self): + for ( + metric, + dtype, + query_vectors, + index_vectors, + ann_dir, + vector_bin_file, + generated_tags, + ) in self._test_matrix: + with self.subTest(msg=f"Testing dtype {dtype}"): + index = dap.DynamicMemoryIndex( + distance_metric="l2", + vector_dtype=dtype, + dimensions=10, + max_vectors=11_000, + complexity=64, + graph_degree=32, + num_threads=16, + ) + index.batch_insert(vectors=index_vectors, vector_ids=generated_tags) + + k = 5 + ids, dists = index.search(query_vectors[0], k_neighbors=k, complexity=5) + self.assertEqual(ids.shape[0], k) + self.assertEqual(dists.shape[0], k) + + def test_valid_metric(self): + with self.assertRaises(ValueError): + dap.DynamicMemoryIndex( + distance_metric="sandwich", + vector_dtype=np.single, + dimensions=10, + max_vectors=11_000, + complexity=64, + graph_degree=32, + num_threads=16, + ) + with self.assertRaises(ValueError): + dap.DynamicMemoryIndex( + distance_metric=None, + vector_dtype=np.single, + dimensions=10, + max_vectors=11_000, + complexity=64, + graph_degree=32, + num_threads=16, + ) + dap.DynamicMemoryIndex( + distance_metric="l2", + vector_dtype=np.single, + dimensions=10, + max_vectors=11_000, + complexity=64, + graph_degree=32, + num_threads=16, + ) + dap.DynamicMemoryIndex( + distance_metric="mips", + vector_dtype=np.single, + dimensions=10, + max_vectors=11_000, + complexity=64, + graph_degree=32, + num_threads=16, + ) + dap.DynamicMemoryIndex( + distance_metric="MiPs", + vector_dtype=np.single, + dimensions=10, + max_vectors=11_000, + complexity=64, + graph_degree=32, + num_threads=16, + ) + + def test_valid_vector_dtype(self): + aliases = {np.single: np.float32, np.byte: np.int8, np.ubyte: np.uint8} + for ( + metric, + dtype, + query_vectors, + index_vectors, + ann_dir, + vector_bin_file, + generated_tags, + ) in self._test_matrix: + with self.subTest(): + index = dap.DynamicMemoryIndex( + distance_metric="l2", + vector_dtype=aliases[dtype], + dimensions=10, + max_vectors=11_000, + complexity=64, + graph_degree=32, + num_threads=16, + ) + + invalid = [np.double, np.float64, np.ulonglong] + for invalid_vector_dtype in invalid: + with self.subTest(): + with self.assertRaises(ValueError, msg=invalid_vector_dtype): + dap.DynamicMemoryIndex( + distance_metric="l2", + vector_dtype=invalid_vector_dtype, + dimensions=10, + max_vectors=11_000, + complexity=64, + graph_degree=32, + num_threads=16, + ) + + def test_value_ranges_ctor(self): + ( + metric, + dtype, + query_vectors, + index_vectors, + ann_dir, + vector_bin_file, + generated_tags, + ) = build_random_vectors_and_memory_index( + np.single, "l2", with_tags=True, index_prefix="not_ann" + ) + good_ranges = { + "distance_metric": "l2", + "vector_dtype": np.single, + "dimensions": 10, + "max_vectors": 11_000, + "complexity": 64, + "graph_degree": 32, + "max_occlusion_size": 10, + "alpha": 1.2, + "num_threads": 16, + "filter_complexity": 10, + "num_frozen_points": 10, + "initial_search_complexity": 32, + "search_threads": 0, + } + + bad_ranges = { + "distance_metric": "l200000", + "vector_dtype": np.double, + "dimensions": -1, + "max_vectors": -1, + "complexity": 0, + "graph_degree": 0, + "max_occlusion_size": -1, + "alpha": -1, + "num_threads": -1, + "filter_complexity": -1, + "num_frozen_points": -1, + "initial_search_complexity": -1, + "search_threads": -1, + } + for bad_value_key in good_ranges.keys(): + kwargs = good_ranges.copy() + kwargs[bad_value_key] = bad_ranges[bad_value_key] + with self.subTest(): + with self.assertRaises( + ValueError, + msg=f"expected to fail with parameter {bad_value_key}={bad_ranges[bad_value_key]}", + ): + index = dap.DynamicMemoryIndex(saturate_graph=False, **kwargs) + + def test_value_ranges_search(self): + good_ranges = {"complexity": 5, "k_neighbors": 10} + bad_ranges = {"complexity": -1, "k_neighbors": 0} + for bad_value_key in good_ranges.keys(): + kwargs = good_ranges.copy() + kwargs[bad_value_key] = bad_ranges[bad_value_key] + with self.subTest(msg=f"Test value ranges search with {kwargs=}"): + with self.assertRaises(ValueError): + index = dap.DynamicMemoryIndex.from_file( + index_directory=self._example_ann_dir, + num_threads=16, + initial_search_complexity=32, + max_vectors=10001, + complexity=64, + graph_degree=32, + ) + index.search(query=np.array([], dtype=np.single), **kwargs) + + def test_value_ranges_batch_search(self): + good_ranges = { + "complexity": 5, + "k_neighbors": 10, + "num_threads": 5, + } + bad_ranges = { + "complexity": 0, + "k_neighbors": 0, + "num_threads": -1, + } + for bad_value_key in good_ranges.keys(): + kwargs = good_ranges.copy() + kwargs[bad_value_key] = bad_ranges[bad_value_key] + with self.subTest(msg=f"Testing value ranges batch search with {kwargs=}"): + with self.assertRaises(ValueError): + index = dap.DynamicMemoryIndex.from_file( + index_directory=self._example_ann_dir, + num_threads=16, + initial_search_complexity=32, + max_vectors=10001, + complexity=64, + graph_degree=32, + ) + index.batch_search( + queries=np.array([[]], dtype=np.single), **kwargs + ) + + # Issue #400 + def test_issue400(self): + _, _, _, index_vectors, ann_dir, _, generated_tags = self._test_matrix[0] + + deletion_tag = generated_tags[10] # arbitrary choice + deletion_vector = index_vectors[10] + + index = dap.DynamicMemoryIndex.from_file( + index_directory=ann_dir, + num_threads=16, + initial_search_complexity=32, + max_vectors=10100, + complexity=64, + graph_degree=32, + ) + index.insert(np.array([1.0] * 10, dtype=np.single), 10099) + index.insert(np.array([2.0] * 10, dtype=np.single), 10050) + index.insert(np.array([3.0] * 10, dtype=np.single), 10053) + tags, distances = index.search( + np.array([3.0] * 10, dtype=np.single), k_neighbors=5, complexity=64 + ) + self.assertIn(10053, tags) + tags, distances = index.search(deletion_vector, k_neighbors=5, complexity=64) + self.assertIn( + deletion_tag, tags, "deletion_tag should exist, as we have not deleted yet" + ) + index.mark_deleted(deletion_tag) + tags, distances = index.search(deletion_vector, k_neighbors=5, complexity=64) + self.assertNotIn( + deletion_tag, + tags, + "deletion_tag should not exist, as we have marked it for deletion", + ) + with tempfile.TemporaryDirectory() as tmpdir: + index.save(tmpdir) + + index2 = dap.DynamicMemoryIndex.from_file( + index_directory=tmpdir, + num_threads=16, + initial_search_complexity=32, + max_vectors=10100, + complexity=64, + graph_degree=32, + ) + tags, distances = index2.search( + deletion_vector, k_neighbors=5, complexity=64 + ) + self.assertNotIn( + deletion_tag, + tags, + "deletion_tag should not exist, as we saved and reloaded the index without it", + ) + + def test_inserts_past_max_vectors(self): + def _tiny_index(): + return dap.DynamicMemoryIndex( + distance_metric="l2", + vector_dtype=np.float32, + dimensions=10, + max_vectors=2, + complexity=64, + graph_degree=32, + num_threads=16, + ) + + + rng = np.random.default_rng(12345) + + # insert 3 vectors and look for an exception + index = _tiny_index() + index.insert(rng.random(10, dtype=np.float32), 1) + index.insert(rng.random(10, dtype=np.float32), 2) + with self.assertRaises(RuntimeError): + index.insert(rng.random(10, dtype=np.float32), 3) + + # insert 2 vectors, delete 1, and insert another and expect a warning + index = _tiny_index() + index.insert(rng.random(10, dtype=np.float32), 1) + index.insert(rng.random(10, dtype=np.float32), 2) + index.mark_deleted(2) + with self.assertWarns(UserWarning): + self.assertEqual(index._removed_num_vectors, 1) + self.assertEqual(index._num_vectors, 2) + index.insert(rng.random(10, dtype=np.float32), 3) + self.assertEqual(index._removed_num_vectors, 0) + self.assertEqual(index._num_vectors, 2) + + # insert 3 batch and look for an exception + index = _tiny_index() + with self.assertRaises(RuntimeError): + index.batch_insert( + rng.random((3, 10), dtype=np.float32), + np.array([1,2,3], dtype=np.uint32) + ) + + + # insert 2 batch, remove 1, add 1 and expect a warning, remove 1, insert 2 batch and look for an exception + index = _tiny_index() + index.batch_insert( + rng.random((2, 10), dtype=np.float32), + np.array([1,2], dtype=np.uint32) + ) + index.mark_deleted(1) + with self.assertWarns(UserWarning): + index.insert(rng.random(10, dtype=np.float32), 3) + index.mark_deleted(2) + with self.assertRaises(RuntimeError): + index.batch_insert(rng.random((2,10), dtype=np.float32), np.array([4, 5], dtype=np.uint32)) + + # insert 1, remove it, add 2 batch, and expect a warning + index = _tiny_index() + index.insert(rng.random(10, dtype=np.float32), 1) + index.mark_deleted(1) + with self.assertWarns(UserWarning): + index.batch_insert(rng.random((2, 10), dtype=np.float32), np.array([10, 20], dtype=np.uint32)) + + # insert 2 batch, remove both, add 2 batch, and expect a warning + index = _tiny_index() + index.batch_insert(rng.random((2,10), dtype=np.float32), np.array([10, 20], dtype=np.uint32)) + index.mark_deleted(10) + index.mark_deleted(20) + with self.assertWarns(UserWarning): + index.batch_insert(rng.random((2, 10), dtype=np.float32), np.array([15, 25], dtype=np.uint32)) + + # insert 2 batch, remove both, consolidate_delete, add 2 batch and do not expect warning + index = _tiny_index() + index.batch_insert(rng.random((2,10), dtype=np.float32), np.array([10, 20], dtype=np.uint32)) + index.mark_deleted(10) + index.mark_deleted(20) + index.consolidate_delete() + with warnings.catch_warnings(): + warnings.simplefilter("error") # turns warnings into raised exceptions + index.batch_insert(rng.random((2, 10), dtype=np.float32), np.array([15, 25], dtype=np.uint32)) + + diff --git a/python/tests/test_search_disk_index.py b/python/tests/test_search_disk_index.py deleted file mode 100644 index 520a85a0f..000000000 --- a/python/tests/test_search_disk_index.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -import time -import argparse -import numpy as np -import diskannpy - - -parser = argparse.ArgumentParser() -parser.add_argument('query_path', type=str, help='Path to the input query set of vectors.') -parser.add_argument('ground_truth_path', type=str, help='Path to the input groundtruth set.') -parser.add_argument('index_path_prefix', type=str, help='Path prefix for index files.') -parser.add_argument('output_path_prefix', type=str, help='Prefix for the generated output files.') -parser.add_argument('K', type=int, help='k value for recall@K.') -parser.add_argument('W', type=int, help='Beamwidth for search.') -parser.add_argument('T', type=int, help='Number of threads to use for search.') -parser.add_argument('C', type=int, help='Number of nodes to cache for search.') - - -args = parser.parse_args() - -recall_at = args.K -W = args.W -# Use multi-threaded search only for batch mode. -num_threads = args.T -num_nodes_to_cache = args.C -single_query_mode = False -l_search = [40, 50, 60, 70, 80, 90, 100, 110, 120] - - -query_data = diskannpy.VectorFloat() -ground_truth_ids = diskannpy.VectorUnsigned() -ground_truth_dists = diskannpy.VectorFloat() - -num_queries, query_dims, query_aligned_dims = diskannpy.load_aligned_bin_float(args.query_path, query_data) -num_ground_truth, ground_truth_dims = diskannpy.load_truthset(args.ground_truth_path, ground_truth_ids, ground_truth_dists) - -index = diskannpy.DiskANNFloatIndex(diskannpy.L2) -if index.load_index(args.index_path_prefix, num_threads, num_nodes_to_cache) != 0: - print("Index load failed") -else: - print("Index Loaded") - - if single_query_mode: - print("Ls QPS Mean Latency (mus) 99.9 Latency Recall@10") - print("================================================================") - for i, L in enumerate(l_search): - latency_stats = [] - query_result_ids = diskannpy.VectorUnsigned() - query_result_dists = diskannpy.VectorFloat() - - s = time.time() - - for j in range(num_queries): - qs = time.time() - index.search(query_data, j, query_aligned_dims, num_queries, - recall_at, L, W, query_result_ids, query_result_dists) - qe = time.time() - latency_stats.append(float((qe - qs) * 1000000)) - - e = time.time() - qps = (num_queries / (e - s)) - recall = diskannpy.calculate_recall(num_queries, ground_truth_ids, - ground_truth_dists, ground_truth_dims, - query_result_ids, recall_at, - recall_at) - latency_stats.sort() - mean_latency = sum(latency_stats) / num_queries - print(str(L) + "{:>10}".format("{:.2f}".format(qps)) + - "{:>15}".format("{:.2f}".format(mean_latency)) + - "{:>20}".format("{:.2f}".format(latency_stats[int((0.999 * num_queries))])) - + "{:>15}".format("{:.2f}".format(recall))) - - result_path = args.output_path_prefix + "_" + str(L) + "_idx_uint32.bin" - diskannpy.save_bin_u32(result_path, query_result_ids, num_queries, recall_at) - else: - print("Ls QPS Mean Latency (mus) Recall@10") - print("=============================================") - for i, L in enumerate(l_search): - diskannpy.omp_set_num_threads(num_threads) - - query_result_ids = diskannpy.VectorUnsigned() - query_result_dists = diskannpy.VectorFloat() - - qs = time.time() - index.batch_search(query_data, query_aligned_dims, num_queries, - recall_at, L, W, - query_result_ids, query_result_dists, - num_threads) - qe = time.time() - latency_stats = float((qe - qs) * 1000000) - - qps = (num_queries / (qe - qs)) - recall = diskannpy.calculate_recall(num_queries, ground_truth_ids, - ground_truth_dists, ground_truth_dims, - query_result_ids, recall_at, - recall_at) - mean_latency = latency_stats / num_queries - print(str(L) + "{:>10}".format("{:.2f}".format(qps)) + - "{:>15}".format("{:.2f}".format(mean_latency)) + - "{:>15}".format("{:.2f}".format(recall))) - - result_path = args.output_path_prefix + "_" + str(L) + "_idx_uint32.bin" - diskannpy.save_bin_u32(result_path, query_result_ids, num_queries, recall_at) diff --git a/python/tests/test_search_disk_index_numpy.py b/python/tests/test_search_disk_index_numpy.py deleted file mode 100644 index d1f2e6827..000000000 --- a/python/tests/test_search_disk_index_numpy.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. -import time -import argparse -import numpy as np -import diskannpy - - -parser = argparse.ArgumentParser() -parser.add_argument('query_path', type=str, help='Path to the input query set of vectors.') -parser.add_argument('ground_truth_path', type=str, help='Path to the input groundtruth set.') -parser.add_argument('index_path_prefix', type=str, help='Path prefix for index files.') -parser.add_argument('K', type=int, help='k value for recall@K.') -parser.add_argument('W', type=int, help='Beamwidth for search.') -parser.add_argument('T', type=int, help='Number of threads to use for search.') -parser.add_argument('C', type=int, help='Number of nodes to cache for search.') - -args = parser.parse_args() -args = parser.parse_args() - -recall_at = args.K -W = args.W -# Use multi-threaded search only for batch mode. -num_threads = args.T -num_nodes_to_cache = args.C -l_search = [40, 50, 60, 70, 80, 90, 100, 110, 120] - - -query_data = diskannpy.VectorFloat() -ground_truth_ids = diskannpy.VectorUnsigned() -ground_truth_dists = diskannpy.VectorFloat() - -num_queries, query_dims, query_aligned_dims = diskannpy.load_aligned_bin_float(args.query_path, query_data) -num_ground_truth, ground_truth_dims = diskannpy.load_truthset(args.ground_truth_path, ground_truth_ids, ground_truth_dists) - -query_data_numpy = np.zeros((num_queries,query_dims), dtype=np.float32) -for i in range(0, num_queries): - for d in range(0, query_dims): - query_data_numpy[i,d] = query_data[i * query_aligned_dims + d] - -index = diskannpy.DiskANNFloatIndex(diskannpy.L2) -if index.load_index(args.index_path_prefix, num_threads, num_nodes_to_cache) != 0: - print("Index load failed") -else: - print("Index Loaded") - print("Ls QPS Recall@10") - print("========================") - for i, L in enumerate(l_search): - diskannpy.omp_set_num_threads(num_threads) - - qs = time.time() - ids, dists = index.batch_search_numpy_input(query_data_numpy, query_aligned_dims, - num_queries, recall_at, L, W, num_threads) - qe = time.time() - latency_stats = float((qe - qs) * 1000000) - qps = (num_queries / (qe - qs)) - - recall = diskannpy.calculate_recall_numpy_input(num_queries, ground_truth_ids, - ground_truth_dists, ground_truth_dims, - ids, recall_at, recall_at) - mean_latency = latency_stats / num_queries - print(str(L) + "{:>10}".format("{:.2f}".format(qps)) + "{:>15}".format("{:.2f}".format(recall))) diff --git a/python/tests/test_static_disk_index.py b/python/tests/test_static_disk_index.py new file mode 100644 index 000000000..4ba544106 --- /dev/null +++ b/python/tests/test_static_disk_index.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import shutil +import unittest +from tempfile import mkdtemp + +import diskannpy as dap +import numpy as np +from fixtures import calculate_recall, random_vectors, vectors_as_temp_file +from sklearn.neighbors import NearestNeighbors + + +def _build_random_vectors_and_index(dtype, metric): + query_vectors = random_vectors(1000, 10, dtype=dtype) + index_vectors = random_vectors(10000, 10, dtype=dtype) + with vectors_as_temp_file(index_vectors) as vector_temp: + ann_dir = mkdtemp() + dap.build_disk_index( + data=vector_temp, + distance_metric=metric, + vector_dtype=dtype, + index_directory=ann_dir, + graph_degree=16, + complexity=32, + search_memory_maximum=0.00003, + build_memory_maximum=1, + num_threads=1, + pq_disk_bytes=0, + ) + return metric, dtype, query_vectors, index_vectors, ann_dir + + +class TestStaticDiskIndex(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls._test_matrix = [ + _build_random_vectors_and_index(np.float32, "l2"), + _build_random_vectors_and_index(np.uint8, "l2"), + _build_random_vectors_and_index(np.int8, "l2"), + ] + cls._example_ann_dir = cls._test_matrix[0][4] + + @classmethod + def tearDownClass(cls) -> None: + for test in cls._test_matrix: + try: + _, _, _, _, ann_dir = test + shutil.rmtree(ann_dir, ignore_errors=True) + except: + pass + + def test_recall_and_batch(self): + for metric, dtype, query_vectors, index_vectors, ann_dir in self._test_matrix: + with self.subTest(msg=f"Testing dtype {dtype}"): + index = dap.StaticDiskIndex( + distance_metric="l2", + vector_dtype=dtype, + index_directory=ann_dir, + num_threads=16, + num_nodes_to_cache=10, + ) + + k = 5 + diskann_neighbors, diskann_distances = index.batch_search( + query_vectors, + k_neighbors=k, + complexity=5, + beam_width=2, + num_threads=16, + ) + if metric == "l2": + knn = NearestNeighbors( + n_neighbors=100, algorithm="auto", metric="l2" + ) + knn.fit(index_vectors) + knn_distances, knn_indices = knn.kneighbors(query_vectors) + recall = calculate_recall(diskann_neighbors, knn_indices, k) + self.assertTrue( + recall > 0.70, + f"Recall [{recall}] was not over 0.7", + ) + + def test_single(self): + for metric, dtype, query_vectors, index_vectors, ann_dir in self._test_matrix: + with self.subTest(msg=f"Testing dtype {dtype}"): + index = dap.StaticDiskIndex( + distance_metric="l2", + vector_dtype=dtype, + index_directory=ann_dir, + num_threads=16, + num_nodes_to_cache=10, + ) + + k = 5 + ids, dists = index.search( + query_vectors[0], k_neighbors=k, complexity=5, beam_width=2 + ) + self.assertEqual(ids.shape[0], k) + self.assertEqual(dists.shape[0], k) + + def test_value_ranges_search(self): + good_ranges = {"complexity": 5, "k_neighbors": 10, "beam_width": 2} + bad_ranges = {"complexity": -1, "k_neighbors": 0, "beam_width": 0} + for bad_value_key in good_ranges.keys(): + kwargs = good_ranges.copy() + kwargs[bad_value_key] = bad_ranges[bad_value_key] + with self.subTest(): + with self.assertRaises(ValueError): + index = dap.StaticDiskIndex( + distance_metric="l2", + vector_dtype=np.single, + index_directory=self._example_ann_dir, + num_threads=16, + num_nodes_to_cache=10, + ) + index.search(query=np.array([], dtype=np.single), **kwargs) + + def test_value_ranges_batch_search(self): + good_ranges = { + "complexity": 5, + "k_neighbors": 10, + "beam_width": 2, + "num_threads": 5, + } + bad_ranges = { + "complexity": 0, + "k_neighbors": 0, + "beam_width": -1, + "num_threads": -1, + } + for bad_value_key in good_ranges.keys(): + kwargs = good_ranges.copy() + kwargs[bad_value_key] = bad_ranges[bad_value_key] + with self.subTest(): + with self.assertRaises(ValueError): + index = dap.StaticDiskIndex( + distance_metric="l2", + vector_dtype=np.single, + index_directory=self._example_ann_dir, + num_threads=16, + num_nodes_to_cache=10, + ) + index.batch_search( + queries=np.array([[]], dtype=np.single), **kwargs + ) diff --git a/python/tests/test_static_memory_index.py b/python/tests/test_static_memory_index.py new file mode 100644 index 000000000..cb7f0f01d --- /dev/null +++ b/python/tests/test_static_memory_index.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import shutil +import unittest + +import diskannpy as dap +import numpy as np +from fixtures import build_random_vectors_and_memory_index, calculate_recall +from sklearn.neighbors import NearestNeighbors + + +class TestStaticMemoryIndex(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls._test_matrix = [ + build_random_vectors_and_memory_index(np.float32, "l2"), + build_random_vectors_and_memory_index(np.uint8, "l2"), + build_random_vectors_and_memory_index(np.int8, "l2"), + build_random_vectors_and_memory_index(np.float32, "cosine"), + build_random_vectors_and_memory_index(np.uint8, "cosine"), + build_random_vectors_and_memory_index(np.int8, "cosine"), + ] + cls._example_ann_dir = cls._test_matrix[0][4] + + @classmethod + def tearDownClass(cls) -> None: + for test in cls._test_matrix: + try: + ann_dir = test[4] + shutil.rmtree(ann_dir, ignore_errors=True) + except: + pass + + def test_recall_and_batch(self): + for ( + metric, + dtype, + query_vectors, + index_vectors, + ann_dir, + vector_bin_file, + _, + ) in self._test_matrix: + with self.subTest(msg=f"Testing dtype {dtype}"): + index = dap.StaticMemoryIndex( + index_directory=ann_dir, + num_threads=16, + initial_search_complexity=32, + ) + + k = 5 + diskann_neighbors, diskann_distances = index.batch_search( + query_vectors, + k_neighbors=k, + complexity=5, + num_threads=16, + ) + if metric in ["l2", "cosine"]: + knn = NearestNeighbors( + n_neighbors=100, algorithm="auto", metric=metric + ) + knn.fit(index_vectors) + knn_distances, knn_indices = knn.kneighbors(query_vectors) + recall = calculate_recall(diskann_neighbors, knn_indices, k) + self.assertTrue( + recall > 0.70, + f"Recall [{recall}] was not over 0.7", + ) + + def test_single(self): + for ( + metric, + dtype, + query_vectors, + index_vectors, + ann_dir, + vector_bin_file, + _, + ) in self._test_matrix: + with self.subTest(msg=f"Testing dtype {dtype}"): + index = dap.StaticMemoryIndex( + index_directory=ann_dir, + num_threads=16, + initial_search_complexity=32, + ) + + k = 5 + ids, dists = index.search(query_vectors[0], k_neighbors=k, complexity=5) + self.assertEqual(ids.shape[0], k) + self.assertEqual(dists.shape[0], k) + + def test_value_ranges_ctor(self): + ( + metric, + dtype, + query_vectors, + index_vectors, + ann_dir, + vector_bin_file, + _, + ) = build_random_vectors_and_memory_index(np.single, "l2", "not_ann") + good_ranges = { + "index_directory": ann_dir, + "num_threads": 16, + "initial_search_complexity": 32, + "index_prefix": "not_ann", + } + + bad_ranges = { + "index_directory": "sandwiches", + "num_threads": -100, + "initial_search_complexity": 0, + "index_prefix": "", + } + for bad_value_key in good_ranges.keys(): + kwargs = good_ranges.copy() + kwargs[bad_value_key] = bad_ranges[bad_value_key] + with self.subTest(): + with self.assertRaises(ValueError): + index = dap.StaticMemoryIndex(**kwargs) + + def test_value_ranges_search(self): + good_ranges = {"complexity": 5, "k_neighbors": 10} + bad_ranges = {"complexity": -1, "k_neighbors": 0} + for bad_value_key in good_ranges.keys(): + kwargs = good_ranges.copy() + kwargs[bad_value_key] = bad_ranges[bad_value_key] + with self.subTest(): + with self.assertRaises(ValueError): + index = dap.StaticMemoryIndex( + index_directory=self._example_ann_dir, + num_threads=16, + initial_search_complexity=32, + ) + index.search(query=np.array([], dtype=np.single), **kwargs) + + def test_value_ranges_batch_search(self): + good_ranges = { + "complexity": 5, + "k_neighbors": 10, + "num_threads": 5, + } + bad_ranges = { + "complexity": 0, + "k_neighbors": 0, + "num_threads": -1, + } + vector_bin_file = self._test_matrix[0][5] + for bad_value_key in good_ranges.keys(): + kwargs = good_ranges.copy() + kwargs[bad_value_key] = bad_ranges[bad_value_key] + with self.subTest(): + with self.assertRaises(ValueError): + index = dap.StaticMemoryIndex( + index_directory=self._example_ann_dir, + num_threads=16, + initial_search_complexity=32, + ) + index.batch_search( + queries=np.array([[]], dtype=np.single), **kwargs + ) diff --git a/rust/Cargo.lock b/rust/Cargo.lock new file mode 100644 index 000000000..2e58e9322 --- /dev/null +++ b/rust/Cargo.lock @@ -0,0 +1,1814 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "ahash" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstream" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is-terminal", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" + +[[package]] +name = "anstyle-parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" +dependencies = [ + "windows-sys 0.48.0", +] + +[[package]] +name = "anstyle-wincon" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" +dependencies = [ + "anstyle", + "windows-sys 0.48.0", +] + +[[package]] +name = "anyhow" +version = "1.0.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "base64" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "build_and_insert_delete_memory_index" +version = "0.1.0" +dependencies = [ + "diskann", + "logger", + "vector", +] + +[[package]] +name = "build_and_insert_memory_index" +version = "0.1.0" +dependencies = [ + "diskann", + "logger", + "vector", +] + +[[package]] +name = "build_disk_index" +version = "0.1.0" +dependencies = [ + "diskann", + "logger", + "openblas-src", + "vector", +] + +[[package]] +name = "build_memory_index" +version = "0.1.0" +dependencies = [ + "clap", + "diskann", + "logger", + "vector", +] + +[[package]] +name = "bumpalo" +version = "3.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" + +[[package]] +name = "bytemuck" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" + +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + +[[package]] +name = "bytes" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cblas" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3de46dff748ed7e891bc46faae117f48d2a7911041c6630aed3c61a3fe12326f" +dependencies = [ + "cblas-sys", + "libc", + "num-complex", +] + +[[package]] +name = "cblas-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65" +dependencies = [ + "libc", +] + +[[package]] +name = "cc" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "ciborium" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" + +[[package]] +name = "ciborium-ll" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" +dependencies = [ + "ciborium-io", + "half 1.8.2", +] + +[[package]] +name = "clap" +version = "4.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9394150f5b4273a1763355bd1c2ec54cc5a2593f790587bcd6b2c947cfa9211" +dependencies = [ + "clap_builder", + "clap_derive", + "once_cell", +] + +[[package]] +name = "clap_builder" +version = "4.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a78fbdd3cc2914ddf37ba444114bc7765bbdcb55ec9cbe6fa054f0137400717" +dependencies = [ + "anstream", + "anstyle", + "bitflags", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8cd2b2a819ad6eec39e8f1d6b53001af1e5469f8c177579cdaeb313115b825f" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.18", +] + +[[package]] +name = "clap_lex" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" + +[[package]] +name = "colorchoice" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" + +[[package]] +name = "convert_f32_to_bf16" +version = "0.1.0" +dependencies = [ + "half 2.2.1", +] + +[[package]] +name = "core-foundation" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" + +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c" +dependencies = [ + "cfg-if", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "dirs" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + +[[package]] +name = "diskann" +version = "0.1.0" +dependencies = [ + "approx", + "bincode", + "bit-vec", + "byteorder", + "cblas", + "cc", + "criterion", + "crossbeam", + "half 2.2.1", + "hashbrown 0.13.2", + "logger", + "num-traits", + "once_cell", + "openblas-src", + "platform", + "rand", + "rayon", + "serde", + "thiserror", + "vector", + "winapi", +] + +[[package]] +name = "either" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" + +[[package]] +name = "errno" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +dependencies = [ + "errno-dragonfly", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "errno-dragonfly" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "fastrand" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" +dependencies = [ + "instant", +] + +[[package]] +name = "filetime" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cbc844cecaee9d4443931972e1289c8ff485cb4cc2767cb03ca139ed6885153" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.2.16", + "windows-sys 0.48.0", +] + +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + +[[package]] +name = "flate2" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + +[[package]] +name = "half" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0" +dependencies = [ + "crunchy", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "hermit-abi" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" +dependencies = [ + "libc", +] + +[[package]] +name = "hermit-abi" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" + +[[package]] +name = "idna" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "io-lifetimes" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" +dependencies = [ + "hermit-abi 0.3.1", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "is-terminal" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" +dependencies = [ + "hermit-abi 0.3.1", + "io-lifetimes", + "rustix", + "windows-sys 0.48.0", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" + +[[package]] +name = "js-sys" +version = "0.3.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "libc" +version = "0.2.146" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" + +[[package]] +name = "linux-raw-sys" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" + +[[package]] +name = "load_and_insert_memory_index" +version = "0.1.0" +dependencies = [ + "diskann", + "logger", + "vector", +] + +[[package]] +name = "log" +version = "0.4.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" + +[[package]] +name = "logger" +version = "0.1.0" +dependencies = [ + "lazy_static", + "log", + "once_cell", + "prost", + "prost-build", + "prost-types", + "thiserror", + "vcpkg", + "win_etw_macros", + "win_etw_provider", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + +[[package]] +name = "multimap" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" + +[[package]] +name = "native-tls" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" +dependencies = [ + "lazy_static", + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "num-complex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_cpus" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +dependencies = [ + "hermit-abi 0.2.6", + "libc", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + +[[package]] +name = "openblas-build" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eba42c395477605f400a8d79ee0b756cfb82abe3eb5618e35fa70d3a36010a7f" +dependencies = [ + "anyhow", + "flate2", + "native-tls", + "tar", + "thiserror", + "ureq", + "walkdir", +] + +[[package]] +name = "openblas-src" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38e5d8af0b707ac2fe1574daa88b4157da73b0de3dc7c39fe3e2c0bb64070501" +dependencies = [ + "dirs", + "openblas-build", + "vcpkg", +] + +[[package]] +name = "openssl" +version = "0.10.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.90" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "percent-encoding" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" + +[[package]] +name = "petgraph" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" +dependencies = [ + "fixedbitset", + "indexmap", +] + +[[package]] +name = "pkg-config" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" + +[[package]] +name = "platform" +version = "0.1.0" +dependencies = [ + "log", + "winapi", +] + +[[package]] +name = "plotters" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" + +[[package]] +name = "plotters-svg" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "prettyplease" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86" +dependencies = [ + "proc-macro2", + "syn 1.0.109", +] + +[[package]] +name = "proc-macro2" +version = "1.0.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec2b086b7a862cf4de201096214fa870344cf922b2b30c167badb3af3195406" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "prost" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270" +dependencies = [ + "bytes", + "heck", + "itertools", + "lazy_static", + "log", + "multimap", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn 1.0.109", + "tempfile", + "which", +] + +[[package]] +name = "prost-derive" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "prost-types" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13" +dependencies = [ + "prost", +] + +[[package]] +name = "quote" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rayon" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "num_cpus", +] + +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_users" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" +dependencies = [ + "getrandom", + "redox_syscall 0.2.16", + "thiserror", +] + +[[package]] +name = "regex" +version = "1.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +dependencies = [ + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" + +[[package]] +name = "rustix" +version = "0.37.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b96e891d04aa506a6d1f318d2771bcb1c7dfda84e126660ace067c9b474bb2c0" +dependencies = [ + "bitflags", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys", + "windows-sys 0.48.0", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" +dependencies = [ + "base64", +] + +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "schannel" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" +dependencies = [ + "windows-sys 0.42.0", +] + +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + +[[package]] +name = "search_memory_index" +version = "0.1.0" +dependencies = [ + "bytemuck", + "diskann", + "num_cpus", + "rayon", + "vector", +] + +[[package]] +name = "security-framework" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "serde" +version = "1.0.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", +] + +[[package]] +name = "serde_json" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdf3bf93142acad5821c99197022e170842cdbc1c30482b98750c688c640842a" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32d41677bcbe24c20c52e7c70b0d8db04134c5d1066bf98662e2871ad200ea3e" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tar" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b55807c0344e1e6c04d7c965f5289c39a8d94ae23ed5c0b57aabac549f871c6" +dependencies = [ + "filetime", + "libc", + "xattr", +] + +[[package]] +name = "tempfile" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" +dependencies = [ + "autocfg", + "cfg-if", + "fastrand", + "redox_syscall 0.3.5", + "rustix", + "windows-sys 0.48.0", +] + +[[package]] +name = "thiserror" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "unicode-bidi" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" + +[[package]] +name = "unicode-ident" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" + +[[package]] +name = "unicode-normalization" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "ureq" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" +dependencies = [ + "base64", + "flate2", + "log", + "native-tls", + "once_cell", + "rustls-native-certs", + "url", +] + +[[package]] +name = "url" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + +[[package]] +name = "uuid" +version = "1.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81" +dependencies = [ + "sha1_smol", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "vector" +version = "0.1.0" +dependencies = [ + "approx", + "base64", + "bincode", + "bytemuck", + "cc", + "half 2.2.1", + "rand", + "serde", + "thiserror", +] + +[[package]] +name = "vector_base64" +version = "0.1.0" +dependencies = [ + "base64", + "bincode", + "half 2.2.1", + "serde", +] + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "w32-error" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7c61a6bd91e168c12fc170985725340f6b458eb6f971d1cf6c34f74ffafb43" +dependencies = [ + "winapi", +] + +[[package]] +name = "walkdir" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.18", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" + +[[package]] +name = "web-sys" +version = "0.3.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "which" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2441c784c52b289a054b7201fc93253e288f094e2f4be9058343127c4226a269" +dependencies = [ + "either", + "libc", + "once_cell", +] + +[[package]] +name = "widestring" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" + +[[package]] +name = "win_etw_macros" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bc4c591edb4858e3445f7a60c7e0a50915aedadfa044f28f17c98c145ef54d" +dependencies = [ + "proc-macro2", + "quote", + "sha1_smol", + "syn 1.0.109", + "uuid", + "win_etw_metadata", +] + +[[package]] +name = "win_etw_metadata" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e50d0fa665033a19ecefd281b4fb5481eba2972dedbb5ec129c9392a206d652f" +dependencies = [ + "bitflags", +] + +[[package]] +name = "win_etw_provider" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dffcc196e0e180e73a275a91f6914f173227fd627cabac3efdd8d6adec113892" +dependencies = [ + "w32-error", + "widestring", + "win_etw_metadata", + "winapi", + "zerocopy", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +dependencies = [ + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + +[[package]] +name = "xattr" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d1526bbe5aaeb5eb06885f4d987bcdfa5e23187055de9b83fe00156a821fabc" +dependencies = [ + "libc", +] + +[[package]] +name = "zerocopy" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "332f188cc1bcf1fe1064b8c58d150f497e697f49774aa846f2dc949d9a25f236" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6505e6815af7de1746a08f69c69606bb45695a17149517680f3b2149713b19a3" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 000000000..5236f96a0 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +[workspace] +members = [ + "cmd_drivers/build_memory_index", + "cmd_drivers/build_and_insert_memory_index", + "cmd_drivers/load_and_insert_memory_index", + "cmd_drivers/convert_f32_to_bf16", + "cmd_drivers/search_memory_index", + "cmd_drivers/build_disk_index", + "cmd_drivers/build_and_insert_delete_memory_index", + "vector", + "diskann", + "platform", + "logger", + "vector_base64" +] +resolver = "2" + +[profile.release] +opt-level = 3 +codegen-units=1 diff --git a/rust/cmd_drivers/build_and_insert_delete_memory_index/Cargo.toml b/rust/cmd_drivers/build_and_insert_delete_memory_index/Cargo.toml new file mode 100644 index 000000000..42aa1851a --- /dev/null +++ b/rust/cmd_drivers/build_and_insert_delete_memory_index/Cargo.toml @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "build_and_insert_delete_memory_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +diskann = { path = "../../diskann" } +logger = { path = "../../logger" } +vector = { path = "../../vector" } + diff --git a/rust/cmd_drivers/build_and_insert_delete_memory_index/src/main.rs b/rust/cmd_drivers/build_and_insert_delete_memory_index/src/main.rs new file mode 100644 index 000000000..4593a9ed5 --- /dev/null +++ b/rust/cmd_drivers/build_and_insert_delete_memory_index/src/main.rs @@ -0,0 +1,420 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::env; + +use diskann::{ + common::{ANNError, ANNResult}, + index::create_inmem_index, + model::{ + configuration::index_write_parameters::IndexWriteParametersBuilder, + vertex::{DIM_104, DIM_128, DIM_256}, + IndexConfiguration, + }, + utils::round_up, + utils::{file_exists, load_ids_to_delete_from_file, load_metadata_from_file, Timer}, +}; + +use vector::{FullPrecisionDistance, Half, Metric}; + +// The main function to build an in-memory index +#[allow(clippy::too_many_arguments)] +fn build_and_insert_delete_in_memory_index( + metric: Metric, + data_path: &str, + delta_path: &str, + r: u32, + l: u32, + alpha: f32, + save_path: &str, + num_threads: u32, + _use_pq_build: bool, + _num_pq_bytes: usize, + use_opq: bool, + delete_path: &str, +) -> ANNResult<()> +where + T: Default + Copy + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + let index_write_parameters = IndexWriteParametersBuilder::new(l, r) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + let (data_num, data_dim) = load_metadata_from_file(data_path)?; + + let config = IndexConfiguration::new( + metric, + data_dim, + round_up(data_dim as u64, 8_u64) as usize, + data_num, + false, + 0, + use_opq, + 0, + 2.0f32, + index_write_parameters, + ); + let mut index = create_inmem_index::(config)?; + + let timer = Timer::new(); + + index.build(data_path, data_num)?; + + let diff = timer.elapsed(); + + println!("Initial indexing time: {}", diff.as_secs_f64()); + + let (delta_data_num, _) = load_metadata_from_file(delta_path)?; + + index.insert(delta_path, delta_data_num)?; + + if !delete_path.is_empty() { + if !file_exists(delete_path) { + return Err(ANNError::log_index_error(format!( + "ERROR: Data file for delete {} does not exist.", + delete_path + ))); + } + + let (num_points_to_delete, vertex_ids_to_delete) = + load_ids_to_delete_from_file(delete_path)?; + index.soft_delete(vertex_ids_to_delete, num_points_to_delete)?; + } + + index.save(save_path)?; + + Ok(()) +} + +fn main() -> ANNResult<()> { + let mut data_type = String::new(); + let mut dist_fn = String::new(); + let mut data_path = String::new(); + let mut insert_path = String::new(); + let mut index_path_prefix = String::new(); + let mut delete_path = String::new(); + + let mut num_threads = 0u32; + let mut r = 64u32; + let mut l = 100u32; + + let mut alpha = 1.2f32; + let mut build_pq_bytes = 0u32; + let mut _use_pq_build = false; + let mut use_opq = false; + + let args: Vec = env::args().collect(); + let mut iter = args.iter().skip(1).peekable(); + + while let Some(arg) = iter.next() { + match arg.as_str() { + "--help" | "-h" => { + print_help(); + return Ok(()); + } + "--data_type" => { + data_type = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_type".to_string(), + "Missing data type".to_string(), + ) + })? + .to_owned(); + } + "--dist_fn" => { + dist_fn = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "dist_fn".to_string(), + "Missing distance function".to_string(), + ) + })? + .to_owned(); + } + "--data_path" => { + data_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_path".to_string(), + "Missing data path".to_string(), + ) + })? + .to_owned(); + } + "--insert_path" => { + insert_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "insert_path".to_string(), + "Missing insert path".to_string(), + ) + })? + .to_owned(); + } + "--index_path_prefix" => { + index_path_prefix = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "index_path_prefix".to_string(), + "Missing index path prefix".to_string(), + ) + })? + .to_owned(); + } + "--max_degree" | "-R" => { + r = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "max_degree".to_string(), + "Missing max degree".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "max_degree".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--Lbuild" | "-L" => { + l = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + "Missing build complexity".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--alpha" => { + alpha = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "alpha".to_string(), + "Missing alpha".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "alpha".to_string(), + format!("ParseFloatError: {}", err), + ) + })?; + } + "--num_threads" | "-T" => { + num_threads = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "num_threads".to_string(), + "Missing number of threads".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "num_threads".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--build_PQ_bytes" => { + build_pq_bytes = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + "Missing PQ bytes".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--use_opq" => { + use_opq = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "use_opq".to_string(), + "Missing use_opq flag".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "use_opq".to_string(), + format!("ParseBoolError: {}", err), + ) + })?; + } + "--delete_path" => { + delete_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "delete_path".to_string(), + "Missing delete_path".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "delete_set_path".to_string(), + format!("ParseStringError: {}", err), + ) + })?; + } + _ => { + return Err(ANNError::log_index_config_error( + String::from(""), + format!("Unknown argument: {}", arg), + )); + } + } + } + + if data_type.is_empty() + || dist_fn.is_empty() + || data_path.is_empty() + || index_path_prefix.is_empty() + { + return Err(ANNError::log_index_config_error( + String::from(""), + "Missing required arguments".to_string(), + )); + } + + _use_pq_build = build_pq_bytes > 0; + + let metric = dist_fn + .parse::() + .map_err(|err| ANNError::log_index_config_error("dist_fn".to_string(), err.to_string()))?; + + println!( + "Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}", + r, l, alpha, num_threads + ); + + match data_type.as_str() { + "int8" => { + build_and_insert_delete_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + &delete_path, + )?; + } + "uint8" => { + build_and_insert_delete_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + &delete_path, + )?; + } + "float" => { + build_and_insert_delete_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + &delete_path, + )?; + } + "f16" => { + build_and_insert_delete_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + &delete_path, + )?; + } + _ => { + println!("Unsupported type. Use one of int8, uint8 or float."); + return Err(ANNError::log_index_config_error( + "data_type".to_string(), + "Invalid data type".to_string(), + )); + } + } + + Ok(()) +} + +fn print_help() { + println!("Arguments"); + println!("--help, -h Print information on arguments"); + println!("--data_type data type (required)"); + println!("--dist_fn distance function (required)"); + println!( + "--data_path Input data file in bin format for initial build (required)" + ); + println!("--insert_path Input data file in bin format for insert (required)"); + println!("--index_path_prefix Path prefix for saving index file components (required)"); + println!("--max_degree, -R Maximum graph degree (default: 64)"); + println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)"); + println!("--alpha alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter (default: 1.2)"); + println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)"); + println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)"); + println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)"); +} + diff --git a/rust/cmd_drivers/build_and_insert_memory_index/Cargo.toml b/rust/cmd_drivers/build_and_insert_memory_index/Cargo.toml new file mode 100644 index 000000000..d9811fc22 --- /dev/null +++ b/rust/cmd_drivers/build_and_insert_memory_index/Cargo.toml @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "build_and_insert_memory_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +diskann = { path = "../../diskann" } +logger = { path = "../../logger" } +vector = { path = "../../vector" } + diff --git a/rust/cmd_drivers/build_and_insert_memory_index/src/main.rs b/rust/cmd_drivers/build_and_insert_memory_index/src/main.rs new file mode 100644 index 000000000..46e4ba4a4 --- /dev/null +++ b/rust/cmd_drivers/build_and_insert_memory_index/src/main.rs @@ -0,0 +1,382 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::env; + +use diskann::{ + common::{ANNResult, ANNError}, + index::create_inmem_index, + utils::round_up, + model::{ + IndexWriteParametersBuilder, + IndexConfiguration, + vertex::{DIM_128, DIM_256, DIM_104} + }, + utils::{load_metadata_from_file, Timer}, +}; + +use vector::{Metric, FullPrecisionDistance, Half}; + +// The main function to build an in-memory index +#[allow(clippy::too_many_arguments)] +fn build_and_insert_in_memory_index ( + metric: Metric, + data_path: &str, + delta_path: &str, + r: u32, + l: u32, + alpha: f32, + save_path: &str, + num_threads: u32, + _use_pq_build: bool, + _num_pq_bytes: usize, + use_opq: bool +) -> ANNResult<()> +where + T: Default + Copy + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance +{ + let index_write_parameters = IndexWriteParametersBuilder::new(l, r) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + let (data_num, data_dim) = load_metadata_from_file(data_path)?; + + let config = IndexConfiguration::new( + metric, + data_dim, + round_up(data_dim as u64, 8_u64) as usize, + data_num, + false, + 0, + use_opq, + 0, + 2.0f32, + index_write_parameters, + ); + let mut index = create_inmem_index::(config)?; + + let timer = Timer::new(); + + index.build(data_path, data_num)?; + + let diff = timer.elapsed(); + + println!("Initial indexing time: {}", diff.as_secs_f64()); + + let (delta_data_num, _) = load_metadata_from_file(delta_path)?; + + index.insert(delta_path, delta_data_num)?; + + index.save(save_path)?; + + Ok(()) +} + +fn main() -> ANNResult<()> { + let mut data_type = String::new(); + let mut dist_fn = String::new(); + let mut data_path = String::new(); + let mut insert_path = String::new(); + let mut index_path_prefix = String::new(); + + let mut num_threads = 0u32; + let mut r = 64u32; + let mut l = 100u32; + + let mut alpha = 1.2f32; + let mut build_pq_bytes = 0u32; + let mut _use_pq_build = false; + let mut use_opq = false; + + let args: Vec = env::args().collect(); + let mut iter = args.iter().skip(1).peekable(); + + while let Some(arg) = iter.next() { + match arg.as_str() { + "--help" | "-h" => { + print_help(); + return Ok(()); + } + "--data_type" => { + data_type = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_type".to_string(), + "Missing data type".to_string(), + ) + })? + .to_owned(); + } + "--dist_fn" => { + dist_fn = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "dist_fn".to_string(), + "Missing distance function".to_string(), + ) + })? + .to_owned(); + } + "--data_path" => { + data_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_path".to_string(), + "Missing data path".to_string(), + ) + })? + .to_owned(); + } + "--insert_path" => { + insert_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "insert_path".to_string(), + "Missing insert path".to_string(), + ) + })? + .to_owned(); + } + "--index_path_prefix" => { + index_path_prefix = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "index_path_prefix".to_string(), + "Missing index path prefix".to_string(), + ) + })? + .to_owned(); + } + "--max_degree" | "-R" => { + r = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "max_degree".to_string(), + "Missing max degree".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "max_degree".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--Lbuild" | "-L" => { + l = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + "Missing build complexity".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--alpha" => { + alpha = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "alpha".to_string(), + "Missing alpha".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "alpha".to_string(), + format!("ParseFloatError: {}", err), + ) + })?; + } + "--num_threads" | "-T" => { + num_threads = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "num_threads".to_string(), + "Missing number of threads".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "num_threads".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--build_PQ_bytes" => { + build_pq_bytes = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + "Missing PQ bytes".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--use_opq" => { + use_opq = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "use_opq".to_string(), + "Missing use_opq flag".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "use_opq".to_string(), + format!("ParseBoolError: {}", err), + ) + })?; + } + _ => { + return Err(ANNError::log_index_config_error( + String::from(""), + format!("Unknown argument: {}", arg), + )); + } + } + } + + if data_type.is_empty() + || dist_fn.is_empty() + || data_path.is_empty() + || index_path_prefix.is_empty() + { + return Err(ANNError::log_index_config_error( + String::from(""), + "Missing required arguments".to_string(), + )); + } + + _use_pq_build = build_pq_bytes > 0; + + let metric = dist_fn + .parse::() + .map_err(|err| ANNError::log_index_config_error( + "dist_fn".to_string(), + err.to_string(), + ))?; + + println!( + "Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}", + r, l, alpha, num_threads + ); + + match data_type.as_str() { + "int8" => { + build_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "uint8" => { + build_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "float" => { + build_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "f16" => { + build_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + _ => { + println!("Unsupported type. Use one of int8, uint8 or float."); + return Err(ANNError::log_index_config_error("data_type".to_string(), "Invalid data type".to_string())); + } + } + + Ok(()) +} + +fn print_help() { + println!("Arguments"); + println!("--help, -h Print information on arguments"); + println!("--data_type data type (required)"); + println!("--dist_fn distance function (required)"); + println!("--data_path Input data file in bin format for initial build (required)"); + println!("--insert_path Input data file in bin format for insert (required)"); + println!("--index_path_prefix Path prefix for saving index file components (required)"); + println!("--max_degree, -R Maximum graph degree (default: 64)"); + println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)"); + println!("--alpha alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter (default: 1.2)"); + println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)"); + println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)"); + println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)"); +} + diff --git a/rust/cmd_drivers/build_disk_index/Cargo.toml b/rust/cmd_drivers/build_disk_index/Cargo.toml new file mode 100644 index 000000000..afe5e5b33 --- /dev/null +++ b/rust/cmd_drivers/build_disk_index/Cargo.toml @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "build_disk_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +diskann = { path = "../../diskann" } +logger = { path = "../../logger" } +vector = { path = "../../vector" } +openblas-src = { version = "0.10.8", features = ["system", "static"] } diff --git a/rust/cmd_drivers/build_disk_index/src/main.rs b/rust/cmd_drivers/build_disk_index/src/main.rs new file mode 100644 index 000000000..e0b6dbe24 --- /dev/null +++ b/rust/cmd_drivers/build_disk_index/src/main.rs @@ -0,0 +1,377 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::env; + +use diskann::{ + common::{ANNError, ANNResult}, + index::ann_disk_index::create_disk_index, + model::{ + default_param_vals::ALPHA, + vertex::{DIM_104, DIM_128, DIM_256}, + DiskIndexBuildParameters, IndexConfiguration, IndexWriteParametersBuilder, + }, + storage::DiskIndexStorage, + utils::round_up, + utils::{load_metadata_from_file, Timer}, +}; + +use vector::{FullPrecisionDistance, Half, Metric}; + +/// The main function to build a disk index +#[allow(clippy::too_many_arguments)] +fn build_disk_index( + metric: Metric, + data_path: &str, + r: u32, + l: u32, + index_path_prefix: &str, + num_threads: u32, + search_ram_limit_gb: f64, + index_build_ram_limit_gb: f64, + num_pq_chunks: usize, + use_opq: bool, +) -> ANNResult<()> +where + T: Default + Copy + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + let disk_index_build_parameters = + DiskIndexBuildParameters::new(search_ram_limit_gb, index_build_ram_limit_gb)?; + + let index_write_parameters = IndexWriteParametersBuilder::new(l, r) + .with_saturate_graph(true) + .with_num_threads(num_threads) + .build(); + + let (data_num, data_dim) = load_metadata_from_file(data_path)?; + + let config = IndexConfiguration::new( + metric, + data_dim, + round_up(data_dim as u64, 8_u64) as usize, + data_num, + num_pq_chunks > 0, + num_pq_chunks, + use_opq, + 0, + 1f32, + index_write_parameters, + ); + let storage = DiskIndexStorage::new(data_path.to_string(), index_path_prefix.to_string())?; + let mut index = create_disk_index::(Some(disk_index_build_parameters), config, storage)?; + + let timer = Timer::new(); + + index.build("")?; + + let diff = timer.elapsed(); + println!("Indexing time: {}", diff.as_secs_f64()); + + Ok(()) +} + +fn main() -> ANNResult<()> { + let mut data_type = String::new(); + let mut dist_fn = String::new(); + let mut data_path = String::new(); + let mut index_path_prefix = String::new(); + + let mut num_threads = 0u32; + let mut r = 64u32; + let mut l = 100u32; + let mut search_ram_limit_gb = 0f64; + let mut index_build_ram_limit_gb = 0f64; + + let mut build_pq_bytes = 0u32; + let mut use_opq = false; + + let args: Vec = env::args().collect(); + let mut iter = args.iter().skip(1).peekable(); + + while let Some(arg) = iter.next() { + match arg.as_str() { + "--help" | "-h" => { + print_help(); + return Ok(()); + } + "--data_type" => { + data_type = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_type".to_string(), + "Missing data type".to_string(), + ) + })? + .to_owned(); + } + "--dist_fn" => { + dist_fn = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "dist_fn".to_string(), + "Missing distance function".to_string(), + ) + })? + .to_owned(); + } + "--data_path" => { + data_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_path".to_string(), + "Missing data path".to_string(), + ) + })? + .to_owned(); + } + "--index_path_prefix" => { + index_path_prefix = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "index_path_prefix".to_string(), + "Missing index path prefix".to_string(), + ) + })? + .to_owned(); + } + "--max_degree" | "-R" => { + r = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "max_degree".to_string(), + "Missing max degree".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "max_degree".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--Lbuild" | "-L" => { + l = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + "Missing build complexity".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--num_threads" | "-T" => { + num_threads = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "num_threads".to_string(), + "Missing number of threads".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "num_threads".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--build_PQ_bytes" => { + build_pq_bytes = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + "Missing PQ bytes".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--use_opq" => { + use_opq = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "use_opq".to_string(), + "Missing use_opq flag".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "use_opq".to_string(), + format!("ParseBoolError: {}", err), + ) + })?; + } + "--search_DRAM_budget" | "-B" => { + search_ram_limit_gb = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "search_DRAM_budget".to_string(), + "Missing search_DRAM_budget flag".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "search_DRAM_budget".to_string(), + format!("ParseBoolError: {}", err), + ) + })?; + } + "--build_DRAM_budget" | "-M" => { + index_build_ram_limit_gb = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "build_DRAM_budget".to_string(), + "Missing build_DRAM_budget flag".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "build_DRAM_budget".to_string(), + format!("ParseBoolError: {}", err), + ) + })?; + } + _ => { + return Err(ANNError::log_index_config_error( + String::from(""), + format!("Unknown argument: {}", arg), + )); + } + } + } + + if data_type.is_empty() + || dist_fn.is_empty() + || data_path.is_empty() + || index_path_prefix.is_empty() + { + return Err(ANNError::log_index_config_error( + String::from(""), + "Missing required arguments".to_string(), + )); + } + + let metric = dist_fn + .parse::() + .map_err(|err| ANNError::log_index_config_error("dist_fn".to_string(), err.to_string()))?; + + println!( + "Starting index build with R: {} Lbuild: {} alpha: {} #threads: {} search_DRAM_budget: {} build_DRAM_budget: {}", + r, l, ALPHA, num_threads, search_ram_limit_gb, index_build_ram_limit_gb + ); + + let err = match data_type.as_str() { + "int8" => build_disk_index::( + metric, + &data_path, + r, + l, + &index_path_prefix, + num_threads, + search_ram_limit_gb, + index_build_ram_limit_gb, + build_pq_bytes as usize, + use_opq, + ), + "uint8" => build_disk_index::( + metric, + &data_path, + r, + l, + &index_path_prefix, + num_threads, + search_ram_limit_gb, + index_build_ram_limit_gb, + build_pq_bytes as usize, + use_opq, + ), + "float" => build_disk_index::( + metric, + &data_path, + r, + l, + &index_path_prefix, + num_threads, + search_ram_limit_gb, + index_build_ram_limit_gb, + build_pq_bytes as usize, + use_opq, + ), + "f16" => build_disk_index::( + metric, + &data_path, + r, + l, + &index_path_prefix, + num_threads, + search_ram_limit_gb, + index_build_ram_limit_gb, + build_pq_bytes as usize, + use_opq, + ), + _ => { + println!("Unsupported type. Use one of int8, uint8, float or f16."); + return Err(ANNError::log_index_config_error( + "data_type".to_string(), + "Invalid data type".to_string(), + )); + } + }; + + match err { + Ok(_) => { + println!("Index build completed successfully"); + Ok(()) + } + Err(err) => { + eprintln!("Error: {:?}", err); + Err(err) + } + } +} + +fn print_help() { + println!("Arguments"); + println!("--help, -h Print information on arguments"); + println!("--data_type data type (required)"); + println!("--dist_fn distance function (required)"); + println!("--data_path Input data file in bin format (required)"); + println!("--index_path_prefix Path prefix for saving index file components (required)"); + println!("--max_degree, -R Maximum graph degree (default: 64)"); + println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)"); + println!("--search_DRAM_budget Bound on the memory footprint of the index at search time in GB. Once built, the index will use up only the specified RAM limit, the rest will reside on disk"); + println!("--build_DRAM_budget Limit on the memory allowed for building the index in GB"); + println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)"); + println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)"); + println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)"); +} diff --git a/rust/cmd_drivers/build_memory_index/Cargo.toml b/rust/cmd_drivers/build_memory_index/Cargo.toml new file mode 100644 index 000000000..eb4708d84 --- /dev/null +++ b/rust/cmd_drivers/build_memory_index/Cargo.toml @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "build_memory_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +clap = { version = "4.3.8", features = ["derive"] } +diskann = { path = "../../diskann" } +logger = { path = "../../logger" } +vector = { path = "../../vector" } + diff --git a/rust/cmd_drivers/build_memory_index/src/args.rs b/rust/cmd_drivers/build_memory_index/src/args.rs new file mode 100644 index 000000000..ede31f2db --- /dev/null +++ b/rust/cmd_drivers/build_memory_index/src/args.rs @@ -0,0 +1,62 @@ +use clap::{Args, Parser}; + +#[derive(Debug, Args)] +enum DataType { + /// Float data type. + Float, + + /// Half data type. + FP16, +} + +#[derive(Debug, Args)] +enum DistanceFunction { + /// Euclidean distance. + L2, + + /// Cosine distance. + Cosine, +} + +#[derive(Debug, Parser)] +struct BuildMemoryIndexArgs { + /// Data type of the vectors. + #[clap(long, default_value = "float")] + pub data_type: DataType, + + /// Distance function to use. + #[clap(long, default_value = "l2")] + pub dist_fn: Metric, + + /// Path to the data file. The file should be in the format specified by the `data_type` argument. + #[clap(long, short, required = true)] + pub data_path: String, + + /// Path to the index file. The index will be saved to this prefixed name. + #[clap(long, short, required = true)] + pub index_path_prefix: String, + + /// Number of max out degree from a vertex. + #[clap(long, default_value = "32")] + pub max_degree: usize, + + /// Number of candidates to consider when building out edges + #[clap(long, short default_value = "50")] + pub l_build: usize, + + /// Alpha to use to build diverse edges + #[clap(long, short default_value = "1.0")] + pub alpha: f32, + + /// Number of threads to use. + #[clap(long, short, default_value = "1")] + pub num_threads: u8, + + /// Number of PQ bytes to use. + #[clap(long, short, default_value = "8")] + pub build_pq_bytes: usize, + + /// Use opq? + #[clap(long, short, default_value = "false")] + pub use_opq: bool, +} diff --git a/rust/cmd_drivers/build_memory_index/src/main.rs b/rust/cmd_drivers/build_memory_index/src/main.rs new file mode 100644 index 000000000..cdccc0061 --- /dev/null +++ b/rust/cmd_drivers/build_memory_index/src/main.rs @@ -0,0 +1,174 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use clap::{Parser, ValueEnum}; +use std::path::PathBuf; + +use diskann::{ + common::ANNResult, + index::create_inmem_index, + model::{ + vertex::{DIM_104, DIM_128, DIM_256}, + IndexConfiguration, IndexWriteParametersBuilder, + }, + utils::round_up, + utils::{load_metadata_from_file, Timer}, +}; + +use vector::{FullPrecisionDistance, Half, Metric}; + +/// The main function to build an in-memory index +#[allow(clippy::too_many_arguments)] +fn build_in_memory_index( + metric: Metric, + data_path: &str, + r: u32, + l: u32, + alpha: f32, + save_path: &str, + num_threads: u32, + _use_pq_build: bool, + _num_pq_bytes: usize, + use_opq: bool, +) -> ANNResult<()> +where + T: Default + Copy + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + let index_write_parameters = IndexWriteParametersBuilder::new(l, r) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + let (data_num, data_dim) = load_metadata_from_file(data_path)?; + + let config = IndexConfiguration::new( + metric, + data_dim, + round_up(data_dim as u64, 8_u64) as usize, + data_num, + false, + 0, + use_opq, + 0, + 1f32, + index_write_parameters, + ); + let mut index = create_inmem_index::(config)?; + + let timer = Timer::new(); + + index.build(data_path, data_num)?; + + let diff = timer.elapsed(); + + println!("Indexing time: {}", diff.as_secs_f64()); + index.save(save_path)?; + + Ok(()) +} + +fn main() -> ANNResult<()> { + let args = BuildMemoryIndexArgs::parse(); + + let _use_pq_build = args.build_pq_bytes > 0; + + println!( + "Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}", + args.max_degree, args.l_build, args.alpha, args.num_threads + ); + + let err = match args.data_type { + DataType::Float => build_in_memory_index::( + args.dist_fn, + &args.data_path.to_string_lossy(), + args.max_degree, + args.l_build, + args.alpha, + &args.index_path_prefix, + args.num_threads, + _use_pq_build, + args.build_pq_bytes, + args.use_opq, + ), + DataType::FP16 => build_in_memory_index::( + args.dist_fn, + &args.data_path.to_string_lossy(), + args.max_degree, + args.l_build, + args.alpha, + &args.index_path_prefix, + args.num_threads, + _use_pq_build, + args.build_pq_bytes, + args.use_opq, + ), + }; + + match err { + Ok(_) => { + println!("Index build completed successfully"); + Ok(()) + } + Err(err) => { + eprintln!("Error: {:?}", err); + Err(err) + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)] +enum DataType { + /// Float data type. + Float, + + /// Half data type. + FP16, +} + +#[derive(Debug, Parser)] +struct BuildMemoryIndexArgs { + /// data type (required) + #[arg(long = "data_type", default_value = "float")] + pub data_type: DataType, + + /// Distance function to use. + #[arg(long = "dist_fn", default_value = "l2")] + pub dist_fn: Metric, + + /// Path to the data file. The file should be in the format specified by the `data_type` argument. + #[arg(long = "data_path", short, required = true)] + pub data_path: PathBuf, + + /// Path to the index file. The index will be saved to this prefixed name. + #[arg(long = "index_path_prefix", short, required = true)] + pub index_path_prefix: String, + + /// Number of max out degree from a vertex. + #[arg(long = "max_degree", short = 'R', default_value = "64")] + pub max_degree: u32, + + /// Number of candidates to consider when building out edges + #[arg(long = "l_build", short = 'L', default_value = "100")] + pub l_build: u32, + + /// alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter + #[arg(long, short, default_value = "1.2")] + pub alpha: f32, + + /// Number of threads to use. + #[arg(long = "num_threads", short = 'T', default_value = "1")] + pub num_threads: u32, + + /// Number of PQ bytes to build the index; 0 for full precision build + #[arg(long = "build_pq_bytes", short, default_value = "0")] + pub build_pq_bytes: usize, + + /// Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression + #[arg(long = "use_opq", short, default_value = "false")] + pub use_opq: bool, +} diff --git a/rust/cmd_drivers/convert_f32_to_bf16/Cargo.toml b/rust/cmd_drivers/convert_f32_to_bf16/Cargo.toml new file mode 100644 index 000000000..1993aab9d --- /dev/null +++ b/rust/cmd_drivers/convert_f32_to_bf16/Cargo.toml @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "convert_f32_to_bf16" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +half = "2.2.1" diff --git a/rust/cmd_drivers/convert_f32_to_bf16/src/main.rs b/rust/cmd_drivers/convert_f32_to_bf16/src/main.rs new file mode 100644 index 000000000..87b4fbaf3 --- /dev/null +++ b/rust/cmd_drivers/convert_f32_to_bf16/src/main.rs @@ -0,0 +1,154 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use half::{bf16, f16}; +use std::env; +use std::fs::{File, OpenOptions}; +use std::io::{self, Read, Write, BufReader, BufWriter}; + +enum F16OrBF16 { + F16(f16), + BF16(bf16), +} + +fn main() -> io::Result<()> { + // Retrieve command-line arguments + let args: Vec = env::args().collect(); + + match args.len() { + 3|4|5|6=> {}, + _ => { + print_usage(); + std::process::exit(1); + } + } + + // Retrieve the input and output file paths from the arguments + let input_file_path = &args[1]; + let output_file_path = &args[2]; + let use_f16 = args.len() >= 4 && args[3] == "f16"; + let save_as_float = args.len() >= 5 && args[4] == "save_as_float"; + let batch_size = if args.len() >= 6 { args[5].parse::().unwrap() } else { 100000 }; + println!("use_f16: {}", use_f16); + println!("save_as_float: {}", save_as_float); + println!("batch_size: {}", batch_size); + + // Open the input file for reading + let mut input_file = BufReader::new(File::open(input_file_path)?); + + // Open the output file for writing + let mut output_file = BufWriter::new(OpenOptions::new().write(true).create(true).open(output_file_path)?); + + // Read the first 8 bytes as metadata + let mut metadata = [0; 8]; + input_file.read_exact(&mut metadata)?; + + // Write the metadata to the output file + output_file.write_all(&metadata)?; + + // Extract the number of points and dimension from the metadata + let num_points = i32::from_le_bytes(metadata[..4].try_into().unwrap()); + let dimension = i32::from_le_bytes(metadata[4..].try_into().unwrap()); + let num_batches = num_points / batch_size; + // Calculate the size of one data point in bytes + let data_point_size = (dimension * 4 * batch_size) as usize; + let mut batches_processed = 0; + let numbers_to_print = 2; + let mut numbers_printed = 0; + let mut num_fb16_wins = 0; + let mut num_f16_wins = 0; + let mut bf16_overflow = 0; + let mut f16_overflow = 0; + + // Process each data point + for _ in 0..num_batches { + // Read one data point from the input file + let mut buffer = vec![0; data_point_size]; + match input_file.read_exact(&mut buffer){ + Ok(()) => { + // Convert the float32 data to bf16 + let half_data: Vec = buffer + .chunks_exact(4) + .map(|chunk| { + let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); + let converted_bf16 = bf16::from_f32(value); + let converted_f16 = f16::from_f32(value); + let distance_f16 = (converted_f16.to_f32() - value).abs(); + let distance_bf16 = (converted_bf16.to_f32() - value).abs(); + + if distance_f16 < distance_bf16 { + num_f16_wins += 1; + } else { + num_fb16_wins += 1; + } + + if (converted_bf16 == bf16::INFINITY) || (converted_bf16 == bf16::NEG_INFINITY) { + bf16_overflow += 1; + } + + if (converted_f16 == f16::INFINITY) || (converted_f16 == f16::NEG_INFINITY) { + f16_overflow += 1; + } + + if numbers_printed < numbers_to_print { + numbers_printed += 1; + println!("f32 value: {} f16 value: {} | distance {}, bf16 value: {} | distance {},", + value, converted_f16, converted_f16.to_f32() - value, converted_bf16, converted_bf16.to_f32() - value); + } + + if use_f16 { + F16OrBF16::F16(converted_f16) + } else { + F16OrBF16::BF16(converted_bf16) + } + }) + .collect(); + + batches_processed += 1; + + match save_as_float { + true => { + for float_val in half_data { + match float_val { + F16OrBF16::F16(f16_val) => output_file.write_all(&f16_val.to_f32().to_le_bytes())?, + F16OrBF16::BF16(bf16_val) => output_file.write_all(&bf16_val.to_f32().to_le_bytes())?, + } + } + } + false => { + for float_val in half_data { + match float_val { + F16OrBF16::F16(f16_val) => output_file.write_all(&f16_val.to_le_bytes())?, + F16OrBF16::BF16(bf16_val) => output_file.write_all(&bf16_val.to_le_bytes())?, + } + } + } + } + + // Print the number of points processed + println!("Processed {} points out of {}", batches_processed * batch_size, num_points); + } + Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => { + println!("Conversion completed! {} of times f16 wins | overflow count {}, {} of times bf16 wins | overflow count{}", + num_f16_wins, f16_overflow, num_fb16_wins, bf16_overflow); + break; + } + Err(err) => { + println!("Error: {}", err); + break; + } + }; + } + + Ok(()) +} + +/// Prints the usage information +fn print_usage() { + println!("Usage: program_name input_file output_file [f16] [save_as_float] [batch_size]]"); + println!("specify f16 to downscale to f16. otherwise, downscale to bf16."); + println!("specify save_as_float to downcast to f16 or bf16, and upcast to float before saving the output data. otherwise, the data will be saved as half type."); + println!("specify the batch_size as a int, the default value is 100000."); +} + diff --git a/rust/cmd_drivers/load_and_insert_memory_index/Cargo.toml b/rust/cmd_drivers/load_and_insert_memory_index/Cargo.toml new file mode 100644 index 000000000..cbb4e1e3c --- /dev/null +++ b/rust/cmd_drivers/load_and_insert_memory_index/Cargo.toml @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "load_and_insert_memory_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +diskann = { path = "../../diskann" } +logger = { path = "../../logger" } +vector = { path = "../../vector" } + diff --git a/rust/cmd_drivers/load_and_insert_memory_index/src/main.rs b/rust/cmd_drivers/load_and_insert_memory_index/src/main.rs new file mode 100644 index 000000000..41680460a --- /dev/null +++ b/rust/cmd_drivers/load_and_insert_memory_index/src/main.rs @@ -0,0 +1,313 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::env; + +use diskann::{ + common::{ANNResult, ANNError}, + index::create_inmem_index, + utils::round_up, + model::{ + IndexWriteParametersBuilder, + IndexConfiguration, + vertex::{DIM_128, DIM_256, DIM_104} + }, + utils::{Timer, load_metadata_from_file}, +}; + +use vector::{Metric, FullPrecisionDistance, Half}; + +// The main function to build an in-memory index +#[allow(clippy::too_many_arguments)] +fn load_and_insert_in_memory_index ( + metric: Metric, + data_path: &str, + delta_path: &str, + r: u32, + l: u32, + alpha: f32, + save_path: &str, + num_threads: u32, + _use_pq_build: bool, + _num_pq_bytes: usize, + use_opq: bool +) -> ANNResult<()> +where + T: Default + Copy + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance +{ + let index_write_parameters = IndexWriteParametersBuilder::new(l, r) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + let (data_num, data_dim) = load_metadata_from_file(&format!("{}.data", data_path))?; + + let config = IndexConfiguration::new( + metric, + data_dim, + round_up(data_dim as u64, 8_u64) as usize, + data_num, + false, + 0, + use_opq, + 0, + 2.0f32, + index_write_parameters, + ); + let mut index = create_inmem_index::(config)?; + + let timer = Timer::new(); + + index.load(data_path, data_num)?; + + let diff = timer.elapsed(); + + println!("Initial indexing time: {}", diff.as_secs_f64()); + + let (delta_data_num, _) = load_metadata_from_file(delta_path)?; + + index.insert(delta_path, delta_data_num)?; + + index.save(save_path)?; + + Ok(()) +} + +fn main() -> ANNResult<()> { + let mut data_type = String::new(); + let mut dist_fn = String::new(); + let mut data_path = String::new(); + let mut insert_path = String::new(); + let mut index_path_prefix = String::new(); + + let mut num_threads = 0u32; + let mut r = 64u32; + let mut l = 100u32; + + let mut alpha = 1.2f32; + let mut build_pq_bytes = 0u32; + let mut _use_pq_build = false; + let mut use_opq = false; + + let args: Vec = env::args().collect(); + let mut iter = args.iter().skip(1).peekable(); + + while let Some(arg) = iter.next() { + match arg.as_str() { + "--help" | "-h" => { + print_help(); + return Ok(()); + } + "--data_type" => { + data_type = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "data_type".to_string(), + "Missing data type".to_string()) + )? + .to_owned(); + } + "--dist_fn" => { + dist_fn = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "dist_fn".to_string(), + "Missing distance function".to_string()) + )? + .to_owned(); + } + "--data_path" => { + data_path = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "data_path".to_string(), + "Missing data path".to_string()) + )? + .to_owned(); + } + "--insert_path" => { + insert_path = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "insert_path".to_string(), + "Missing insert path".to_string()) + )? + .to_owned(); + } + "--index_path_prefix" => { + index_path_prefix = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "index_path_prefix".to_string(), + "Missing index path prefix".to_string()))? + .to_owned(); + } + "--max_degree" | "-R" => { + r = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "max_degree".to_string(), + "Missing max degree".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "max_degree".to_string(), + format!("ParseIntError: {}", err)) + )?; + } + "--Lbuild" | "-L" => { + l = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "Lbuild".to_string(), + "Missing build complexity".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "Lbuild".to_string(), + format!("ParseIntError: {}", err)) + )?; + } + "--alpha" => { + alpha = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "alpha".to_string(), + "Missing alpha".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "alpha".to_string(), + format!("ParseFloatError: {}", err)) + )?; + } + "--num_threads" | "-T" => { + num_threads = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "num_threads".to_string(), + "Missing number of threads".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "num_threads".to_string(), + format!("ParseIntError: {}", err)) + )?; + } + "--build_PQ_bytes" => { + build_pq_bytes = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + "Missing PQ bytes".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + format!("ParseIntError: {}", err)) + )?; + } + "--use_opq" => { + use_opq = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "use_opq".to_string(), + "Missing use_opq flag".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "use_opq".to_string(), + format!("ParseBoolError: {}", err)) + )?; + } + _ => { + return Err(ANNError::log_index_config_error(String::from(""), format!("Unknown argument: {}", arg))); + } + } + } + + if data_type.is_empty() + || dist_fn.is_empty() + || data_path.is_empty() + || index_path_prefix.is_empty() + { + return Err(ANNError::log_index_config_error(String::from(""), "Missing required arguments".to_string())); + } + + _use_pq_build = build_pq_bytes > 0; + + let metric = dist_fn + .parse::() + .map_err(|err| ANNError::log_index_config_error( + "dist_fn".to_string(), + err.to_string(), + ))?; + + println!( + "Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}", + r, l, alpha, num_threads + ); + + match data_type.as_str() { + "int8" => { + load_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "uint8" => { + load_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "float" => { + load_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "f16" => { + load_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )? + } + _ => { + println!("Unsupported type. Use one of int8, uint8 or float."); + return Err(ANNError::log_index_config_error("data_type".to_string(), "Invalid data type".to_string())); + } + } + + Ok(()) +} + +fn print_help() { + println!("Arguments"); + println!("--help, -h Print information on arguments"); + println!("--data_type data type (required)"); + println!("--dist_fn distance function (required)"); + println!("--data_path Input data file in bin format for initial build (required)"); + println!("--insert_path Input data file in bin format for insert (required)"); + println!("--index_path_prefix Path prefix for saving index file components (required)"); + println!("--max_degree, -R Maximum graph degree (default: 64)"); + println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)"); + println!("--alpha alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter (default: 1.2)"); + println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)"); + println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)"); + println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)"); +} + diff --git a/rust/cmd_drivers/search_memory_index/Cargo.toml b/rust/cmd_drivers/search_memory_index/Cargo.toml new file mode 100644 index 000000000..cba3709aa --- /dev/null +++ b/rust/cmd_drivers/search_memory_index/Cargo.toml @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "search_memory_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytemuck = "1.13.1" +diskann = { path = "../../diskann" } +num_cpus = "1.15.0" +rayon = "1.7.0" +vector = { path = "../../vector" } + diff --git a/rust/cmd_drivers/search_memory_index/src/main.rs b/rust/cmd_drivers/search_memory_index/src/main.rs new file mode 100644 index 000000000..ca4d4cd1d --- /dev/null +++ b/rust/cmd_drivers/search_memory_index/src/main.rs @@ -0,0 +1,430 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod search_index_utils; +use bytemuck::Pod; +use diskann::{ + common::{ANNError, ANNResult}, + index, + model::{ + configuration::index_write_parameters::{default_param_vals, IndexWriteParametersBuilder}, + vertex::{DIM_104, DIM_128, DIM_256}, + IndexConfiguration, + }, + utils::{load_metadata_from_file, save_bin_u32}, +}; +use std::{env, path::Path, process::exit, time::Instant}; +use vector::{FullPrecisionDistance, Half, Metric}; + +use rayon::prelude::*; + +#[allow(clippy::too_many_arguments)] +fn search_memory_index( + metric: Metric, + index_path: &str, + result_path_prefix: &str, + query_file: &str, + truthset_file: &str, + num_threads: u32, + recall_at: u32, + print_all_recalls: bool, + l_vec: &Vec, + show_qps_per_thread: bool, + fail_if_recall_below: f32, +) -> ANNResult +where + T: Default + Copy + Sized + Pod + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + // Load the query file + let (query, query_num, query_dim, query_aligned_dim) = + search_index_utils::load_aligned_bin::(query_file)?; + let mut gt_dim: usize = 0; + let mut gt_ids: Option> = None; + let mut gt_dists: Option> = None; + + // Check for ground truth + let mut calc_recall_flag = false; + if !truthset_file.is_empty() && Path::new(truthset_file).exists() { + let ret = search_index_utils::load_truthset(truthset_file)?; + gt_ids = Some(ret.0); + gt_dists = ret.1; + let gt_num = ret.2; + gt_dim = ret.3; + + if gt_num != query_num { + println!("Error. Mismatch in number of queries and ground truth data"); + } + + calc_recall_flag = true; + } else { + println!( + "Truthset file {} not found. Not computing recall", + truthset_file + ); + } + + let num_frozen_pts = search_index_utils::get_graph_num_frozen_points(index_path)?; + + // C++ uses the max given L value, so we do the same here. Max degree is never specified in C++ so use the rust default + let index_write_params = IndexWriteParametersBuilder::new( + *l_vec.iter().max().unwrap(), + default_param_vals::MAX_DEGREE, + ) + .with_num_threads(num_threads) + .build(); + + let (index_num_points, _) = load_metadata_from_file(&format!("{}.data", index_path))?; + + let index_config = IndexConfiguration::new( + metric, + query_dim, + query_aligned_dim, + index_num_points, + false, + 0, + false, + num_frozen_pts, + 1f32, + index_write_params, + ); + let mut index = index::create_inmem_index::(index_config)?; + + index.load(index_path, index_num_points)?; + + println!("Using {} threads to search", num_threads); + let qps_title = if show_qps_per_thread { + "QPS/thread" + } else { + "QPS" + }; + let mut table_width = 4 + 12 + 18 + 20 + 15; + let mut table_header_str = format!( + "{:>4}{:>12}{:>18}{:>20}{:>15}", + "Ls", qps_title, "Avg dist cmps", "Mean Latency (mus)", "99.9 Latency" + ); + + let first_recall: u32 = if print_all_recalls { 1 } else { recall_at }; + let mut recalls_to_print: usize = 0; + if calc_recall_flag { + for curr_recall in first_recall..=recall_at { + let recall_str = format!("Recall@{}", curr_recall); + table_header_str.push_str(&format!("{:>12}", recall_str)); + recalls_to_print = (recall_at + 1 - first_recall) as usize; + table_width += recalls_to_print * 12; + } + } + + println!("{}", table_header_str); + println!("{}", "=".repeat(table_width)); + + let mut query_result_ids: Vec> = + vec![vec![0; query_num * recall_at as usize]; l_vec.len()]; + let mut latency_stats: Vec = vec![0.0; query_num]; + let mut cmp_stats: Vec = vec![0; query_num]; + let mut best_recall = 0.0; + + std::env::set_var("RAYON_NUM_THREADS", num_threads.to_string()); + + for test_id in 0..l_vec.len() { + let l_value = l_vec[test_id]; + + if l_value < recall_at { + println!( + "Ignoring search with L:{} since it's smaller than K:{}", + l_value, recall_at + ); + continue; + } + + let zipped = cmp_stats + .par_iter_mut() + .zip(latency_stats.par_iter_mut()) + .zip(query_result_ids[test_id].par_chunks_mut(recall_at as usize)) + .zip(query.par_chunks(query_aligned_dim)); + + let start = Instant::now(); + zipped.for_each(|(((cmp, latency), query_result), query_chunk)| { + let query_start = Instant::now(); + *cmp = index + .search(query_chunk, recall_at as usize, l_value, query_result) + .unwrap(); + + let query_end = Instant::now(); + let diff = query_end.duration_since(query_start); + *latency = diff.as_micros() as f32; + }); + let diff = Instant::now().duration_since(start); + + let mut displayed_qps: f32 = query_num as f32 / diff.as_secs_f32(); + if show_qps_per_thread { + displayed_qps /= num_threads as f32; + } + + let mut recalls: Vec = Vec::new(); + if calc_recall_flag { + recalls.reserve(recalls_to_print); + for curr_recall in first_recall..=recall_at { + recalls.push(search_index_utils::calculate_recall( + query_num, + gt_ids.as_ref().unwrap(), + >_dists, + gt_dim, + &query_result_ids[test_id], + recall_at, + curr_recall, + )? as f32); + } + } + + latency_stats.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mean_latency = latency_stats.iter().sum::() / query_num as f32; + let avg_cmps = cmp_stats.iter().sum::() as f32 / query_num as f32; + + let mut stat_str = format!( + "{: >4}{: >12.2}{: >18.2}{: >20.2}{: >15.2}", + l_value, + displayed_qps, + avg_cmps, + mean_latency, + latency_stats[(0.999 * query_num as f32).round() as usize] + ); + + for recall in recalls.iter() { + stat_str.push_str(&format!("{: >12.2}", recall)); + best_recall = f32::max(best_recall, *recall); + } + + println!("{}", stat_str); + } + + println!("Done searching. Now saving results"); + for (test_id, l_value) in l_vec.iter().enumerate() { + if *l_value < recall_at { + println!( + "Ignoring all search with L: {} since it's smaller than K: {}", + l_value, recall_at + ); + } + + let cur_result_path = format!("{}_{}_idx_uint32.bin", result_path_prefix, l_value); + save_bin_u32( + &cur_result_path, + query_result_ids[test_id].as_slice(), + query_num, + recall_at as usize, + 0, + )?; + } + + if best_recall >= fail_if_recall_below { + Ok(0) + } else { + Ok(-1) + } +} + +fn main() -> ANNResult<()> { + let return_val: i32; + { + let mut data_type: String = String::new(); + let mut metric: Option = None; + let mut index_path: String = String::new(); + let mut result_path_prefix: String = String::new(); + let mut query_file: String = String::new(); + let mut truthset_file: String = String::new(); + let mut num_cpus: u32 = num_cpus::get() as u32; + let mut recall_at: Option = None; + let mut print_all_recalls: bool = false; + let mut l_vec: Vec = Vec::new(); + let mut show_qps_per_thread: bool = false; + let mut fail_if_recall_below: f32 = 0.0; + + let args: Vec = env::args().collect(); + let mut iter = args.iter().skip(1).peekable(); + while let Some(arg) = iter.next() { + let ann_error = + || ANNError::log_index_config_error(String::from(arg), format!("Missing {}", arg)); + match arg.as_str() { + "--help" | "-h" => { + print_help(); + return Ok(()); + } + "--data_type" => { + data_type = iter.next().ok_or_else(ann_error)?.to_owned(); + } + "--dist_fn" => { + metric = Some(iter.next().ok_or_else(ann_error)?.parse().map_err(|err| { + ANNError::log_index_config_error( + String::from(arg), + format!("ParseError: {}", err), + ) + })?); + } + "--index_path_prefix" => { + index_path = iter.next().ok_or_else(ann_error)?.to_owned(); + } + "--result_path" => { + result_path_prefix = iter.next().ok_or_else(ann_error)?.to_owned(); + } + "--query_file" => { + query_file = iter.next().ok_or_else(ann_error)?.to_owned(); + } + "--gt_file" => { + truthset_file = iter.next().ok_or_else(ann_error)?.to_owned(); + } + "--recall_at" | "-K" => { + recall_at = + Some(iter.next().ok_or_else(ann_error)?.parse().map_err(|err| { + ANNError::log_index_config_error( + String::from(arg), + format!("ParseError: {}", err), + ) + })?); + } + "--print_all_recalls" => { + print_all_recalls = true; + } + "--search_list" | "-L" => { + while iter.peek().is_some() && !iter.peek().unwrap().starts_with('-') { + l_vec.push(iter.next().ok_or_else(ann_error)?.parse().map_err(|err| { + ANNError::log_index_config_error( + String::from(arg), + format!("ParseError: {}", err), + ) + })?); + } + } + "--num_threads" => { + num_cpus = iter.next().ok_or_else(ann_error)?.parse().map_err(|err| { + ANNError::log_index_config_error( + String::from(arg), + format!("ParseError: {}", err), + ) + })?; + } + "--qps_per_thread" => { + show_qps_per_thread = true; + } + "--fail_if_recall_below" => { + fail_if_recall_below = + iter.next().ok_or_else(ann_error)?.parse().map_err(|err| { + ANNError::log_index_config_error( + String::from(arg), + format!("ParseError: {}", err), + ) + })?; + } + _ => { + return Err(ANNError::log_index_error(format!( + "Unknown argument: {}", + arg + ))); + } + } + } + + if metric.is_none() { + return Err(ANNError::log_index_error(String::from("No metric given!"))); + } else if recall_at.is_none() { + return Err(ANNError::log_index_error(String::from( + "No recall_at given!", + ))); + } + + // Seems like float is the only supported data type for FullPrecisionDistance right now, + // but keep the structure in place here for future data types + match data_type.as_str() { + "float" => { + return_val = search_memory_index::( + metric.unwrap(), + &index_path, + &result_path_prefix, + &query_file, + &truthset_file, + num_cpus, + recall_at.unwrap(), + print_all_recalls, + &l_vec, + show_qps_per_thread, + fail_if_recall_below, + )?; + } + "int8" => { + return_val = search_memory_index::( + metric.unwrap(), + &index_path, + &result_path_prefix, + &query_file, + &truthset_file, + num_cpus, + recall_at.unwrap(), + print_all_recalls, + &l_vec, + show_qps_per_thread, + fail_if_recall_below, + )?; + } + "uint8" => { + return_val = search_memory_index::( + metric.unwrap(), + &index_path, + &result_path_prefix, + &query_file, + &truthset_file, + num_cpus, + recall_at.unwrap(), + print_all_recalls, + &l_vec, + show_qps_per_thread, + fail_if_recall_below, + )?; + } + "f16" => { + return_val = search_memory_index::( + metric.unwrap(), + &index_path, + &result_path_prefix, + &query_file, + &truthset_file, + num_cpus, + recall_at.unwrap(), + print_all_recalls, + &l_vec, + show_qps_per_thread, + fail_if_recall_below, + )?; + } + _ => { + return Err(ANNError::log_index_error(format!( + "Unknown data type: {}!", + data_type + ))); + } + } + } + + // Rust only allows returning values with this method, but this will immediately terminate the program without running destructors on the + // stack. To get around this enclose main function logic in a block so that by the time we return here all destructors have been called. + exit(return_val); +} + +fn print_help() { + println!("Arguments"); + println!("--help, -h Print information on arguments"); + println!("--data_type data type (required)"); + println!("--dist_fn distance function (required)"); + println!("--index_path_prefix Path prefix to the index (required)"); + println!("--result_path Path prefix for saving results of the queries (required)"); + println!("--query_file Query file in binary format"); + println!("--gt_file Ground truth file for the queryset"); + println!("--recall_at, -K Number of neighbors to be returned"); + println!("--print_all_recalls Print recalls at all positions, from 1 up to specified recall_at value"); + println!("--search_list List of L values of search"); + println!("----num_threads, -T Number of threads used for building index (defaults to num_cpus::get())"); + println!("--qps_per_thread Print overall QPS divided by the number of threads in the output table"); + println!("--fail_if_recall_below If set to a value >0 and <100%, program returns -1 if best recall found is below this threshold"); +} diff --git a/rust/cmd_drivers/search_memory_index/src/search_index_utils.rs b/rust/cmd_drivers/search_memory_index/src/search_index_utils.rs new file mode 100644 index 000000000..c7b04a47f --- /dev/null +++ b/rust/cmd_drivers/search_memory_index/src/search_index_utils.rs @@ -0,0 +1,186 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use bytemuck::{cast_slice, Pod}; +use diskann::{ + common::{ANNError, ANNResult, AlignedBoxWithSlice}, + model::data_store::DatasetDto, + utils::{copy_aligned_data_from_file, is_aligned, round_up}, +}; +use std::collections::HashSet; +use std::fs::File; +use std::io::Read; +use std::mem::size_of; + +pub(crate) fn calculate_recall( + num_queries: usize, + gold_std: &[u32], + gs_dist: &Option>, + dim_gs: usize, + our_results: &[u32], + dim_or: u32, + recall_at: u32, +) -> ANNResult { + let mut total_recall: f64 = 0.0; + let (mut gt, mut res): (HashSet, HashSet) = (HashSet::new(), HashSet::new()); + + for i in 0..num_queries { + gt.clear(); + res.clear(); + + let gt_slice = &gold_std[dim_gs * i..]; + let res_slice = &our_results[dim_or as usize * i..]; + let mut tie_breaker = recall_at as usize; + + if gs_dist.is_some() { + tie_breaker = (recall_at - 1) as usize; + let gt_dist_vec = &gs_dist.as_ref().unwrap()[dim_gs * i..]; + while tie_breaker < dim_gs + && gt_dist_vec[tie_breaker] == gt_dist_vec[(recall_at - 1) as usize] + { + tie_breaker += 1; + } + } + + (0..tie_breaker).for_each(|idx| { + gt.insert(gt_slice[idx]); + }); + + (0..tie_breaker).for_each(|idx| { + res.insert(res_slice[idx]); + }); + + let mut cur_recall: u32 = 0; + for v in gt.iter() { + if res.contains(v) { + cur_recall += 1; + } + } + + total_recall += cur_recall as f64; + } + + Ok(total_recall / num_queries as f64 * (100.0 / recall_at as f64)) +} + +pub(crate) fn get_graph_num_frozen_points(graph_file: &str) -> ANNResult { + let mut file = File::open(graph_file)?; + let mut usize_buffer = [0; size_of::()]; + let mut u32_buffer = [0; size_of::()]; + + file.read_exact(&mut usize_buffer)?; + file.read_exact(&mut u32_buffer)?; + file.read_exact(&mut u32_buffer)?; + file.read_exact(&mut usize_buffer)?; + let file_frozen_pts = usize::from_le_bytes(usize_buffer); + + Ok(file_frozen_pts) +} + +#[inline] +pub(crate) fn load_truthset( + bin_file: &str, +) -> ANNResult<(Vec, Option>, usize, usize)> { + let mut file = File::open(bin_file)?; + let actual_file_size = file.metadata()?.len() as usize; + + let mut buffer = [0; size_of::()]; + file.read_exact(&mut buffer)?; + let npts = i32::from_le_bytes(buffer) as usize; + + file.read_exact(&mut buffer)?; + let dim = i32::from_le_bytes(buffer) as usize; + + println!("Metadata: #pts = {npts}, #dims = {dim}... "); + + let expected_file_size_with_dists: usize = + 2 * npts * dim * size_of::() + 2 * size_of::(); + let expected_file_size_just_ids: usize = npts * dim * size_of::() + 2 * size_of::(); + + let truthset_type : i32 = match actual_file_size + { + // This is in the C++ code, but nothing is done in this case. Keeping it here for future reference just in case. + // expected_file_size_just_ids => 2, + x if x == expected_file_size_with_dists => 1, + _ => return Err(ANNError::log_index_error(format!("Error. File size mismatch. File should have bin format, with npts followed by ngt + followed by npts*ngt ids and optionally followed by npts*ngt distance values; actual size: {}, expected: {} or {}", + actual_file_size, + expected_file_size_with_dists, + expected_file_size_just_ids))) + }; + + let mut ids: Vec = vec![0; npts * dim]; + let mut buffer = vec![0; npts * dim * size_of::()]; + file.read_exact(&mut buffer)?; + ids.clone_from_slice(cast_slice::(&buffer)); + + if truthset_type == 1 { + let mut dists: Vec = vec![0.0; npts * dim]; + let mut buffer = vec![0; npts * dim * size_of::()]; + file.read_exact(&mut buffer)?; + dists.clone_from_slice(cast_slice::(&buffer)); + + return Ok((ids, Some(dists), npts, dim)); + } + + Ok((ids, None, npts, dim)) +} + +#[inline] +pub(crate) fn load_aligned_bin( + bin_file: &str, +) -> ANNResult<(AlignedBoxWithSlice, usize, usize, usize)> { + let t_size = size_of::(); + let (npts, dim, file_size): (usize, usize, usize); + { + println!("Reading (with alignment) bin file: {bin_file}"); + let mut file = File::open(bin_file)?; + file_size = file.metadata()?.len() as usize; + + let mut buffer = [0; size_of::()]; + file.read_exact(&mut buffer)?; + npts = i32::from_le_bytes(buffer) as usize; + + file.read_exact(&mut buffer)?; + dim = i32::from_le_bytes(buffer) as usize; + } + + let rounded_dim = round_up(dim, 8); + let expected_actual_file_size = npts * dim * size_of::() + 2 * size_of::(); + + if file_size != expected_actual_file_size { + return Err(ANNError::log_index_error(format!( + "ERROR: File size mismatch. Actual size is {} while expected size is {} + npts = {}, #dims = {}, aligned_dim = {}", + file_size, expected_actual_file_size, npts, dim, rounded_dim + ))); + } + + println!("Metadata: #pts = {npts}, #dims = {dim}, aligned_dim = {rounded_dim}..."); + + let alloc_size = npts * rounded_dim; + let alignment = 8 * t_size; + println!( + "allocating aligned memory of {} bytes... ", + alloc_size * t_size + ); + if !is_aligned(alloc_size * t_size, alignment) { + return Err(ANNError::log_index_error(format!( + "Requested memory size is not a multiple of {}. Can not be allocated.", + alignment + ))); + } + + let mut data = AlignedBoxWithSlice::::new(alloc_size, alignment)?; + let dto = DatasetDto { + data: &mut data, + rounded_dim, + }; + + println!("done. Copying data to mem_aligned buffer..."); + + let (_, _) = copy_aligned_data_from_file(bin_file, dto, 0)?; + + Ok((data, npts, dim, rounded_dim)) +} diff --git a/rust/diskann/Cargo.toml b/rust/diskann/Cargo.toml new file mode 100644 index 000000000..a5be54750 --- /dev/null +++ b/rust/diskann/Cargo.toml @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "diskann" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bincode = "1.3.3" +bit-vec = "0.6.3" +byteorder = "1.4.3" +cblas = "0.4.0" +crossbeam = "0.8.2" +half = "2.2.1" +hashbrown = "0.13.2" +num-traits = "0.2.15" +once_cell = "1.17.1" +openblas-src = { version = "0.10.8", features = ["system"] } +rand = { version = "0.8.5", features = [ "small_rng" ] } +rayon = "1.7.0" +serde = { version = "1.0.130", features = ["derive"] } +thiserror = "1.0.40" +winapi = { version = "0.3.9", features = ["errhandlingapi", "fileapi", "ioapiset", "handleapi", "winnt", "minwindef", "basetsd", "winerror", "winbase"] } + +logger = { path = "../logger" } +platform = { path = "../platform" } +vector = { path = "../vector" } + +[build-dependencies] +cc = "1.0.79" + +[dev-dependencies] +approx = "0.5.1" +criterion = "0.5.1" + + +[[bench]] +name = "distance_bench" +harness = false + +[[bench]] +name = "neighbor_bench" +harness = false diff --git a/rust/diskann/benches/distance_bench.rs b/rust/diskann/benches/distance_bench.rs new file mode 100644 index 000000000..885c95bac --- /dev/null +++ b/rust/diskann/benches/distance_bench.rs @@ -0,0 +1,47 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +use rand::{thread_rng, Rng}; +use vector::{FullPrecisionDistance, Metric}; + +// make sure the vector is 256-bit (32 bytes) aligned required by _mm256_load_ps +#[repr(C, align(32))] +struct Vector32ByteAligned { + v: [f32; 256], +} + +fn benchmark_l2_distance_float_rust(c: &mut Criterion) { + let (a, b) = prepare_random_aligned_vectors(); + let mut group = c.benchmark_group("avx-computation"); + group.sample_size(5000); + + group.bench_function("AVX Rust run", |f| { + f.iter(|| { + black_box(<[f32; 256]>::distance_compare( + black_box(&a.v), + black_box(&b.v), + Metric::L2, + )) + }) + }); +} + +// make sure the vector is 256-bit (32 bytes) aligned required by _mm256_load_ps +fn prepare_random_aligned_vectors() -> (Box, Box) { + let a = Box::new(Vector32ByteAligned { + v: [(); 256].map(|_| thread_rng().gen_range(0.0..100.0)), + }); + + let b = Box::new(Vector32ByteAligned { + v: [(); 256].map(|_| thread_rng().gen_range(0.0..100.0)), + }); + + (a, b) +} + +criterion_group!(benches, benchmark_l2_distance_float_rust,); +criterion_main!(benches); + diff --git a/rust/diskann/benches/kmeans_bench.rs b/rust/diskann/benches/kmeans_bench.rs new file mode 100644 index 000000000..c69c16a8c --- /dev/null +++ b/rust/diskann/benches/kmeans_bench.rs @@ -0,0 +1,70 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use criterion::{criterion_group, criterion_main, Criterion}; +use diskann::utils::k_means_clustering; +use rand::Rng; + +const NUM_POINTS: usize = 10000; +const DIM: usize = 100; +const NUM_CENTERS: usize = 256; +const MAX_KMEANS_REPS: usize = 12; + +fn benchmark_kmeans_rust(c: &mut Criterion) { + let mut rng = rand::thread_rng(); + let data: Vec = (0..NUM_POINTS * DIM) + .map(|_| rng.gen_range(-1.0..1.0)) + .collect(); + let centers: Vec = vec![0.0; NUM_CENTERS * DIM]; + + let mut group = c.benchmark_group("kmeans-computation"); + group.sample_size(500); + + group.bench_function("K-Means Rust run", |f| { + f.iter(|| { + // let mut centers_copy = centers.clone(); + let data_copy = data.clone(); + let mut centers_copy = centers.clone(); + k_means_clustering( + &data_copy, + NUM_POINTS, + DIM, + &mut centers_copy, + NUM_CENTERS, + MAX_KMEANS_REPS, + ) + }) + }); +} + +fn benchmark_kmeans_c(c: &mut Criterion) { + let mut rng = rand::thread_rng(); + let data: Vec = (0..NUM_POINTS * DIM) + .map(|_| rng.gen_range(-1.0..1.0)) + .collect(); + let centers: Vec = vec![0.0; NUM_CENTERS * DIM]; + + let mut group = c.benchmark_group("kmeans-computation"); + group.sample_size(500); + + group.bench_function("K-Means C++ Run", |f| { + f.iter(|| { + let data_copy = data.clone(); + let mut centers_copy = centers.clone(); + let _ = k_means_clustering( + data_copy.as_slice(), + NUM_POINTS, + DIM, + centers_copy.as_mut_slice(), + NUM_CENTERS, + MAX_KMEANS_REPS, + ); + }) + }); +} + +criterion_group!(benches, benchmark_kmeans_rust, benchmark_kmeans_c); + +criterion_main!(benches); + diff --git a/rust/diskann/benches/neighbor_bench.rs b/rust/diskann/benches/neighbor_bench.rs new file mode 100644 index 000000000..958acdce2 --- /dev/null +++ b/rust/diskann/benches/neighbor_bench.rs @@ -0,0 +1,49 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::time::Duration; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +use diskann::model::{Neighbor, NeighborPriorityQueue}; +use rand::distributions::{Distribution, Uniform}; +use rand::rngs::StdRng; +use rand::SeedableRng; + +fn benchmark_priority_queue_insert(c: &mut Criterion) { + let vec = generate_random_floats(); + let mut group = c.benchmark_group("neighborqueue-insert"); + group.measurement_time(Duration::from_secs(3)).sample_size(500); + + let mut queue = NeighborPriorityQueue::with_capacity(64_usize); + group.bench_function("Neighbor Priority Queue Insert", |f| { + f.iter(|| { + queue.clear(); + for n in vec.iter() { + queue.insert(*n); + } + + black_box(&1) + }); + }); +} + +fn generate_random_floats() -> Vec { + let seed: [u8; 32] = [73; 32]; + let mut rng: StdRng = SeedableRng::from_seed(seed); + let range = Uniform::new(0.0, 1.0); + let mut random_floats = Vec::with_capacity(100); + + for i in 0..100 { + let random_float = range.sample(&mut rng) as f32; + let n = Neighbor::new(i, random_float); + random_floats.push(n); + } + + random_floats +} + +criterion_group!(benches, benchmark_priority_queue_insert); +criterion_main!(benches); + diff --git a/rust/diskann/src/algorithm/mod.rs b/rust/diskann/src/algorithm/mod.rs new file mode 100644 index 000000000..87e377c8b --- /dev/null +++ b/rust/diskann/src/algorithm/mod.rs @@ -0,0 +1,7 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod search; + +pub mod prune; diff --git a/rust/diskann/src/algorithm/prune/mod.rs b/rust/diskann/src/algorithm/prune/mod.rs new file mode 100644 index 000000000..4627eeb10 --- /dev/null +++ b/rust/diskann/src/algorithm/prune/mod.rs @@ -0,0 +1,6 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +pub mod prune; diff --git a/rust/diskann/src/algorithm/prune/prune.rs b/rust/diskann/src/algorithm/prune/prune.rs new file mode 100644 index 000000000..40fec4a5d --- /dev/null +++ b/rust/diskann/src/algorithm/prune/prune.rs @@ -0,0 +1,288 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use hashbrown::HashSet; +use vector::{FullPrecisionDistance, Metric}; + +use crate::common::{ANNError, ANNResult}; +use crate::index::InmemIndex; +use crate::model::graph::AdjacencyList; +use crate::model::neighbor::SortedNeighborVector; +use crate::model::scratch::InMemQueryScratch; +use crate::model::Neighbor; + +impl InmemIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + /// A method that occludes a list of neighbors based on some criteria + #[allow(clippy::too_many_arguments)] + fn occlude_list( + &self, + location: u32, + pool: &mut SortedNeighborVector, + alpha: f32, + degree: u32, + max_candidate_size: usize, + result: &mut AdjacencyList, + scratch: &mut InMemQueryScratch, + delete_set_ptr: Option<&HashSet>, + ) -> ANNResult<()> { + if pool.is_empty() { + return Ok(()); + } + + if !result.is_empty() { + return Err(ANNError::log_index_error( + "result is not empty.".to_string(), + )); + } + + // Truncate pool at max_candidate_size and initialize scratch spaces + if pool.len() > max_candidate_size { + pool.truncate(max_candidate_size); + } + + let occlude_factor = &mut scratch.occlude_factor; + + // occlude_list can be called with the same scratch more than once by + // search_for_point_and_add_link through inter_insert. + occlude_factor.clear(); + + // Initialize occlude_factor to pool.len() many 0.0 values for correctness + occlude_factor.resize(pool.len(), 0.0); + + let mut cur_alpha = 1.0; + while cur_alpha <= alpha && result.len() < degree as usize { + for (i, neighbor) in pool.iter().enumerate() { + if result.len() >= degree as usize { + break; + } + if occlude_factor[i] > cur_alpha { + continue; + } + // Set the entry to f32::MAX so that is not considered again + occlude_factor[i] = f32::MAX; + + // Add the entry to the result if its not been deleted, and doesn't + // add a self loop + if delete_set_ptr.map_or(true, |delete_set| !delete_set.contains(&neighbor.id)) + && neighbor.id != location + { + result.push(neighbor.id); + } + + // Update occlude factor for points from i+1 to pool.len() + for (j, neighbor2) in pool.iter().enumerate().skip(i + 1) { + if occlude_factor[j] > alpha { + continue; + } + + // todo - self.filtered_index + let djk = self.get_distance(neighbor2.id, neighbor.id)?; + match self.configuration.dist_metric { + Metric::L2 | Metric::Cosine => { + occlude_factor[j] = if djk == 0.0 { + f32::MAX + } else { + occlude_factor[j].max(neighbor2.distance / djk) + }; + } + } + } + } + + cur_alpha *= 1.2; + } + + Ok(()) + } + + /// Prunes the neighbors of a given data point based on some criteria and returns a list of pruned ids. + /// + /// # Arguments + /// + /// * `location` - The id of the data point whose neighbors are to be pruned. + /// * `pool` - A vector of neighbors to be pruned, sorted by distance to the query point. + /// * `pruned_list` - A vector to store the ids of the pruned neighbors. + /// * `scratch` - A mutable reference to a scratch space for in-memory queries. + /// + /// # Panics + /// + /// Panics if `pruned_list` contains more than `range` elements after pruning. + pub fn prune_neighbors( + &self, + location: u32, + pool: &mut Vec, + pruned_list: &mut AdjacencyList, + scratch: &mut InMemQueryScratch, + ) -> ANNResult<()> { + self.robust_prune( + location, + pool, + self.configuration.index_write_parameter.max_degree, + self.configuration.index_write_parameter.max_occlusion_size, + self.configuration.index_write_parameter.alpha, + pruned_list, + scratch, + ) + } + + /// Prunes the neighbors of a given data point based on some criteria and returns a list of pruned ids. + /// + /// # Arguments + /// + /// * `location` - The id of the data point whose neighbors are to be pruned. + /// * `pool` - A vector of neighbors to be pruned, sorted by distance to the query point. + /// * `range` - The maximum number of neighbors to keep after pruning. + /// * `max_candidate_size` - The maximum number of candidates to consider for pruning. + /// * `alpha` - A parameter that controls the occlusion pruning strategy. + /// * `pruned_list` - A vector to store the ids of the pruned neighbors. + /// * `scratch` - A mutable reference to a scratch space for in-memory queries. + /// + /// # Error + /// + /// Return error if `pruned_list` contains more than `range` elements after pruning. + #[allow(clippy::too_many_arguments)] + fn robust_prune( + &self, + location: u32, + pool: &mut Vec, + range: u32, + max_candidate_size: u32, + alpha: f32, + pruned_list: &mut AdjacencyList, + scratch: &mut InMemQueryScratch, + ) -> ANNResult<()> { + if pool.is_empty() { + // if the pool is empty, behave like a noop + pruned_list.clear(); + return Ok(()); + } + + // If using _pq_build, over-write the PQ distances with actual distances + // todo : pq_dist + + // sort the pool based on distance to query and prune it with occlude_list + let mut pool = SortedNeighborVector::new(pool); + pruned_list.clear(); + + self.occlude_list( + location, + &mut pool, + alpha, + range, + max_candidate_size as usize, + pruned_list, + scratch, + Option::None, + )?; + + if pruned_list.len() > range as usize { + return Err(ANNError::log_index_error(format!( + "pruned_list's len {} is over range {}.", + pruned_list.len(), + range + ))); + } + + if self.configuration.index_write_parameter.saturate_graph && alpha > 1.0f32 { + for neighbor in pool.iter() { + if pruned_list.len() >= (range as usize) { + break; + } + if !pruned_list.contains(&neighbor.id) && neighbor.id != location { + pruned_list.push(neighbor.id); + } + } + } + + Ok(()) + } + + /// A method that inserts a point n into the graph of its neighbors and their neighbors, + /// pruning the graph if necessary to keep it within the specified range + /// * `n` - The index of the new point + /// * `pruned_list` is a vector of the neighbors of n that have been pruned by a previous step + /// * `range` is the target number of neighbors for each point + /// * `scratch` is a mutable reference to a scratch space that can be reused for intermediate computations + pub fn inter_insert( + &self, + n: u32, + pruned_list: &Vec, + range: u32, + scratch: &mut InMemQueryScratch, + ) -> ANNResult<()> { + // Borrow the pruned_list as a source pool of neighbors + let src_pool = pruned_list; + + if src_pool.is_empty() { + return Err(ANNError::log_index_error("src_pool is empty.".to_string())); + } + + for &vertex_id in src_pool { + // vertex is the index of a neighbor of n + // Assert that vertex is within the valid range of points + if (vertex_id as usize) + >= self.configuration.max_points + self.configuration.num_frozen_pts + { + return Err(ANNError::log_index_error(format!( + "vertex_id {} is out of valid range of points {}", + vertex_id, + self.configuration.max_points + self.configuration.num_frozen_pts, + ))); + } + + let neighbors = self.add_to_neighbors(vertex_id, n, range)?; + + if let Some(copy_of_neighbors) = neighbors { + // Pruning is needed, create a dummy set and a dummy vector to store the unique neighbors of vertex_id + let mut dummy_pool = self.get_unique_neighbors(©_of_neighbors, vertex_id)?; + + // Create a new vector to store the pruned neighbors of vertex_id + let mut new_out_neighbors = + AdjacencyList::for_range(self.configuration.write_range()); + // Prune the neighbors of vertex_id using a helper method + self.prune_neighbors(vertex_id, &mut dummy_pool, &mut new_out_neighbors, scratch)?; + + self.set_neighbors(vertex_id, new_out_neighbors)?; + } + } + + Ok(()) + } + + /// Adds a node to the list of neighbors for the given node. + /// + /// # Arguments + /// + /// * `vertex_id` - The ID of the node to add the neighbor to. + /// * `node_id` - The ID of the node to add. + /// * `range` - The range of the graph. + /// + /// # Return + /// + /// Returns `None` if the node is already in the list of neighbors, or a `Vec` containing the updated list of neighbors if the list of neighbors is full. + fn add_to_neighbors( + &self, + vertex_id: u32, + node_id: u32, + range: u32, + ) -> ANNResult>> { + // vertex contains a vector of the neighbors of vertex_id + let mut vertex_guard = self.final_graph.write_vertex_and_neighbors(vertex_id)?; + + Ok(vertex_guard.add_to_neighbors(node_id, range)) + } + + fn set_neighbors(&self, vertex_id: u32, new_out_neighbors: AdjacencyList) -> ANNResult<()> { + // vertex contains a vector of the neighbors of vertex_id + let mut vertex_guard = self.final_graph.write_vertex_and_neighbors(vertex_id)?; + + vertex_guard.set_neighbors(new_out_neighbors); + Ok(()) + } +} + diff --git a/rust/diskann/src/algorithm/search/mod.rs b/rust/diskann/src/algorithm/search/mod.rs new file mode 100644 index 000000000..9f007ab69 --- /dev/null +++ b/rust/diskann/src/algorithm/search/mod.rs @@ -0,0 +1,7 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +pub mod search; + diff --git a/rust/diskann/src/algorithm/search/search.rs b/rust/diskann/src/algorithm/search/search.rs new file mode 100644 index 000000000..ab6d01696 --- /dev/null +++ b/rust/diskann/src/algorithm/search/search.rs @@ -0,0 +1,359 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Search algorithm for index construction and query + +use crate::common::{ANNError, ANNResult}; +use crate::index::InmemIndex; +use crate::model::{scratch::InMemQueryScratch, Neighbor, Vertex}; +use hashbrown::hash_set::Entry::*; +use vector::FullPrecisionDistance; + +impl InmemIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + /// Search for query using given L value, for benchmarking purposes + /// # Arguments + /// * `query` - query vertex + /// * `scratch` - in-memory query scratch + /// * `search_list_size` - search list size to use for the benchmark + pub fn search_with_l_override( + &self, + query: &Vertex, + scratch: &mut InMemQueryScratch, + search_list_size: usize, + ) -> ANNResult { + let init_ids = self.get_init_ids()?; + self.init_graph_for_point(query, init_ids, scratch)?; + // Scratch is created using largest L val from search_memory_index, so we artifically make it smaller here + // This allows us to use the same scratch for all L values without having to rebuild the query scratch + scratch.best_candidates.set_capacity(search_list_size); + let (_, cmp) = self.greedy_search(query, scratch)?; + + Ok(cmp) + } + + /// search for point + /// # Arguments + /// * `query` - query vertex + /// * `scratch` - in-memory query scratch + /// TODO: use_filter, filteredLindex + pub fn search_for_point( + &self, + query: &Vertex, + scratch: &mut InMemQueryScratch, + ) -> ANNResult> { + let init_ids = self.get_init_ids()?; + self.init_graph_for_point(query, init_ids, scratch)?; + let (mut visited_nodes, _) = self.greedy_search(query, scratch)?; + + visited_nodes.retain(|&element| element.id != query.vertex_id()); + Ok(visited_nodes) + } + + /// Returns the locations of start point and frozen points suitable for use with iterate_to_fixed_point. + fn get_init_ids(&self) -> ANNResult> { + let mut init_ids = Vec::with_capacity(1 + self.configuration.num_frozen_pts); + init_ids.push(self.start); + + for frozen in self.configuration.max_points + ..(self.configuration.max_points + self.configuration.num_frozen_pts) + { + let frozen_u32 = frozen.try_into()?; + if frozen_u32 != self.start { + init_ids.push(frozen_u32); + } + } + + Ok(init_ids) + } + + /// Initialize graph for point + /// # Arguments + /// * `query` - query vertex + /// * `init_ids` - initial nodes from which search starts + /// * `scratch` - in-memory query scratch + /// * `search_list_size_override` - override for search list size in index config + fn init_graph_for_point( + &self, + query: &Vertex, + init_ids: Vec, + scratch: &mut InMemQueryScratch, + ) -> ANNResult<()> { + scratch + .best_candidates + .reserve(self.configuration.index_write_parameter.search_list_size as usize); + scratch.query.memcpy(query.vector())?; + + if !scratch.id_scratch.is_empty() { + return Err(ANNError::log_index_error( + "id_scratch is not empty.".to_string(), + )); + } + + let query_vertex = Vertex::::try_from((&scratch.query[..], query.vertex_id())) + .map_err(|err| { + ANNError::log_index_error(format!( + "TryFromSliceError: failed to get Vertex for query, err={}", + err + )) + })?; + + for id in init_ids { + if (id as usize) >= self.configuration.max_points + self.configuration.num_frozen_pts { + return Err(ANNError::log_index_error(format!( + "vertex_id {} is out of valid range of points {}", + id, + self.configuration.max_points + self.configuration.num_frozen_pts + ))); + } + + if let Vacant(entry) = scratch.node_visited_robinset.entry(id) { + entry.insert(); + + let vertex = self.dataset.get_vertex(id)?; + + let distance = vertex.compare(&query_vertex, self.configuration.dist_metric); + let neighbor = Neighbor::new(id, distance); + scratch.best_candidates.insert(neighbor); + } + } + + Ok(()) + } + + /// GreedySearch against query node + /// Returns visited nodes + /// # Arguments + /// * `query` - query vertex + /// * `scratch` - in-memory query scratch + /// TODO: use_filter, filter_label, search_invocation + fn greedy_search( + &self, + query: &Vertex, + scratch: &mut InMemQueryScratch, + ) -> ANNResult<(Vec, u32)> { + let mut visited_nodes = + Vec::with_capacity((3 * scratch.candidate_size + scratch.max_degree) as usize); + + // TODO: uncomment hops? + // let mut hops: u32 = 0; + let mut cmps: u32 = 0; + + let query_vertex = Vertex::::try_from((&scratch.query[..], query.vertex_id())) + .map_err(|err| { + ANNError::log_index_error(format!( + "TryFromSliceError: failed to get Vertex for query, err={}", + err + )) + })?; + + while scratch.best_candidates.has_notvisited_node() { + let closest_node = scratch.best_candidates.closest_notvisited(); + + // Add node to visited nodes to create pool for prune later + // TODO: search_invocation and use_filter + visited_nodes.push(closest_node); + + // Find which of the nodes in des have not been visited before + scratch.id_scratch.clear(); + + let max_vertex_id = self.configuration.max_points + self.configuration.num_frozen_pts; + + for id in self + .final_graph + .read_vertex_and_neighbors(closest_node.id)? + .get_neighbors() + { + let current_vertex_id = *id; + debug_assert!( + (current_vertex_id as usize) < max_vertex_id, + "current_vertex_id {} is out of valid range of points {}", + current_vertex_id, + max_vertex_id + ); + if current_vertex_id as usize >= max_vertex_id { + continue; + } + + // quickly de-dup. Remember, we are in a read lock + // we want to exit out of it quickly + if scratch.node_visited_robinset.insert(current_vertex_id) { + scratch.id_scratch.push(current_vertex_id); + } + } + + let len = scratch.id_scratch.len(); + for (m, &id) in scratch.id_scratch.iter().enumerate() { + if m + 1 < len { + let next_node = unsafe { *scratch.id_scratch.get_unchecked(m + 1) }; + self.dataset.prefetch_vector(next_node); + } + + let vertex = self.dataset.get_vertex(id)?; + let distance = query_vertex.compare(&vertex, self.configuration.dist_metric); + + // Insert pairs into the pool of candidates + scratch.best_candidates.insert(Neighbor::new(id, distance)); + } + + cmps += len as u32; + } + + Ok((visited_nodes, cmps)) + } +} + +#[cfg(test)] +mod search_test { + use vector::Metric; + + use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder; + use crate::model::graph::AdjacencyList; + use crate::model::IndexConfiguration; + use crate::test_utils::inmem_index_initialization::create_index_with_test_data; + + use super::*; + + #[test] + fn get_init_ids_no_forzen_pts() { + let index_write_parameters = IndexWriteParametersBuilder::new(50, 4) + .with_alpha(1.2) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + 256, + 256, + 256, + false, + 0, + false, + 0, + 1f32, + index_write_parameters, + ); + + let index = InmemIndex::::new(config).unwrap(); + let init_ids = index.get_init_ids().unwrap(); + assert_eq!(init_ids.len(), 1); + assert_eq!(init_ids[0], 256); + } + + #[test] + fn get_init_ids_with_forzen_pts() { + let index_write_parameters = IndexWriteParametersBuilder::new(50, 4) + .with_alpha(1.2) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + 256, + 256, + 256, + false, + 0, + false, + 2, + 1f32, + index_write_parameters, + ); + + let index = InmemIndex::::new(config).unwrap(); + let init_ids = index.get_init_ids().unwrap(); + assert_eq!(init_ids.len(), 2); + assert_eq!(init_ids[0], 256); + assert_eq!(init_ids[1], 257); + } + + #[test] + fn search_for_point_initial_call() { + let index = create_index_with_test_data(); + let query = index.dataset.get_vertex(0).unwrap(); + + let mut scratch = InMemQueryScratch::new( + index.configuration.index_write_parameter.search_list_size, + &index.configuration.index_write_parameter, + false, + ) + .unwrap(); + let visited_nodes = index.search_for_point(&query, &mut scratch).unwrap(); + assert_eq!(visited_nodes.len(), 1); + assert_eq!(scratch.best_candidates.size(), 1); + assert_eq!(scratch.best_candidates[0].id, 72); + assert_eq!(scratch.best_candidates[0].distance, 125678.0_f32); + assert!(scratch.best_candidates[0].visited); + } + + fn set_neighbors(index: &InmemIndex, vertex_id: u32, neighbors: Vec) { + index + .final_graph + .write_vertex_and_neighbors(vertex_id) + .unwrap() + .set_neighbors(AdjacencyList::from(neighbors)); + } + #[test] + fn search_for_point_works_with_edges() { + let index = create_index_with_test_data(); + let query = index.dataset.get_vertex(14).unwrap(); + + set_neighbors(&index, 0, vec![12, 72, 5, 9]); + set_neighbors(&index, 1, vec![2, 12, 10, 4]); + set_neighbors(&index, 2, vec![1, 72, 9]); + set_neighbors(&index, 3, vec![13, 6, 5, 11]); + set_neighbors(&index, 4, vec![1, 3, 7, 9]); + set_neighbors(&index, 5, vec![3, 0, 8, 11, 13]); + set_neighbors(&index, 6, vec![3, 72, 7, 10, 13]); + set_neighbors(&index, 7, vec![72, 4, 6]); + set_neighbors(&index, 8, vec![72, 5, 9, 12]); + set_neighbors(&index, 9, vec![8, 4, 0, 2]); + set_neighbors(&index, 10, vec![72, 1, 9, 6]); + set_neighbors(&index, 11, vec![3, 0, 5]); + set_neighbors(&index, 12, vec![1, 0, 8, 9]); + set_neighbors(&index, 13, vec![3, 72, 5, 6]); + set_neighbors(&index, 72, vec![7, 2, 10, 8, 13]); + + let mut scratch = InMemQueryScratch::new( + index.configuration.index_write_parameter.search_list_size, + &index.configuration.index_write_parameter, + false, + ) + .unwrap(); + let visited_nodes = index.search_for_point(&query, &mut scratch).unwrap(); + assert_eq!(visited_nodes.len(), 15); + assert_eq!(scratch.best_candidates.size(), 15); + assert_eq!(scratch.best_candidates[0].id, 2); + assert_eq!(scratch.best_candidates[0].distance, 120899.0_f32); + assert_eq!(scratch.best_candidates[1].id, 8); + assert_eq!(scratch.best_candidates[1].distance, 145538.0_f32); + assert_eq!(scratch.best_candidates[2].id, 72); + assert_eq!(scratch.best_candidates[2].distance, 146046.0_f32); + assert_eq!(scratch.best_candidates[3].id, 4); + assert_eq!(scratch.best_candidates[3].distance, 148462.0_f32); + assert_eq!(scratch.best_candidates[4].id, 7); + assert_eq!(scratch.best_candidates[4].distance, 148912.0_f32); + assert_eq!(scratch.best_candidates[5].id, 10); + assert_eq!(scratch.best_candidates[5].distance, 154570.0_f32); + assert_eq!(scratch.best_candidates[6].id, 1); + assert_eq!(scratch.best_candidates[6].distance, 159448.0_f32); + assert_eq!(scratch.best_candidates[7].id, 12); + assert_eq!(scratch.best_candidates[7].distance, 170698.0_f32); + assert_eq!(scratch.best_candidates[8].id, 9); + assert_eq!(scratch.best_candidates[8].distance, 177205.0_f32); + assert_eq!(scratch.best_candidates[9].id, 0); + assert_eq!(scratch.best_candidates[9].distance, 259996.0_f32); + assert_eq!(scratch.best_candidates[10].id, 6); + assert_eq!(scratch.best_candidates[10].distance, 371819.0_f32); + assert_eq!(scratch.best_candidates[11].id, 5); + assert_eq!(scratch.best_candidates[11].distance, 385240.0_f32); + assert_eq!(scratch.best_candidates[12].id, 3); + assert_eq!(scratch.best_candidates[12].distance, 413899.0_f32); + assert_eq!(scratch.best_candidates[13].id, 13); + assert_eq!(scratch.best_candidates[13].distance, 416386.0_f32); + assert_eq!(scratch.best_candidates[14].id, 11); + assert_eq!(scratch.best_candidates[14].distance, 449266.0_f32); + } +} diff --git a/rust/diskann/src/common/aligned_allocator.rs b/rust/diskann/src/common/aligned_allocator.rs new file mode 100644 index 000000000..6164a1f40 --- /dev/null +++ b/rust/diskann/src/common/aligned_allocator.rs @@ -0,0 +1,281 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Aligned allocator + +use std::alloc::Layout; +use std::ops::{Deref, DerefMut, Range}; +use std::ptr::copy_nonoverlapping; + +use super::{ANNResult, ANNError}; + +#[derive(Debug)] +/// A box that holds a slice but is aligned to the specified layout. +/// +/// This type is useful for working with types that require a certain alignment, +/// such as SIMD vectors or FFI structs. It allocates memory using the global allocator +/// and frees it when dropped. It also implements Deref and DerefMut to allow access +/// to the underlying slice. +pub struct AlignedBoxWithSlice { + /// The layout of the allocated memory. + layout: Layout, + + /// The slice that points to the allocated memory. + val: Box<[T]>, +} + +impl AlignedBoxWithSlice { + /// Creates a new `AlignedBoxWithSlice` with the given capacity and alignment. + /// The allocated memory are set to 0. + /// + /// # Error + /// + /// Return IndexError if the alignment is not a power of two or if the layout is invalid. + /// + /// This function is unsafe because it allocates uninitialized memory and casts it to + /// a slice of `T`. The caller must ensure that the capacity and alignment are valid + /// for the type `T` and that the memory is initialized before accessing the elements + /// of the slice. + pub fn new(capacity: usize, alignment: usize) -> ANNResult { + let allocsize = capacity.checked_mul(std::mem::size_of::()) + .ok_or_else(|| ANNError::log_index_error("capacity overflow".to_string()))?; + let layout = Layout::from_size_align(allocsize, alignment) + .map_err(ANNError::log_mem_alloc_layout_error)?; + + let val = unsafe { + let mem = std::alloc::alloc_zeroed(layout); + let ptr = mem as *mut T; + let slice = std::slice::from_raw_parts_mut(ptr, capacity); + std::boxed::Box::from_raw(slice) + }; + + Ok(Self { layout, val }) + } + + /// Returns a reference to the slice. + pub fn as_slice(&self) -> &[T] { + &self.val + } + + /// Returns a mutable reference to the slice. + pub fn as_mut_slice(&mut self) -> &mut [T] { + &mut self.val + } + + /// Copies data from the source slice to the destination box. + pub fn memcpy(&mut self, src: &[T]) -> ANNResult<()> { + if src.len() > self.val.len() { + return Err(ANNError::log_index_error(format!("source slice is too large (src:{}, dst:{})", src.len(), self.val.len()))); + } + + // Check that they don't overlap + let src_ptr = src.as_ptr(); + let src_end = unsafe { src_ptr.add(src.len()) }; + let dst_ptr = self.val.as_mut_ptr(); + let dst_end = unsafe { dst_ptr.add(self.val.len()) }; + + if src_ptr < dst_end && src_end > dst_ptr { + return Err(ANNError::log_index_error("Source and destination overlap".to_string())); + } + + unsafe { + copy_nonoverlapping(src.as_ptr(), self.val.as_mut_ptr(), src.len()); + } + + Ok(()) + } + + /// Split the range of memory into nonoverlapping mutable slices. + /// The number of returned slices is (range length / slice_len) and each has a length of slice_len. + pub fn split_into_nonoverlapping_mut_slices(&mut self, range: Range, slice_len: usize) -> ANNResult> { + if range.len() % slice_len != 0 || range.end > self.len() { + return Err(ANNError::log_index_error(format!( + "Cannot split range ({:?}) of AlignedBoxWithSlice (len: {}) into nonoverlapping mutable slices with length {}", + range, + self.len(), + slice_len, + ))); + } + + let mut slices = Vec::with_capacity(range.len() / slice_len); + let mut remaining_slice = &mut self.val[range]; + + while remaining_slice.len() >= slice_len { + let (left, right) = remaining_slice.split_at_mut(slice_len); + slices.push(left); + remaining_slice = right; + } + + Ok(slices) + } +} + + +impl Drop for AlignedBoxWithSlice { + /// Frees the memory allocated for the slice using the global allocator. + fn drop(&mut self) { + let val = std::mem::take(&mut self.val); + let mut val2 = std::mem::ManuallyDrop::new(val); + let ptr = val2.as_mut_ptr(); + + unsafe { + // let nonNull = NonNull::new_unchecked(ptr as *mut u8); + std::alloc::dealloc(ptr as *mut u8, self.layout) + } + } +} + +impl Deref for AlignedBoxWithSlice { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + &self.val + } +} + +impl DerefMut for AlignedBoxWithSlice { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.val + } +} + +#[cfg(test)] +mod tests { + use rand::Rng; + + use crate::utils::is_aligned; + + use super::*; + + #[test] + fn create_alignedvec_works_32() { + (0..100).for_each(|_| { + let size = 1_000_000; + println!("Attempting {}", size); + let data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + assert_eq!(data.len(), size, "Capacity should match"); + + let ptr = data.as_ptr() as usize; + assert_eq!(ptr % 32, 0, "Ptr should be aligned to 32"); + + // assert that the slice is initialized. + (0..size).for_each(|i| { + assert_eq!(data[i], f32::default()); + }); + + drop(data); + }); + } + + #[test] + fn create_alignedvec_works_256() { + let mut rng = rand::thread_rng(); + + (0..100).for_each(|_| { + let n = rng.gen::(); + let size = usize::from(n) + 1; + println!("Attempting {}", size); + let data = AlignedBoxWithSlice::::new(size, 256).unwrap(); + assert_eq!(data.len(), size, "Capacity should match"); + + let ptr = data.as_ptr() as usize; + assert_eq!(ptr % 256, 0, "Ptr should be aligned to 32"); + + // assert that the slice is initialized. + (0..size).for_each(|i| { + assert_eq!(data[i], u8::default()); + }); + + drop(data); + }); + } + + #[test] + fn as_slice_test() { + let size = 1_000_000; + let data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + // assert that the slice is initialized. + (0..size).for_each(|i| { + assert_eq!(data[i], f32::default()); + }); + + let slice = data.as_slice(); + (0..size).for_each(|i| { + assert_eq!(slice[i], f32::default()); + }); + } + + #[test] + fn as_mut_slice_test() { + let size = 1_000_000; + let mut data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + let mut_slice = data.as_mut_slice(); + (0..size).for_each(|i| { + assert_eq!(mut_slice[i], f32::default()); + }); + } + + #[test] + fn memcpy_test() { + let size = 1_000_000; + let mut data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + let mut destination = AlignedBoxWithSlice::::new(size-2, 32).unwrap(); + let mut_destination = destination.as_mut_slice(); + data.memcpy(mut_destination).unwrap(); + (0..size-2).for_each(|i| { + assert_eq!(data[i], mut_destination[i]); + }); + } + + #[test] + #[should_panic(expected = "source slice is too large (src:1000000, dst:999998)")] + fn memcpy_panic_test() { + let size = 1_000_000; + let mut data = AlignedBoxWithSlice::::new(size-2, 32).unwrap(); + let mut destination = AlignedBoxWithSlice::::new(size, 32).unwrap(); + let mut_destination = destination.as_mut_slice(); + data.memcpy(mut_destination).unwrap(); + } + + #[test] + fn is_aligned_test() { + assert!(is_aligned(256,256)); + assert!(!is_aligned(255,256)); + } + + #[test] + fn split_into_nonoverlapping_mut_slices_test() { + let size = 10; + let slice_len = 2; + let mut data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + let slices = data.split_into_nonoverlapping_mut_slices(2..8, slice_len).unwrap(); + assert_eq!(slices.len(), 3); + for (i, slice) in slices.into_iter().enumerate() { + assert_eq!(slice.len(), slice_len); + slice[0] = i as f32 + 1.0; + slice[1] = i as f32 + 1.0; + } + let expected_arr = [0.0f32, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0, 0.0]; + assert_eq!(data.as_ref(), &expected_arr); + } + + #[test] + fn split_into_nonoverlapping_mut_slices_error_when_indivisible() { + let size = 10; + let slice_len = 2; + let range = 2..7; + let mut data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + let result = data.split_into_nonoverlapping_mut_slices(range.clone(), slice_len); + let expected_err_str = format!( + "IndexError: Cannot split range ({:?}) of AlignedBoxWithSlice (len: {}) into nonoverlapping mutable slices with length {}", + range, + size, + slice_len, + ); + assert!(result.is_err_and(|e| e.to_string() == expected_err_str)); + } +} + diff --git a/rust/diskann/src/common/ann_result.rs b/rust/diskann/src/common/ann_result.rs new file mode 100644 index 000000000..69fcf03f6 --- /dev/null +++ b/rust/diskann/src/common/ann_result.rs @@ -0,0 +1,179 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::alloc::LayoutError; +use std::array::TryFromSliceError; +use std::io; +use std::num::TryFromIntError; + +use logger::error_logger::log_error; +use logger::log_error::LogError; + +/// Result +pub type ANNResult = Result; + +/// DiskANN Error +/// ANNError is `Send` (i.e., safe to send across threads) +#[derive(thiserror::Error, Debug)] +pub enum ANNError { + /// Index construction and search error + #[error("IndexError: {err}")] + IndexError { err: String }, + + /// Index configuration error + #[error("IndexConfigError: {parameter} is invalid, err={err}")] + IndexConfigError { parameter: String, err: String }, + + /// Integer conversion error + #[error("TryFromIntError: {err}")] + TryFromIntError { + #[from] + err: TryFromIntError, + }, + + /// IO error + #[error("IOError: {err}")] + IOError { + #[from] + err: io::Error, + }, + + /// Layout error in memory allocation + #[error("MemoryAllocLayoutError: {err}")] + MemoryAllocLayoutError { + #[from] + err: LayoutError, + }, + + /// PoisonError which can be returned whenever a lock is acquired + /// Both Mutexes and RwLocks are poisoned whenever a thread fails while the lock is held + #[error("LockPoisonError: {err}")] + LockPoisonError { err: String }, + + /// DiskIOAlignmentError which can be returned when calling windows API CreateFileA for the disk index file fails. + #[error("DiskIOAlignmentError: {err}")] + DiskIOAlignmentError { err: String }, + + /// Logging error + #[error("LogError: {err}")] + LogError { + #[from] + err: LogError, + }, + + // PQ construction error + // Error happened when we construct PQ pivot or PQ compressed table + #[error("PQError: {err}")] + PQError { err: String }, + + /// Array conversion error + #[error("Error try creating array from slice: {err}")] + TryFromSliceError { + #[from] + err: TryFromSliceError, + }, +} + +impl ANNError { + /// Create, log and return IndexError + #[inline] + pub fn log_index_error(err: String) -> Self { + let ann_err = ANNError::IndexError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return IndexConfigError + #[inline] + pub fn log_index_config_error(parameter: String, err: String) -> Self { + let ann_err = ANNError::IndexConfigError { parameter, err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return TryFromIntError + #[inline] + pub fn log_try_from_int_error(err: TryFromIntError) -> Self { + let ann_err = ANNError::TryFromIntError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return IOError + #[inline] + pub fn log_io_error(err: io::Error) -> Self { + let ann_err = ANNError::IOError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return DiskIOAlignmentError + /// #[inline] + pub fn log_disk_io_request_alignment_error(err: String) -> Self { + let ann_err: ANNError = ANNError::DiskIOAlignmentError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return IOError + #[inline] + pub fn log_mem_alloc_layout_error(err: LayoutError) -> Self { + let ann_err = ANNError::MemoryAllocLayoutError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return LockPoisonError + #[inline] + pub fn log_lock_poison_error(err: String) -> Self { + let ann_err = ANNError::LockPoisonError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return PQError + #[inline] + pub fn log_pq_error(err: String) -> Self { + let ann_err = ANNError::PQError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return TryFromSliceError + #[inline] + pub fn log_try_from_slice_error(err: TryFromSliceError) -> Self { + let ann_err = ANNError::TryFromSliceError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } +} + +#[cfg(test)] +mod ann_result_test { + use super::*; + + #[test] + fn ann_err_is_send() { + fn assert_send() {} + assert_send::(); + } +} diff --git a/rust/diskann/src/common/mod.rs b/rust/diskann/src/common/mod.rs new file mode 100644 index 000000000..d9da72bbc --- /dev/null +++ b/rust/diskann/src/common/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod aligned_allocator; +pub use aligned_allocator::AlignedBoxWithSlice; + +mod ann_result; +pub use ann_result::*; diff --git a/rust/diskann/src/index/disk_index/ann_disk_index.rs b/rust/diskann/src/index/disk_index/ann_disk_index.rs new file mode 100644 index 000000000..a6e053e17 --- /dev/null +++ b/rust/diskann/src/index/disk_index/ann_disk_index.rs @@ -0,0 +1,54 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_docs)] + +//! ANN disk index abstraction + +use vector::FullPrecisionDistance; + +use crate::model::{IndexConfiguration, DiskIndexBuildParameters}; +use crate::storage::DiskIndexStorage; +use crate::model::vertex::{DIM_128, DIM_256, DIM_104}; + +use crate::common::{ANNResult, ANNError}; + +use super::DiskIndex; + +/// ANN disk index abstraction for custom +pub trait ANNDiskIndex : Sync + Send +where T : Default + Copy + Sync + Send + Into + { + /// Build index + fn build(&mut self, codebook_prefix: &str) -> ANNResult<()>; +} + +/// Create Index based on configuration +pub fn create_disk_index<'a, T>( + disk_build_param: Option, + config: IndexConfiguration, + storage: DiskIndexStorage, +) -> ANNResult + 'a>> +where + T: Default + Copy + Sync + Send + Into + 'a, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + match config.aligned_dim { + DIM_104 => { + let index = Box::new(DiskIndex::::new(disk_build_param, config, storage)); + Ok(index as Box>) + }, + DIM_128 => { + let index = Box::new(DiskIndex::::new(disk_build_param, config, storage)); + Ok(index as Box>) + }, + DIM_256 => { + let index = Box::new(DiskIndex::::new(disk_build_param, config, storage)); + Ok(index as Box>) + }, + _ => Err(ANNError::log_index_error(format!("Invalid dimension: {}", config.aligned_dim))), + } +} diff --git a/rust/diskann/src/index/disk_index/disk_index.rs b/rust/diskann/src/index/disk_index/disk_index.rs new file mode 100644 index 000000000..16f0d5969 --- /dev/null +++ b/rust/diskann/src/index/disk_index/disk_index.rs @@ -0,0 +1,161 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::mem; + +use logger::logger::indexlog::DiskIndexConstructionCheckpoint; +use vector::FullPrecisionDistance; + +use crate::common::{ANNResult, ANNError}; +use crate::index::{InmemIndex, ANNInmemIndex}; +use crate::instrumentation::DiskIndexBuildLogger; +use crate::model::configuration::DiskIndexBuildParameters; +use crate::model::{IndexConfiguration, MAX_PQ_TRAINING_SET_SIZE, MAX_PQ_CHUNKS, generate_quantized_data, GRAPH_SLACK_FACTOR}; +use crate::storage::DiskIndexStorage; +use crate::utils::set_rayon_num_threads; + +use super::ann_disk_index::ANNDiskIndex; + +pub const OVERHEAD_FACTOR: f64 = 1.1f64; + +pub const MAX_SAMPLE_POINTS_FOR_WARMUP: usize = 100_000; + +pub struct DiskIndex +where + [T; N]: FullPrecisionDistance, +{ + /// Parameters for index construction + /// None for query path + disk_build_param: Option, + + configuration: IndexConfiguration, + + pub storage: DiskIndexStorage, +} + +impl DiskIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + pub fn new( + disk_build_param: Option, + configuration: IndexConfiguration, + storage: DiskIndexStorage, + ) -> Self { + Self { + disk_build_param, + configuration, + storage, + } + } + + pub fn disk_build_param(&self) -> &Option { + &self.disk_build_param + } + + pub fn index_configuration(&self) -> &IndexConfiguration { + &self.configuration + } + + fn build_inmem_index(&self, num_points: usize, data_path: &str, inmem_index_path: &str) -> ANNResult<()> { + let estimated_index_ram = self.estimate_ram_usage(num_points); + if estimated_index_ram >= self.fetch_disk_build_param()?.index_build_ram_limit() * 1024_f64 * 1024_f64 * 1024_f64 { + return Err(ANNError::log_index_error(format!( + "Insufficient memory budget for index build, index_build_ram_limit={}GB estimated_index_ram={}GB", + self.fetch_disk_build_param()?.index_build_ram_limit(), + estimated_index_ram / (1024_f64 * 1024_f64 * 1024_f64), + ))); + } + + let mut index = InmemIndex::::new(self.configuration.clone())?; + index.build(data_path, num_points)?; + index.save(inmem_index_path)?; + + Ok(()) + } + + #[inline] + fn estimate_ram_usage(&self, size: usize) -> f64 { + let degree = self.configuration.index_write_parameter.max_degree as usize; + let datasize = mem::size_of::(); + + let dataset_size = (size * N * datasize) as f64; + let graph_size = (size * degree * mem::size_of::()) as f64 * GRAPH_SLACK_FACTOR; + + OVERHEAD_FACTOR * (dataset_size + graph_size) + } + + #[inline] + fn fetch_disk_build_param(&self) -> ANNResult<&DiskIndexBuildParameters> { + self.disk_build_param + .as_ref() + .ok_or_else(|| ANNError::log_index_config_error( + "disk_build_param".to_string(), + "disk_build_param is None".to_string())) + } +} + +impl ANNDiskIndex for DiskIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + fn build(&mut self, codebook_prefix: &str) -> ANNResult<()> { + if self.configuration.index_write_parameter.num_threads > 0 { + set_rayon_num_threads(self.configuration.index_write_parameter.num_threads); + } + + println!("Starting index build: R={} L={} Query RAM budget={} Indexing RAM budget={} T={}", + self.configuration.index_write_parameter.max_degree, + self.configuration.index_write_parameter.search_list_size, + self.fetch_disk_build_param()?.search_ram_limit(), + self.fetch_disk_build_param()?.index_build_ram_limit(), + self.configuration.index_write_parameter.num_threads + ); + + let mut logger = DiskIndexBuildLogger::new(DiskIndexConstructionCheckpoint::PqConstruction); + + // PQ memory consumption = PQ pivots + PQ compressed table + // PQ pivots: dim * num_centroids * sizeof::() + // PQ compressed table: num_pts * num_pq_chunks * (dim / num_pq_chunks) * sizeof::() + // * Because num_centroids is 256, centroid id can be represented by u8 + let num_points = self.configuration.max_points; + let dim = self.configuration.dim; + let p_val = MAX_PQ_TRAINING_SET_SIZE / (num_points as f64); + let mut num_pq_chunks = ((self.fetch_disk_build_param()?.search_ram_limit() / (num_points as f64)).floor()) as usize; + num_pq_chunks = if num_pq_chunks == 0 { 1 } else { num_pq_chunks }; + num_pq_chunks = if num_pq_chunks > dim { dim } else { num_pq_chunks }; + num_pq_chunks = if num_pq_chunks > MAX_PQ_CHUNKS { MAX_PQ_CHUNKS } else { num_pq_chunks }; + + println!("Compressing {}-dimensional data into {} bytes per vector.", dim, num_pq_chunks); + + // TODO: Decouple PQ from file access + generate_quantized_data::( + p_val, + num_pq_chunks, + codebook_prefix, + self.storage.get_pq_storage(), + )?; + logger.log_checkpoint(DiskIndexConstructionCheckpoint::InmemIndexBuild)?; + + // TODO: Decouple index from file access + let inmem_index_path = self.storage.index_path_prefix().clone() + "_mem.index"; + self.build_inmem_index(num_points, self.storage.dataset_file(), inmem_index_path.as_str())?; + logger.log_checkpoint(DiskIndexConstructionCheckpoint::DiskLayout)?; + + self.storage.create_disk_layout()?; + logger.log_checkpoint(DiskIndexConstructionCheckpoint::None)?; + + let ten_percent_points = ((num_points as f64) * 0.1_f64).ceil(); + let num_sample_points = if ten_percent_points > (MAX_SAMPLE_POINTS_FOR_WARMUP as f64) { MAX_SAMPLE_POINTS_FOR_WARMUP as f64 } else { ten_percent_points }; + let sample_sampling_rate = num_sample_points / (num_points as f64); + self.storage.gen_query_warmup_data(sample_sampling_rate)?; + + self.storage.index_build_cleanup()?; + + Ok(()) + } +} + diff --git a/rust/diskann/src/index/disk_index/mod.rs b/rust/diskann/src/index/disk_index/mod.rs new file mode 100644 index 000000000..4f07bd78d --- /dev/null +++ b/rust/diskann/src/index/disk_index/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod disk_index; +pub use disk_index::DiskIndex; + +pub mod ann_disk_index; diff --git a/rust/diskann/src/index/inmem_index/ann_inmem_index.rs b/rust/diskann/src/index/inmem_index/ann_inmem_index.rs new file mode 100644 index 000000000..dc8dfc876 --- /dev/null +++ b/rust/diskann/src/index/inmem_index/ann_inmem_index.rs @@ -0,0 +1,97 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_docs)] + +//! ANN in-memory index abstraction + +use vector::FullPrecisionDistance; + +use crate::model::{vertex::{DIM_128, DIM_256, DIM_104}, IndexConfiguration}; +use crate::common::{ANNResult, ANNError}; + +use super::InmemIndex; + +/// ANN inmem-index abstraction for custom +pub trait ANNInmemIndex : Sync + Send +where T : Default + Copy + Sync + Send + Into + { + /// Build index + fn build(&mut self, filename: &str, num_points_to_load: usize) -> ANNResult<()>; + + /// Save index + fn save(&mut self, filename: &str) -> ANNResult<()>; + + /// Load index + fn load(&mut self, filename: &str, expected_num_points: usize) -> ANNResult<()>; + + /// insert index + fn insert(&mut self, filename: &str, num_points_to_insert: usize) -> ANNResult<()>; + + /// Search the index for K nearest neighbors of query using given L value, for benchmarking purposes + fn search(&self, query : &[T], k_value : usize, l_value : u32, indices : &mut[u32]) -> ANNResult; + + /// Soft deletes the nodes with the ids in the given array. + fn soft_delete(&mut self, vertex_ids_to_delete: Vec, num_points_to_delete: usize) -> ANNResult<()>; +} + +/// Create Index based on configuration +pub fn create_inmem_index<'a, T>(config: IndexConfiguration) -> ANNResult + 'a>> +where + T: Default + Copy + Sync + Send + Into + 'a, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + match config.aligned_dim { + DIM_104 => { + let index = Box::new(InmemIndex::::new(config)?); + Ok(index as Box>) + }, + DIM_128 => { + let index = Box::new(InmemIndex::::new(config)?); + Ok(index as Box>) + }, + DIM_256 => { + let index = Box::new(InmemIndex::::new(config)?); + Ok(index as Box>) + }, + _ => Err(ANNError::log_index_error(format!("Invalid dimension: {}", config.aligned_dim))), + } +} + +#[cfg(test)] +mod dataset_test { + use vector::Metric; + + use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder; + + use super::*; + + #[test] + #[should_panic(expected = "ERROR: Data file fake_file does not exist.")] + fn create_index_test() { + let index_write_parameters = IndexWriteParametersBuilder::new(50, 4) + .with_alpha(1.2) + .with_saturate_graph(false) + .with_num_threads(1) + .build(); + + let config = IndexConfiguration::new( + Metric::L2, + 128, + 256, + 1_000_000, + false, + 0, + false, + 0, + 1f32, + index_write_parameters, + ); + let mut index = create_inmem_index::(config).unwrap(); + index.build("fake_file", 100).unwrap(); + } +} + diff --git a/rust/diskann/src/index/inmem_index/inmem_index.rs b/rust/diskann/src/index/inmem_index/inmem_index.rs new file mode 100644 index 000000000..871d21092 --- /dev/null +++ b/rust/diskann/src/index/inmem_index/inmem_index.rs @@ -0,0 +1,1033 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::cmp; +use std::sync::RwLock; +use std::time::Duration; + +use hashbrown::hash_set::Entry::*; +use hashbrown::HashSet; +use vector::FullPrecisionDistance; + +use crate::common::{ANNError, ANNResult}; +use crate::index::ANNInmemIndex; +use crate::instrumentation::IndexLogger; +use crate::model::graph::AdjacencyList; +use crate::model::{ + ArcConcurrentBoxedQueue, InMemQueryScratch, InMemoryGraph, IndexConfiguration, InmemDataset, + Neighbor, ScratchStoreManager, Vertex, +}; + +use crate::utils::file_util::{file_exists, load_metadata_from_file}; +use crate::utils::rayon_util::execute_with_rayon; +use crate::utils::{set_rayon_num_threads, Timer}; + +/// In-memory Index +pub struct InmemIndex +where + [T; N]: FullPrecisionDistance, +{ + /// Dataset + pub dataset: InmemDataset, + + /// Graph + pub final_graph: InMemoryGraph, + + /// Index configuration + pub configuration: IndexConfiguration, + + /// Start point of the search. When _num_frozen_pts is greater than zero, + /// this is the location of the first frozen point. Otherwise, this is a + /// location of one of the points in index. + pub start: u32, + + /// Max observed out degree + pub max_observed_degree: u32, + + /// Number of active points i.e. existing in the graph + pub num_active_pts: usize, + + /// query scratch queue. + query_scratch_queue: ArcConcurrentBoxedQueue>, + + pub delete_set: RwLock>, +} + +impl InmemIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + /// Create Index obj based on configuration + pub fn new(mut config: IndexConfiguration) -> ANNResult { + // Sanity check. While logically it is correct, max_points = 0 causes + // downstream problems. + if config.max_points == 0 { + config.max_points = 1; + } + + let total_internal_points = config.max_points + config.num_frozen_pts; + + if config.use_pq_dist { + // TODO: pq + todo!("PQ is not supported now"); + } + + let start = config.max_points.try_into()?; + + let query_scratch_queue = ArcConcurrentBoxedQueue::>::new(); + let delete_set = RwLock::new(HashSet::::new()); + + Ok(Self { + dataset: InmemDataset::::new(total_internal_points, config.growth_potential)?, + final_graph: InMemoryGraph::new( + total_internal_points, + config.index_write_parameter.max_degree, + ), + configuration: config, + start, + max_observed_degree: 0, + num_active_pts: 0, + query_scratch_queue, + delete_set, + }) + } + + /// Get distance between two vertices. + pub fn get_distance(&self, id1: u32, id2: u32) -> ANNResult { + self.dataset + .get_distance(id1, id2, self.configuration.dist_metric) + } + + fn build_with_data_populated(&mut self) -> ANNResult<()> { + println!( + "Starting index build with {} points...", + self.num_active_pts + ); + + if self.num_active_pts < 1 { + return Err(ANNError::log_index_error( + "Error: Trying to build an index with 0 points.".to_string(), + )); + } + + if self.query_scratch_queue.size()? == 0 { + self.initialize_query_scratch( + 5 + self.configuration.index_write_parameter.num_threads, + self.configuration.index_write_parameter.search_list_size, + )?; + } + + // TODO: generate_frozen_point() + + self.link()?; + + self.print_stats()?; + + Ok(()) + } + + fn link(&mut self) -> ANNResult<()> { + // visit_order is a vector that is initialized to the entire graph + let mut visit_order = + Vec::with_capacity(self.num_active_pts + self.configuration.num_frozen_pts); + for i in 0..self.num_active_pts { + visit_order.push(i as u32); + } + + // If there are any frozen points, add them all. + for frozen in self.configuration.max_points + ..(self.configuration.max_points + self.configuration.num_frozen_pts) + { + visit_order.push(frozen as u32); + } + + // if there are frozen points, the first such one is set to be the _start + if self.configuration.num_frozen_pts > 0 { + self.start = self.configuration.max_points as u32; + } else { + self.start = self.dataset.calculate_medoid_point_id()?; + } + + let timer = Timer::new(); + + let range = visit_order.len(); + let logger = IndexLogger::new(range); + + execute_with_rayon( + 0..range, + self.configuration.index_write_parameter.num_threads, + |idx| { + self.insert_vertex_id(visit_order[idx])?; + logger.vertex_processed()?; + + Ok(()) + }, + )?; + + self.cleanup_graph(&visit_order)?; + + if self.num_active_pts > 0 { + println!("{}", timer.elapsed_seconds_for_step("Link time: ")); + } + + Ok(()) + } + + fn insert_vertex_id(&self, vertex_id: u32) -> ANNResult<()> { + let mut scratch_manager = + ScratchStoreManager::new(self.query_scratch_queue.clone(), Duration::from_millis(10))?; + let scratch = scratch_manager.scratch_space().ok_or_else(|| { + ANNError::log_index_error( + "ScratchStoreManager doesn't have InMemQueryScratch instance available".to_string(), + ) + })?; + + let new_neighbors = self.search_for_point_and_prune(scratch, vertex_id)?; + self.update_vertex_with_neighbors(vertex_id, new_neighbors)?; + self.update_neighbors_of_vertex(vertex_id, scratch)?; + + Ok(()) + } + + fn update_neighbors_of_vertex( + &self, + vertex_id: u32, + scratch: &mut InMemQueryScratch, + ) -> Result<(), ANNError> { + let vertex = self.final_graph.read_vertex_and_neighbors(vertex_id)?; + assert!(vertex.size() <= self.configuration.index_write_parameter.max_degree as usize); + self.inter_insert( + vertex_id, + vertex.get_neighbors(), + self.configuration.index_write_parameter.max_degree, + scratch, + )?; + Ok(()) + } + + fn update_vertex_with_neighbors( + &self, + vertex_id: u32, + new_neighbors: AdjacencyList, + ) -> Result<(), ANNError> { + let vertex = &mut self.final_graph.write_vertex_and_neighbors(vertex_id)?; + vertex.set_neighbors(new_neighbors); + assert!(vertex.size() <= self.configuration.index_write_parameter.max_degree as usize); + Ok(()) + } + + fn search_for_point_and_prune( + &self, + scratch: &mut InMemQueryScratch, + vertex_id: u32, + ) -> ANNResult { + let mut pruned_list = + AdjacencyList::for_range(self.configuration.index_write_parameter.max_degree as usize); + let vertex = self.dataset.get_vertex(vertex_id)?; + let mut visited_nodes = self.search_for_point(&vertex, scratch)?; + + self.prune_neighbors(vertex_id, &mut visited_nodes, &mut pruned_list, scratch)?; + + if pruned_list.is_empty() { + return Err(ANNError::log_index_error( + "pruned_list is empty.".to_string(), + )); + } + + if self.final_graph.size() + != self.configuration.max_points + self.configuration.num_frozen_pts + { + return Err(ANNError::log_index_error(format!( + "final_graph has {} vertices instead of {}", + self.final_graph.size(), + self.configuration.max_points + self.configuration.num_frozen_pts, + ))); + } + + Ok(pruned_list) + } + + fn search( + &self, + query: &Vertex, + k_value: usize, + l_value: u32, + indices: &mut [u32], + ) -> ANNResult { + if k_value > l_value as usize { + return Err(ANNError::log_index_error(format!( + "Set L: {} to a value of at least K: {}", + l_value, k_value + ))); + } + + let mut scratch_manager = + ScratchStoreManager::new(self.query_scratch_queue.clone(), Duration::from_millis(10))?; + + let scratch = scratch_manager.scratch_space().ok_or_else(|| { + ANNError::log_index_error( + "ScratchStoreManager doesn't have InMemQueryScratch instance available".to_string(), + ) + })?; + + if l_value > scratch.candidate_size { + println!("Attempting to expand query scratch_space. Was created with Lsize: {} but search L is: {}", scratch.candidate_size, l_value); + scratch.resize_for_new_candidate_size(l_value); + println!( + "Resize completed. New scratch size is: {}", + scratch.candidate_size + ); + } + + let cmp = self.search_with_l_override(query, scratch, l_value as usize)?; + let mut pos = 0; + + for i in 0..scratch.best_candidates.size() { + if scratch.best_candidates[i].id < self.configuration.max_points as u32 { + // Filter out the deleted points. + if let Ok(delete_set_guard) = self.delete_set.read() { + if !delete_set_guard.contains(&scratch.best_candidates[i].id) { + indices[pos] = scratch.best_candidates[i].id; + pos += 1; + } + } else { + return Err(ANNError::log_lock_poison_error( + "failed to acquire the lock for delete_set.".to_string(), + )); + } + } + + if pos == k_value { + break; + } + } + + if pos < k_value { + eprintln!( + "Found fewer than K elements for query! Found: {} but K: {}", + pos, k_value + ); + } + + Ok(cmp) + } + + fn cleanup_graph(&mut self, visit_order: &Vec) -> ANNResult<()> { + if self.num_active_pts > 0 { + println!("Starting final cleanup.."); + } + + execute_with_rayon( + 0..visit_order.len(), + self.configuration.index_write_parameter.num_threads, + |idx| { + let vertex_id = visit_order[idx]; + let num_nbrs = self.get_neighbor_count(vertex_id)?; + + if num_nbrs <= self.configuration.index_write_parameter.max_degree as usize { + // Neighbor list is already small enough. + return Ok(()); + } + + let mut scratch_manager = ScratchStoreManager::new( + self.query_scratch_queue.clone(), + Duration::from_millis(10), + )?; + let scratch = scratch_manager.scratch_space().ok_or_else(|| { + ANNError::log_index_error( + "ScratchStoreManager doesn't have InMemQueryScratch instance available" + .to_string(), + ) + })?; + + let mut dummy_pool = self.get_neighbors_for_vertex(vertex_id)?; + + let mut new_out_neighbors = AdjacencyList::for_range( + self.configuration.index_write_parameter.max_degree as usize, + ); + self.prune_neighbors(vertex_id, &mut dummy_pool, &mut new_out_neighbors, scratch)?; + + self.final_graph + .write_vertex_and_neighbors(vertex_id)? + .set_neighbors(new_out_neighbors); + + Ok(()) + }, + ) + } + + /// Get the unique neighbors for a vertex. + /// + /// This code feels out of place here. This should have nothing to do with whether this + /// is in memory index? + /// # Errors + /// + /// This function will return an error if we are not able to get the read lock. + fn get_neighbors_for_vertex(&self, vertex_id: u32) -> ANNResult> { + let binding = self.final_graph.read_vertex_and_neighbors(vertex_id)?; + let neighbors = binding.get_neighbors(); + let dummy_pool = self.get_unique_neighbors(neighbors, vertex_id)?; + + Ok(dummy_pool) + } + + /// Returns a vector of unique neighbors for the given vertex, along with their distances. + /// + /// # Arguments + /// + /// * `neighbors` - A vector of neighbor id index for the given vertex. + /// * `vertex_id` - The given vertex id. + /// + /// # Errors + /// + /// Returns an `ANNError` if there is an error retrieving the vertex or one of its neighbors. + pub fn get_unique_neighbors( + &self, + neighbors: &Vec, + vertex_id: u32, + ) -> Result, ANNError> { + let vertex = self.dataset.get_vertex(vertex_id)?; + + let len = neighbors.len(); + if len == 0 { + return Ok(Vec::new()); + } + + self.dataset.prefetch_vector(neighbors[0]); + + let mut dummy_visited: HashSet = HashSet::with_capacity(len); + let mut dummy_pool: Vec = Vec::with_capacity(len); + + // let slice = ['w', 'i', 'n', 'd', 'o', 'w', 's']; + // for window in slice.windows(2) { + // &println!{"[{}, {}]", window[0], window[1]}; + // } + // prints: [w, i] -> [i, n] -> [n, d] -> [d, o] -> [o, w] -> [w, s] + for current in neighbors.windows(2) { + // Prefetch the next item. + self.dataset.prefetch_vector(current[1]); + let current = current[0]; + + self.insert_neighbor_if_unique( + &mut dummy_visited, + current, + vertex_id, + &vertex, + &mut dummy_pool, + )?; + } + + // Insert the last neighbor + #[allow(clippy::unwrap_used)] + self.insert_neighbor_if_unique( + &mut dummy_visited, + *neighbors.last().unwrap(), // we know len != 0, so this is safe. + vertex_id, + &vertex, + &mut dummy_pool, + )?; + + Ok(dummy_pool) + } + + fn insert_neighbor_if_unique( + &self, + dummy_visited: &mut HashSet, + current: u32, + vertex_id: u32, + vertex: &Vertex<'_, T, N>, + dummy_pool: &mut Vec, + ) -> Result<(), ANNError> { + if current != vertex_id { + if let Vacant(entry) = dummy_visited.entry(current) { + let cur_nbr_vertex = self.dataset.get_vertex(current)?; + let dist = vertex.compare(&cur_nbr_vertex, self.configuration.dist_metric); + dummy_pool.push(Neighbor::new(current, dist)); + entry.insert(); + } + } + + Ok(()) + } + + /// Get count of neighbors for a given vertex. + /// + /// # Errors + /// + /// This function will return an error if we can't get a lock. + fn get_neighbor_count(&self, vertex_id: u32) -> ANNResult { + let num_nbrs = self + .final_graph + .read_vertex_and_neighbors(vertex_id)? + .size(); + Ok(num_nbrs) + } + + fn soft_delete_vertex(&self, vertex_id_to_delete: u32) -> ANNResult<()> { + if vertex_id_to_delete as usize > self.num_active_pts { + return Err(ANNError::log_index_error(format!( + "vertex_id_to_delete: {} is greater than the number of active points in the graph: {}", + vertex_id_to_delete, self.num_active_pts + ))); + } + + let mut delete_set_guard = match self.delete_set.write() { + Ok(guard) => guard, + Err(_) => { + return Err(ANNError::log_index_error(format!( + "Failed to acquire delete_set lock, cannot delete vertex {}", + vertex_id_to_delete + ))); + } + }; + + delete_set_guard.insert(vertex_id_to_delete); + Ok(()) + } + + fn initialize_query_scratch( + &mut self, + num_threads: u32, + search_candidate_size: u32, + ) -> ANNResult<()> { + self.query_scratch_queue.reserve(num_threads as usize)?; + for _ in 0..num_threads { + let scratch = Box::new(InMemQueryScratch::::new( + search_candidate_size, + &self.configuration.index_write_parameter, + false, + )?); + + self.query_scratch_queue.push(scratch)?; + } + + Ok(()) + } + + fn print_stats(&mut self) -> ANNResult<()> { + let mut max = 0; + let mut min = usize::MAX; + let mut total = 0; + let mut cnt = 0; + + for i in 0..self.num_active_pts { + let vertex_id = i.try_into()?; + let pool_size = self + .final_graph + .read_vertex_and_neighbors(vertex_id)? + .size(); + max = cmp::max(max, pool_size); + min = cmp::min(min, pool_size); + total += pool_size; + if pool_size < 2 { + cnt += 1; + } + } + + println!( + "Index built with degree: max: {} avg: {} min: {} count(deg<2): {}", + max, + (total as f32) / ((self.num_active_pts + self.configuration.num_frozen_pts) as f32), + min, + cnt + ); + + match self.delete_set.read() { + Ok(guard) => { + println!( + "Number of soft deleted vertices {}, soft deleted percentage: {}", + guard.len(), + (guard.len() as f32) + / ((self.num_active_pts + self.configuration.num_frozen_pts) as f32), + ); + } + Err(_) => { + return Err(ANNError::log_lock_poison_error( + "Failed to acquire delete_set lock, cannot get the number of deleted vertices" + .to_string(), + )); + } + }; + + self.max_observed_degree = cmp::max(max as u32, self.max_observed_degree); + + Ok(()) + } +} + +impl ANNInmemIndex for InmemIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + fn build(&mut self, filename: &str, num_points_to_load: usize) -> ANNResult<()> { + // TODO: fresh-diskANN + // std::unique_lock ul(_update_lock); + + if !file_exists(filename) { + return Err(ANNError::log_index_error(format!( + "ERROR: Data file {} does not exist.", + filename + ))); + } + + let (file_num_points, file_dim) = load_metadata_from_file(filename)?; + if file_num_points > self.configuration.max_points { + return Err(ANNError::log_index_error(format!( + "ERROR: Driver requests loading {} points and file has {} points, + but index can support only {} points as specified in configuration.", + num_points_to_load, file_num_points, self.configuration.max_points + ))); + } + + if num_points_to_load > file_num_points { + return Err(ANNError::log_index_error(format!( + "ERROR: Driver requests loading {} points and file has only {} points.", + num_points_to_load, file_num_points + ))); + } + + if file_dim != self.configuration.dim { + return Err(ANNError::log_index_error(format!( + "ERROR: Driver requests loading {} dimension, but file has {} dimension.", + self.configuration.dim, file_dim + ))); + } + + if self.configuration.use_pq_dist { + // TODO: PQ + todo!("PQ is not supported now"); + } + + if self.configuration.index_write_parameter.num_threads > 0 { + set_rayon_num_threads(self.configuration.index_write_parameter.num_threads); + } + + self.dataset.build_from_file(filename, num_points_to_load)?; + + println!("Using only first {} from file.", num_points_to_load); + + // TODO: tag_lock + + self.num_active_pts = num_points_to_load; + self.build_with_data_populated()?; + + Ok(()) + } + + fn insert(&mut self, filename: &str, num_points_to_insert: usize) -> ANNResult<()> { + // fresh-diskANN + if !file_exists(filename) { + return Err(ANNError::log_index_error(format!( + "ERROR: Data file {} does not exist.", + filename + ))); + } + + let (file_num_points, file_dim) = load_metadata_from_file(filename)?; + + if num_points_to_insert > file_num_points { + return Err(ANNError::log_index_error(format!( + "ERROR: Driver requests loading {} points and file has only {} points.", + num_points_to_insert, file_num_points + ))); + } + + if file_dim != self.configuration.dim { + return Err(ANNError::log_index_error(format!( + "ERROR: Driver requests loading {} dimension, but file has {} dimension.", + self.configuration.dim, file_dim + ))); + } + + if self.configuration.use_pq_dist { + // TODO: PQ + todo!("PQ is not supported now"); + } + + if self.query_scratch_queue.size()? == 0 { + self.initialize_query_scratch( + 5 + self.configuration.index_write_parameter.num_threads, + self.configuration.index_write_parameter.search_list_size, + )?; + } + + if self.configuration.index_write_parameter.num_threads > 0 { + // set the thread count of Rayon, otherwise it will use threads as many as logical cores. + std::env::set_var( + "RAYON_NUM_THREADS", + self.configuration + .index_write_parameter + .num_threads + .to_string(), + ); + } + + self.dataset + .append_from_file(filename, num_points_to_insert)?; + self.final_graph.extend( + num_points_to_insert, + self.configuration.index_write_parameter.max_degree, + ); + + // TODO: this should not consider frozen points + let previous_last_pt = self.num_active_pts; + self.num_active_pts += num_points_to_insert; + self.configuration.max_points += num_points_to_insert; + + println!("Inserting {} vectors from file.", num_points_to_insert); + + // TODO: tag_lock + let logger = IndexLogger::new(num_points_to_insert); + let timer = Timer::new(); + execute_with_rayon( + previous_last_pt..self.num_active_pts, + self.configuration.index_write_parameter.num_threads, + |idx| { + self.insert_vertex_id(idx as u32)?; + logger.vertex_processed()?; + + Ok(()) + }, + )?; + + let mut visit_order = + Vec::with_capacity(self.num_active_pts + self.configuration.num_frozen_pts); + for i in 0..self.num_active_pts { + visit_order.push(i as u32); + } + + self.cleanup_graph(&visit_order)?; + println!("{}", timer.elapsed_seconds_for_step("Insert time: ")); + + self.print_stats()?; + + Ok(()) + } + + fn save(&mut self, filename: &str) -> ANNResult<()> { + let data_file = filename.to_string() + ".data"; + let delete_file = filename.to_string() + ".delete"; + + self.save_graph(filename)?; + self.save_data(data_file.as_str())?; + self.save_delete_list(delete_file.as_str())?; + + Ok(()) + } + + fn load(&mut self, filename: &str, expected_num_points: usize) -> ANNResult<()> { + self.num_active_pts = expected_num_points; + self.dataset + .build_from_file(&format!("{}.data", filename), expected_num_points)?; + + self.load_graph(filename, expected_num_points)?; + self.load_delete_list(&format!("{}.delete", filename))?; + + if self.query_scratch_queue.size()? == 0 { + self.initialize_query_scratch( + 5 + self.configuration.index_write_parameter.num_threads, + self.configuration.index_write_parameter.search_list_size, + )?; + } + + Ok(()) + } + + fn search( + &self, + query: &[T], + k_value: usize, + l_value: u32, + indices: &mut [u32], + ) -> ANNResult { + let query_vector = Vertex::new(<&[T; N]>::try_from(query)?, 0); + InmemIndex::search(self, &query_vector, k_value, l_value, indices) + } + + fn soft_delete( + &mut self, + vertex_ids_to_delete: Vec, + num_points_to_delete: usize, + ) -> ANNResult<()> { + println!("Deleting {} vectors from file.", num_points_to_delete); + + let logger = IndexLogger::new(num_points_to_delete); + let timer = Timer::new(); + + execute_with_rayon( + 0..num_points_to_delete, + self.configuration.index_write_parameter.num_threads, + |idx: usize| { + self.soft_delete_vertex(vertex_ids_to_delete[idx])?; + logger.vertex_processed()?; + + Ok(()) + }, + )?; + + println!("{}", timer.elapsed_seconds_for_step("Delete time: ")); + self.print_stats()?; + + Ok(()) + } +} + +#[cfg(test)] +mod index_test { + use vector::Metric; + + use super::*; + use crate::{ + model::{ + configuration::index_write_parameters::IndexWriteParametersBuilder, vertex::DIM_128, + }, + test_utils::get_test_file_path, + utils::file_util::load_ids_to_delete_from_file, + utils::round_up, + }; + + const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin"; + const TRUTH_GRAPH: &str = "tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2"; + const TEST_DELETE_FILE: &str = "tests/data/delete_set_50pts.bin"; + const TRUTH_GRAPH_WITH_SATURATED: &str = + "tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_mem.index"; + const R: u32 = 4; + const L: u32 = 50; + const ALPHA: f32 = 1.2; + + /// Build the index with TEST_DATA_FILE and compare the index graph with truth graph TRUTH_GRAPH + /// Change above constants if you want to test with different dataset + macro_rules! index_end_to_end_test_singlethread { + ($saturate_graph:expr, $truth_graph:expr) => {{ + let (data_num, dim) = + load_metadata_from_file(get_test_file_path(TEST_DATA_FILE).as_str()).unwrap(); + + let index_write_parameters = IndexWriteParametersBuilder::new(L, R) + .with_alpha(ALPHA) + .with_num_threads(1) + .with_saturate_graph($saturate_graph) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num, + false, + 0, + false, + 0, + 1.0f32, + index_write_parameters, + ); + let mut index: InmemIndex = InmemIndex::new(config.clone()).unwrap(); + + index + .build(get_test_file_path(TEST_DATA_FILE).as_str(), data_num) + .unwrap(); + + let mut truth_index: InmemIndex = InmemIndex::new(config).unwrap(); + truth_index + .load_graph(get_test_file_path($truth_graph).as_str(), data_num) + .unwrap(); + + compare_graphs(&index, &truth_index); + }}; + } + + #[test] + fn index_end_to_end_test_singlethread() { + index_end_to_end_test_singlethread!(false, TRUTH_GRAPH); + } + + #[test] + fn index_end_to_end_test_singlethread_with_saturate_graph() { + index_end_to_end_test_singlethread!(true, TRUTH_GRAPH_WITH_SATURATED); + } + + #[test] + fn index_end_to_end_test_multithread() { + let (data_num, dim) = + load_metadata_from_file(get_test_file_path(TEST_DATA_FILE).as_str()).unwrap(); + + let index_write_parameters = IndexWriteParametersBuilder::new(L, R) + .with_alpha(ALPHA) + .with_num_threads(8) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num, + false, + 0, + false, + 0, + 1f32, + index_write_parameters, + ); + let mut index: InmemIndex = InmemIndex::new(config).unwrap(); + + index + .build(get_test_file_path(TEST_DATA_FILE).as_str(), data_num) + .unwrap(); + + for i in 0..index.final_graph.size() { + assert_ne!( + index + .final_graph + .read_vertex_and_neighbors(i as u32) + .unwrap() + .size(), + 0 + ); + } + } + + const TEST_DATA_FILE_2: &str = "tests/data/siftsmall_learn_256pts_2.fbin"; + const INSERT_TRUTH_GRAPH: &str = + "tests/data/truth_index_siftsmall_learn_256pts_1+2_R4_L50_A1.2"; + const INSERT_TRUTH_GRAPH_WITH_SATURATED: &str = + "tests/data/truth_index_siftsmall_learn_256pts_1+2_saturated_R4_L50_A1.2"; + + /// Build the index with TEST_DATA_FILE, insert TEST_DATA_FILE_2 and compare the index graph with truth graph TRUTH_GRAPH + /// Change above constants if you want to test with different dataset + macro_rules! index_insert_end_to_end_test_singlethread { + ($saturate_graph:expr, $truth_graph:expr) => {{ + let (data_num, dim) = + load_metadata_from_file(get_test_file_path(TEST_DATA_FILE).as_str()).unwrap(); + + let index_write_parameters = IndexWriteParametersBuilder::new(L, R) + .with_alpha(ALPHA) + .with_num_threads(1) + .with_saturate_graph($saturate_graph) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num, + false, + 0, + false, + 0, + 2.0f32, + index_write_parameters, + ); + let mut index: InmemIndex = InmemIndex::new(config.clone()).unwrap(); + + index + .build(get_test_file_path(TEST_DATA_FILE).as_str(), data_num) + .unwrap(); + index + .insert(get_test_file_path(TEST_DATA_FILE_2).as_str(), data_num) + .unwrap(); + + let config2 = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num * 2, + false, + 0, + false, + 0, + 1.0f32, + index_write_parameters, + ); + let mut truth_index: InmemIndex = InmemIndex::new(config2).unwrap(); + truth_index + .load_graph(get_test_file_path($truth_graph).as_str(), data_num) + .unwrap(); + + compare_graphs(&index, &truth_index); + }}; + } + + /// Build the index with TEST_DATA_FILE, and delete the vertices with id defined in TEST_DELETE_SET + macro_rules! index_delete_end_to_end_test_singlethread { + () => {{ + let (data_num, dim) = + load_metadata_from_file(get_test_file_path(TEST_DATA_FILE).as_str()).unwrap(); + + let index_write_parameters = IndexWriteParametersBuilder::new(L, R) + .with_alpha(ALPHA) + .with_num_threads(1) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num, + false, + 0, + false, + 0, + 2.0f32, + index_write_parameters, + ); + let mut index: InmemIndex = InmemIndex::new(config.clone()).unwrap(); + + index + .build(get_test_file_path(TEST_DATA_FILE).as_str(), data_num) + .unwrap(); + + let (num_points_to_delete, vertex_ids_to_delete) = + load_ids_to_delete_from_file(TEST_DELETE_FILE).unwrap(); + index + .soft_delete(vertex_ids_to_delete, num_points_to_delete) + .unwrap(); + assert!(index.delete_set.read().unwrap().len() == num_points_to_delete); + }}; + } + + #[test] + fn index_insert_end_to_end_test_singlethread() { + index_insert_end_to_end_test_singlethread!(false, INSERT_TRUTH_GRAPH); + } + + #[test] + fn index_delete_end_to_end_test_singlethread() { + index_delete_end_to_end_test_singlethread!(); + } + + #[test] + fn index_insert_end_to_end_test_saturated_singlethread() { + index_insert_end_to_end_test_singlethread!(true, INSERT_TRUTH_GRAPH_WITH_SATURATED); + } + + fn compare_graphs(index: &InmemIndex, truth_index: &InmemIndex) { + assert_eq!(index.start, truth_index.start); + assert_eq!(index.max_observed_degree, truth_index.max_observed_degree); + assert_eq!(index.final_graph.size(), truth_index.final_graph.size()); + + for i in 0..index.final_graph.size() { + assert_eq!( + index + .final_graph + .read_vertex_and_neighbors(i as u32) + .unwrap() + .size(), + truth_index + .final_graph + .read_vertex_and_neighbors(i as u32) + .unwrap() + .size() + ); + assert_eq!( + index + .final_graph + .read_vertex_and_neighbors(i as u32) + .unwrap() + .get_neighbors(), + truth_index + .final_graph + .read_vertex_and_neighbors(i as u32) + .unwrap() + .get_neighbors() + ); + } + } +} diff --git a/rust/diskann/src/index/inmem_index/inmem_index_storage.rs b/rust/diskann/src/index/inmem_index/inmem_index_storage.rs new file mode 100644 index 000000000..fa14d70b2 --- /dev/null +++ b/rust/diskann/src/index/inmem_index/inmem_index_storage.rs @@ -0,0 +1,304 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::fs::File; +use std::io::{BufReader, BufWriter, Seek, SeekFrom, Write}; +use std::path::Path; + +use byteorder::{LittleEndian, ReadBytesExt}; +use vector::FullPrecisionDistance; + +use crate::common::{ANNError, ANNResult}; +use crate::model::graph::AdjacencyList; +use crate::model::InMemoryGraph; +use crate::utils::{file_exists, save_data_in_base_dimensions}; + +use super::InmemIndex; + +impl InmemIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + pub fn load_graph(&mut self, filename: &str, expected_num_points: usize) -> ANNResult { + // let file_offset = 0; // will need this for single file format support + + let mut in_file = BufReader::new(File::open(Path::new(filename))?); + // in_file.seek(SeekFrom::Start(file_offset as u64))?; + + let expected_file_size: usize = in_file.read_u64::()? as usize; + self.max_observed_degree = in_file.read_u32::()?; + self.start = in_file.read_u32::()?; + let file_frozen_pts: usize = in_file.read_u64::()? as usize; + + let vamana_metadata_size = 24; + + println!("From graph header, expected_file_size: {}, max_observed_degree: {}, start: {}, file_frozen_pts: {}", + expected_file_size, self.max_observed_degree, self.start, file_frozen_pts); + + if file_frozen_pts != self.configuration.num_frozen_pts { + if file_frozen_pts == 1 { + return Err(ANNError::log_index_config_error( + "num_frozen_pts".to_string(), + "ERROR: When loading index, detected dynamic index, but constructor asks for static index. Exitting.".to_string()) + ); + } else { + return Err(ANNError::log_index_config_error( + "num_frozen_pts".to_string(), + "ERROR: When loading index, detected static index, but constructor asks for dynamic index. Exitting.".to_string()) + ); + } + } + + println!("Loading vamana graph {}...", filename); + + let expected_max_points = expected_num_points - file_frozen_pts; + + // If user provides more points than max_points + // resize the _final_graph to the larger size. + if self.configuration.max_points < expected_max_points { + println!("Number of points in data: {} is greater than max_points: {} Setting max points to: {}", expected_max_points, self.configuration.max_points, expected_max_points); + + self.configuration.max_points = expected_max_points; + self.final_graph = InMemoryGraph::new( + self.configuration.max_points + self.configuration.num_frozen_pts, + self.configuration.index_write_parameter.max_degree, + ); + } + + let mut bytes_read = vamana_metadata_size; + let mut num_edges = 0; + let mut nodes_read = 0; + let mut max_observed_degree = 0; + + while bytes_read != expected_file_size { + let num_nbrs = in_file.read_u32::()?; + max_observed_degree = if num_nbrs > max_observed_degree { + num_nbrs + } else { + max_observed_degree + }; + + if num_nbrs == 0 { + return Err(ANNError::log_index_error(format!( + "ERROR: Point found with no out-neighbors, point# {}", + nodes_read + ))); + } + + num_edges += num_nbrs; + nodes_read += 1; + let mut tmp: Vec = Vec::with_capacity(num_nbrs as usize); + for _ in 0..num_nbrs { + tmp.push(in_file.read_u32::()?); + } + + self.final_graph + .write_vertex_and_neighbors(nodes_read - 1)? + .set_neighbors(AdjacencyList::from(tmp)); + bytes_read += 4 * (num_nbrs as usize + 1); + } + + println!( + "Done. Index has {} nodes and {} out-edges, _start is set to {}", + nodes_read, num_edges, self.start + ); + + self.max_observed_degree = max_observed_degree; + Ok(nodes_read as usize) + } + + /// Save the graph index on a file as an adjacency list. + /// For each point, first store the number of neighbors, + /// and then the neighbor list (each as 4 byte u32) + pub fn save_graph(&mut self, graph_file: &str) -> ANNResult { + let file: File = File::create(graph_file)?; + let mut out = BufWriter::new(file); + + let file_offset: u64 = 0; + out.seek(SeekFrom::Start(file_offset))?; + let mut index_size: u64 = 24; + let mut max_degree: u32 = 0; + out.write_all(&index_size.to_le_bytes())?; + out.write_all(&self.max_observed_degree.to_le_bytes())?; + out.write_all(&self.start.to_le_bytes())?; + out.write_all(&(self.configuration.num_frozen_pts as u64).to_le_bytes())?; + + // At this point, either nd == max_points or any frozen points have + // been temporarily moved to nd, so nd + num_frozen_points is the valid + // location limit + for i in 0..self.num_active_pts + self.configuration.num_frozen_pts { + let idx = i as u32; + let gk: u32 = self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32; + out.write_all(&gk.to_le_bytes())?; + for neighbor in self + .final_graph + .read_vertex_and_neighbors(idx)? + .get_neighbors() + .iter() + { + out.write_all(&neighbor.to_le_bytes())?; + } + max_degree = + if self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32 > max_degree { + self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32 + } else { + max_degree + }; + index_size += (std::mem::size_of::() * (gk as usize + 1)) as u64; + } + out.seek(SeekFrom::Start(file_offset))?; + out.write_all(&index_size.to_le_bytes())?; + out.write_all(&max_degree.to_le_bytes())?; + out.flush()?; + Ok(index_size) + } + + /// Save the data on a file. + pub fn save_data(&mut self, data_file: &str) -> ANNResult { + // Note: at this point, either _nd == _max_points or any frozen points have + // been temporarily moved to _nd, so _nd + _num_frozen_points is the valid + // location limit. + Ok(save_data_in_base_dimensions( + data_file, + &mut self.dataset.data, + self.num_active_pts + self.configuration.num_frozen_pts, + self.configuration.dim, + self.configuration.aligned_dim, + 0, + )?) + } + + /// Save the delete list to a file only if the delete list length is not zero. + pub fn save_delete_list(&mut self, delete_list_file: &str) -> ANNResult { + let mut delete_file_size = 0; + if let Ok(delete_set) = self.delete_set.read() { + let delete_set_len = delete_set.len() as u32; + + if delete_set_len != 0 { + let file: File = File::create(delete_list_file)?; + let mut writer = BufWriter::new(file); + + // Write the length of the set. + writer.write_all(&delete_set_len.to_le_bytes())?; + delete_file_size += std::mem::size_of::(); + + // Write the elements of the set. + for &item in delete_set.iter() { + writer.write_all(&item.to_be_bytes())?; + delete_file_size += std::mem::size_of::(); + } + + writer.flush()?; + } + } else { + return Err(ANNError::log_lock_poison_error( + "Poisoned lock on delete set. Can't save deleted list.".to_string(), + )); + } + + Ok(delete_file_size) + } + + // load the deleted list from the delete file if it exists. + pub fn load_delete_list(&mut self, delete_list_file: &str) -> ANNResult { + let mut len = 0; + + if file_exists(delete_list_file) { + let file = File::open(delete_list_file)?; + let mut reader = BufReader::new(file); + + len = reader.read_u32::()? as usize; + + if let Ok(mut delete_set) = self.delete_set.write() { + for _ in 0..len { + let item = reader.read_u32::()?; + delete_set.insert(item); + } + } else { + return Err(ANNError::log_lock_poison_error( + "Poisoned lock on delete set. Can't load deleted list.".to_string(), + )); + } + } + + Ok(len) + } +} + +#[cfg(test)] +mod index_test { + use std::fs; + + use vector::Metric; + + use super::*; + use crate::{ + index::ANNInmemIndex, + model::{ + configuration::index_write_parameters::IndexWriteParametersBuilder, vertex::DIM_128, + IndexConfiguration, + }, + utils::{load_metadata_from_file, round_up}, + }; + + const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin"; + const R: u32 = 4; + const L: u32 = 50; + const ALPHA: f32 = 1.2; + + #[cfg_attr(not(coverage), test)] + fn save_graph_test() { + let parameters = IndexWriteParametersBuilder::new(50, 4) + .with_alpha(1.2) + .build(); + let config = + IndexConfiguration::new(Metric::L2, 10, 16, 16, false, 0, false, 8, 1f32, parameters); + let mut index = InmemIndex::::new(config).unwrap(); + let final_graph = InMemoryGraph::new(10, 3); + let num_active_pts = 2_usize; + index.final_graph = final_graph; + index.num_active_pts = num_active_pts; + let graph_file = "test_save_graph_data.bin"; + let result = index.save_graph(graph_file); + assert!(result.is_ok()); + + fs::remove_file(graph_file).expect("Failed to delete file"); + } + + #[test] + fn save_data_test() { + let (data_num, dim) = load_metadata_from_file(TEST_DATA_FILE).unwrap(); + + let index_write_parameters = IndexWriteParametersBuilder::new(L, R) + .with_alpha(ALPHA) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num, + false, + 0, + false, + 0, + 1f32, + index_write_parameters, + ); + let mut index: InmemIndex = InmemIndex::new(config).unwrap(); + + index.build(TEST_DATA_FILE, data_num).unwrap(); + + let data_file = "test.data"; + let result = index.save_data(data_file); + assert_eq!( + result.unwrap(), + 2 * std::mem::size_of::() + + (index.num_active_pts + index.configuration.num_frozen_pts) + * index.configuration.dim + * (std::mem::size_of::()) + ); + fs::remove_file(data_file).expect("Failed to delete file"); + } +} diff --git a/rust/diskann/src/index/inmem_index/mod.rs b/rust/diskann/src/index/inmem_index/mod.rs new file mode 100644 index 000000000..f2a091a09 --- /dev/null +++ b/rust/diskann/src/index/inmem_index/mod.rs @@ -0,0 +1,12 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod inmem_index; +pub use inmem_index::InmemIndex; + +mod inmem_index_storage; + +pub mod ann_inmem_index; + diff --git a/rust/diskann/src/index/mod.rs b/rust/diskann/src/index/mod.rs new file mode 100644 index 000000000..18c3bd5e9 --- /dev/null +++ b/rust/diskann/src/index/mod.rs @@ -0,0 +1,11 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod inmem_index; +pub use inmem_index::ann_inmem_index::*; +pub use inmem_index::InmemIndex; + +mod disk_index; +pub use disk_index::*; + diff --git a/rust/diskann/src/instrumentation/disk_index_build_logger.rs b/rust/diskann/src/instrumentation/disk_index_build_logger.rs new file mode 100644 index 000000000..d34935342 --- /dev/null +++ b/rust/diskann/src/instrumentation/disk_index_build_logger.rs @@ -0,0 +1,57 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use logger::logger::indexlog::DiskIndexConstructionCheckpoint; +use logger::logger::indexlog::DiskIndexConstructionLog; +use logger::logger::indexlog::Log; +use logger::logger::indexlog::LogLevel; +use logger::message_handler::send_log; + +use crate::{utils::Timer, common::ANNResult}; + +pub struct DiskIndexBuildLogger { + timer: Timer, + checkpoint: DiskIndexConstructionCheckpoint, +} + +impl DiskIndexBuildLogger { + pub fn new(checkpoint: DiskIndexConstructionCheckpoint) -> Self { + Self { + timer: Timer::new(), + checkpoint, + } + } + + pub fn log_checkpoint(&mut self, next_checkpoint: DiskIndexConstructionCheckpoint) -> ANNResult<()> { + if self.checkpoint == DiskIndexConstructionCheckpoint::None { + return Ok(()); + } + + let mut log = Log::default(); + let disk_index_construction_log = DiskIndexConstructionLog { + checkpoint: self.checkpoint as i32, + time_spent_in_seconds: self.timer.elapsed().as_secs_f32(), + g_cycles_spent: self.timer.elapsed_gcycles(), + log_level: LogLevel::Info as i32, + }; + log.disk_index_construction_log = Some(disk_index_construction_log); + + send_log(log)?; + self.checkpoint = next_checkpoint; + self.timer.reset(); + Ok(()) + } +} + +#[cfg(test)] +mod dataset_test { + use super::*; + + #[test] + fn test_log() { + let mut logger = DiskIndexBuildLogger::new(DiskIndexConstructionCheckpoint::PqConstruction); + logger.log_checkpoint(DiskIndexConstructionCheckpoint::InmemIndexBuild).unwrap();logger.log_checkpoint(logger::logger::indexlog::DiskIndexConstructionCheckpoint::DiskLayout).unwrap(); + } +} + diff --git a/rust/diskann/src/instrumentation/index_logger.rs b/rust/diskann/src/instrumentation/index_logger.rs new file mode 100644 index 000000000..dfc81ad15 --- /dev/null +++ b/rust/diskann/src/instrumentation/index_logger.rs @@ -0,0 +1,47 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::sync::atomic::{AtomicUsize, Ordering}; + +use logger::logger::indexlog::IndexConstructionLog; +use logger::logger::indexlog::Log; +use logger::logger::indexlog::LogLevel; +use logger::message_handler::send_log; + +use crate::common::ANNResult; +use crate::utils::Timer; + +pub struct IndexLogger { + items_processed: AtomicUsize, + timer: Timer, + range: usize, +} + +impl IndexLogger { + pub fn new(range: usize) -> Self { + Self { + items_processed: AtomicUsize::new(0), + timer: Timer::new(), + range, + } + } + + pub fn vertex_processed(&self) -> ANNResult<()> { + let count = self.items_processed.fetch_add(1, Ordering::Relaxed); + if count % 100_000 == 0 { + let mut log = Log::default(); + let index_construction_log = IndexConstructionLog { + percentage_complete: (100_f32 * count as f32) / (self.range as f32), + time_spent_in_seconds: self.timer.elapsed().as_secs_f32(), + g_cycles_spent: self.timer.elapsed_gcycles(), + log_level: LogLevel::Info as i32, + }; + log.index_construction_log = Some(index_construction_log); + + send_log(log)?; + } + + Ok(()) + } +} diff --git a/rust/diskann/src/instrumentation/mod.rs b/rust/diskann/src/instrumentation/mod.rs new file mode 100644 index 000000000..234e53ce9 --- /dev/null +++ b/rust/diskann/src/instrumentation/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod index_logger; +pub use index_logger::IndexLogger; + +mod disk_index_build_logger; +pub use disk_index_build_logger::DiskIndexBuildLogger; diff --git a/rust/diskann/src/lib.rs b/rust/diskann/src/lib.rs new file mode 100644 index 000000000..1f89e33fc --- /dev/null +++ b/rust/diskann/src/lib.rs @@ -0,0 +1,26 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![cfg_attr( + not(test), + warn(clippy::panic, clippy::unwrap_used, clippy::expect_used) +)] +#![cfg_attr(test, allow(clippy::unused_io_amount))] + +pub mod utils; + +pub mod algorithm; + +pub mod model; + +pub mod common; + +pub mod index; + +pub mod storage; + +pub mod instrumentation; + +#[cfg(test)] +pub mod test_utils; diff --git a/rust/diskann/src/model/configuration/disk_index_build_parameter.rs b/rust/diskann/src/model/configuration/disk_index_build_parameter.rs new file mode 100644 index 000000000..539192af0 --- /dev/null +++ b/rust/diskann/src/model/configuration/disk_index_build_parameter.rs @@ -0,0 +1,85 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Parameters for disk index construction. + +use crate::common::{ANNResult, ANNError}; + +/// Cached nodes size in GB +const SPACE_FOR_CACHED_NODES_IN_GB: f64 = 0.25; + +/// Threshold for caching in GB +const THRESHOLD_FOR_CACHING_IN_GB: f64 = 1.0; + +/// Parameters specific for disk index construction. +#[derive(Clone, Copy, PartialEq, Debug)] +pub struct DiskIndexBuildParameters { + /// Bound on the memory footprint of the index at search time in bytes. + /// Once built, the index will use up only the specified RAM limit, the rest will reside on disk. + /// This will dictate how aggressively we compress the data vectors to store in memory. + /// Larger will yield better performance at search time. + search_ram_limit: f64, + + /// Limit on the memory allowed for building the index in bytes. + index_build_ram_limit: f64, +} + +impl DiskIndexBuildParameters { + /// Create DiskIndexBuildParameters instance + pub fn new(search_ram_limit_gb: f64, index_build_ram_limit_gb: f64) -> ANNResult { + let param = Self { + search_ram_limit: Self::get_memory_budget(search_ram_limit_gb), + index_build_ram_limit: index_build_ram_limit_gb * 1024_f64 * 1024_f64 * 1024_f64, + }; + + if param.search_ram_limit <= 0f64 { + return Err(ANNError::log_index_config_error("search_ram_limit".to_string(), "RAM budget should be > 0".to_string())) + } + + if param.index_build_ram_limit <= 0f64 { + return Err(ANNError::log_index_config_error("index_build_ram_limit".to_string(), "RAM budget should be > 0".to_string())) + } + + Ok(param) + } + + /// Get search_ram_limit + pub fn search_ram_limit(&self) -> f64 { + self.search_ram_limit + } + + /// Get index_build_ram_limit + pub fn index_build_ram_limit(&self) -> f64 { + self.index_build_ram_limit + } + + fn get_memory_budget(mut index_ram_limit_gb: f64) -> f64 { + if index_ram_limit_gb - SPACE_FOR_CACHED_NODES_IN_GB > THRESHOLD_FOR_CACHING_IN_GB { + // slack for space used by cached nodes + index_ram_limit_gb -= SPACE_FOR_CACHED_NODES_IN_GB; + } + + index_ram_limit_gb * 1024_f64 * 1024_f64 * 1024_f64 + } +} + +#[cfg(test)] +mod dataset_test { + use super::*; + + #[test] + fn sufficient_ram_for_caching() { + let param = DiskIndexBuildParameters::new(1.26_f64, 1.0_f64).unwrap(); + assert_eq!(param.search_ram_limit, 1.01_f64 * 1024_f64 * 1024_f64 * 1024_f64); + } + + #[test] + fn insufficient_ram_for_caching() { + let param = DiskIndexBuildParameters::new(0.03_f64, 1.0_f64).unwrap(); + assert_eq!(param.search_ram_limit, 0.03_f64 * 1024_f64 * 1024_f64 * 1024_f64); + } +} + diff --git a/rust/diskann/src/model/configuration/index_configuration.rs b/rust/diskann/src/model/configuration/index_configuration.rs new file mode 100644 index 000000000..3e8c472ae --- /dev/null +++ b/rust/diskann/src/model/configuration/index_configuration.rs @@ -0,0 +1,92 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Index configuration. + +use vector::Metric; + +use super::index_write_parameters::IndexWriteParameters; + +/// The index configuration +#[derive(Debug, Clone)] +pub struct IndexConfiguration { + /// Index write parameter + pub index_write_parameter: IndexWriteParameters, + + /// Distance metric + pub dist_metric: Metric, + + /// Dimension of the raw data + pub dim: usize, + + /// Aligned dimension - round up dim to the nearest multiple of 8 + pub aligned_dim: usize, + + /// Total number of points in given data set + pub max_points: usize, + + /// Number of points which are used as initial candidates when iterating to + /// closest point(s). These are not visible externally and won't be returned + /// by search. DiskANN forces at least 1 frozen point for dynamic index. + /// The frozen points have consecutive locations. + pub num_frozen_pts: usize, + + /// Calculate distance by PQ or not + pub use_pq_dist: bool, + + /// Number of PQ chunks + pub num_pq_chunks: usize, + + /// Use optimized product quantization + /// Currently not supported + pub use_opq: bool, + + /// potential for growth. 1.2 means the index can grow by up to 20%. + pub growth_potential: f32, + + // TODO: below settings are not supported in current iteration + // pub concurrent_consolidate: bool, + // pub has_built: bool, + // pub save_as_one_file: bool, + // pub dynamic_index: bool, + // pub enable_tags: bool, + // pub normalize_vecs: bool, +} + +impl IndexConfiguration { + /// Create IndexConfiguration instance + #[allow(clippy::too_many_arguments)] + pub fn new( + dist_metric: Metric, + dim: usize, + aligned_dim: usize, + max_points: usize, + use_pq_dist: bool, + num_pq_chunks: usize, + use_opq: bool, + num_frozen_pts: usize, + growth_potential: f32, + index_write_parameter: IndexWriteParameters + ) -> Self { + Self { + index_write_parameter, + dist_metric, + dim, + aligned_dim, + max_points, + num_frozen_pts, + use_pq_dist, + num_pq_chunks, + use_opq, + growth_potential, + } + } + + /// Get the size of adjacency list that we build out. + pub fn write_range(&self) -> usize { + self.index_write_parameter.max_degree as usize + } +} diff --git a/rust/diskann/src/model/configuration/index_write_parameters.rs b/rust/diskann/src/model/configuration/index_write_parameters.rs new file mode 100644 index 000000000..cb71f4297 --- /dev/null +++ b/rust/diskann/src/model/configuration/index_write_parameters.rs @@ -0,0 +1,245 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Index write parameters. + +/// Default parameter values. +pub mod default_param_vals { + /// Default value of alpha. + pub const ALPHA: f32 = 1.2; + + /// Default value of number of threads. + pub const NUM_THREADS: u32 = 0; + + /// Default value of number of rounds. + pub const NUM_ROUNDS: u32 = 2; + + /// Default value of max occlusion size. + pub const MAX_OCCLUSION_SIZE: u32 = 750; + + /// Default value of filter list size. + pub const FILTER_LIST_SIZE: u32 = 0; + + /// Default value of number of frozen points. + pub const NUM_FROZEN_POINTS: u32 = 0; + + /// Default value of max degree. + pub const MAX_DEGREE: u32 = 64; + + /// Default value of build list size. + pub const BUILD_LIST_SIZE: u32 = 100; + + /// Default value of saturate graph. + pub const SATURATE_GRAPH: bool = false; + + /// Default value of search list size. + pub const SEARCH_LIST_SIZE: u32 = 100; +} + +/// Index write parameters. +#[derive(Clone, Copy, PartialEq, Debug)] +pub struct IndexWriteParameters { + /// Search list size - L. + pub search_list_size: u32, + + /// Max degree - R. + pub max_degree: u32, + + /// Saturate graph. + pub saturate_graph: bool, + + /// Max occlusion size - C. + pub max_occlusion_size: u32, + + /// Alpha. + pub alpha: f32, + + /// Number of rounds. + pub num_rounds: u32, + + /// Number of threads. + pub num_threads: u32, + + /// Number of frozen points. + pub num_frozen_points: u32, +} + +impl Default for IndexWriteParameters { + /// Create IndexWriteParameters with default values + fn default() -> Self { + Self { + search_list_size: default_param_vals::SEARCH_LIST_SIZE, + max_degree: default_param_vals::MAX_DEGREE, + saturate_graph: default_param_vals::SATURATE_GRAPH, + max_occlusion_size: default_param_vals::MAX_OCCLUSION_SIZE, + alpha: default_param_vals::ALPHA, + num_rounds: default_param_vals::NUM_ROUNDS, + num_threads: default_param_vals::NUM_THREADS, + num_frozen_points: default_param_vals::NUM_FROZEN_POINTS + } + } +} + +/// The builder for IndexWriteParameters. +#[derive(Debug)] +pub struct IndexWriteParametersBuilder { + search_list_size: u32, + max_degree: u32, + max_occlusion_size: Option, + saturate_graph: Option, + alpha: Option, + num_rounds: Option, + num_threads: Option, + // filter_list_size: Option, + num_frozen_points: Option, +} + +impl IndexWriteParametersBuilder { + /// Initialize IndexWriteParametersBuilder + pub fn new(search_list_size: u32, max_degree: u32) -> Self { + Self { + search_list_size, + max_degree, + max_occlusion_size: None, + saturate_graph: None, + alpha: None, + num_rounds: None, + num_threads: None, + // filter_list_size: None, + num_frozen_points: None, + } + } + + /// Set max occlusion size. + pub fn with_max_occlusion_size(mut self, max_occlusion_size: u32) -> Self { + self.max_occlusion_size = Some(max_occlusion_size); + self + } + + /// Set saturate graph. + pub fn with_saturate_graph(mut self, saturate_graph: bool) -> Self { + self.saturate_graph = Some(saturate_graph); + self + } + + /// Set alpha. + pub fn with_alpha(mut self, alpha: f32) -> Self { + self.alpha = Some(alpha); + self + } + + /// Set number of rounds. + pub fn with_num_rounds(mut self, num_rounds: u32) -> Self { + self.num_rounds = Some(num_rounds); + self + } + + /// Set number of threads. + pub fn with_num_threads(mut self, num_threads: u32) -> Self { + self.num_threads = Some(num_threads); + self + } + + /* + pub fn with_filter_list_size(mut self, filter_list_size: u32) -> Self { + self.filter_list_size = Some(filter_list_size); + self + } + */ + + /// Set number of frozen points. + pub fn with_num_frozen_points(mut self, num_frozen_points: u32) -> Self { + self.num_frozen_points = Some(num_frozen_points); + self + } + + /// Build IndexWriteParameters from IndexWriteParametersBuilder. + pub fn build(self) -> IndexWriteParameters { + IndexWriteParameters { + search_list_size: self.search_list_size, + max_degree: self.max_degree, + saturate_graph: self.saturate_graph.unwrap_or(default_param_vals::SATURATE_GRAPH), + max_occlusion_size: self.max_occlusion_size.unwrap_or(default_param_vals::MAX_OCCLUSION_SIZE), + alpha: self.alpha.unwrap_or(default_param_vals::ALPHA), + num_rounds: self.num_rounds.unwrap_or(default_param_vals::NUM_ROUNDS), + num_threads: self.num_threads.unwrap_or(default_param_vals::NUM_THREADS), + // filter_list_size: self.filter_list_size.unwrap_or(default_param_vals::FILTER_LIST_SIZE), + num_frozen_points: self.num_frozen_points.unwrap_or(default_param_vals::NUM_FROZEN_POINTS), + } + } +} + +/// Construct IndexWriteParametersBuilder from IndexWriteParameters. +impl From for IndexWriteParametersBuilder { + fn from(param: IndexWriteParameters) -> Self { + Self { + search_list_size: param.search_list_size, + max_degree: param.max_degree, + max_occlusion_size: Some(param.max_occlusion_size), + saturate_graph: Some(param.saturate_graph), + alpha: Some(param.alpha), + num_rounds: Some(param.num_rounds), + num_threads: Some(param.num_threads), + // filter_list_size: Some(param.filter_list_size), + num_frozen_points: Some(param.num_frozen_points), + } + } +} + +#[cfg(test)] +mod parameters_test { + use crate::model::configuration::index_write_parameters::*; + + #[test] + fn test_default_index_params() { + let wp1 = IndexWriteParameters::default(); + assert_eq!(wp1.search_list_size, default_param_vals::SEARCH_LIST_SIZE); + assert_eq!(wp1.max_degree, default_param_vals::MAX_DEGREE); + assert_eq!(wp1.saturate_graph, default_param_vals::SATURATE_GRAPH); + assert_eq!(wp1.max_occlusion_size, default_param_vals::MAX_OCCLUSION_SIZE); + assert_eq!(wp1.alpha, default_param_vals::ALPHA); + assert_eq!(wp1.num_rounds, default_param_vals::NUM_ROUNDS); + assert_eq!(wp1.num_threads, default_param_vals::NUM_THREADS); + assert_eq!(wp1.num_frozen_points, default_param_vals::NUM_FROZEN_POINTS); + } + + #[test] + fn test_index_write_parameters_builder() { + // default value + let wp1 = IndexWriteParametersBuilder::new(10, 20).build(); + assert_eq!(wp1.search_list_size, 10); + assert_eq!(wp1.max_degree, 20); + assert_eq!(wp1.saturate_graph, default_param_vals::SATURATE_GRAPH); + assert_eq!(wp1.max_occlusion_size, default_param_vals::MAX_OCCLUSION_SIZE); + assert_eq!(wp1.alpha, default_param_vals::ALPHA); + assert_eq!(wp1.num_rounds, default_param_vals::NUM_ROUNDS); + assert_eq!(wp1.num_threads, default_param_vals::NUM_THREADS); + assert_eq!(wp1.num_frozen_points, default_param_vals::NUM_FROZEN_POINTS); + + // build with custom values + let wp2 = IndexWriteParametersBuilder::new(10, 20) + .with_max_occlusion_size(30) + .with_saturate_graph(true) + .with_alpha(0.5) + .with_num_rounds(40) + .with_num_threads(50) + .with_num_frozen_points(60) + .build(); + assert_eq!(wp2.search_list_size, 10); + assert_eq!(wp2.max_degree, 20); + assert!(wp2.saturate_graph); + assert_eq!(wp2.max_occlusion_size, 30); + assert_eq!(wp2.alpha, 0.5); + assert_eq!(wp2.num_rounds, 40); + assert_eq!(wp2.num_threads, 50); + assert_eq!(wp2.num_frozen_points, 60); + + // test from + let wp3 = IndexWriteParametersBuilder::from(wp2).build(); + assert_eq!(wp3, wp2); + } +} + diff --git a/rust/diskann/src/model/configuration/mod.rs b/rust/diskann/src/model/configuration/mod.rs new file mode 100644 index 000000000..201f97e98 --- /dev/null +++ b/rust/diskann/src/model/configuration/mod.rs @@ -0,0 +1,12 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod index_configuration; +pub use index_configuration::IndexConfiguration; + +pub mod index_write_parameters; +pub use index_write_parameters::*; + +pub mod disk_index_build_parameter; +pub use disk_index_build_parameter::DiskIndexBuildParameters; diff --git a/rust/diskann/src/model/data_store/disk_scratch_dataset.rs b/rust/diskann/src/model/data_store/disk_scratch_dataset.rs new file mode 100644 index 000000000..0d9a007ab --- /dev/null +++ b/rust/diskann/src/model/data_store/disk_scratch_dataset.rs @@ -0,0 +1,76 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Disk scratch dataset + +use std::mem::{size_of, size_of_val}; +use std::ptr; + +use crate::common::{AlignedBoxWithSlice, ANNResult}; +use crate::model::MAX_N_CMPS; +use crate::utils::round_up; + +/// DiskScratchDataset alignment +pub const DISK_SCRATCH_DATASET_ALIGN: usize = 256; + +/// Disk scratch dataset storing fp vectors with aligned dim +#[derive(Debug)] +pub struct DiskScratchDataset +{ + /// fp vectors with aligned dim + pub data: AlignedBoxWithSlice, + + /// current index to store the next fp vector + pub cur_index: usize, +} + +impl DiskScratchDataset +{ + /// Create DiskScratchDataset instance + pub fn new() -> ANNResult { + Ok(Self { + // C++ code allocates round_up(MAX_N_CMPS * N, 256) bytes, shouldn't it be round_up(MAX_N_CMPS * N, 256) * size_of:: bytes? + data: AlignedBoxWithSlice::new( + round_up(MAX_N_CMPS * N, DISK_SCRATCH_DATASET_ALIGN), + DISK_SCRATCH_DATASET_ALIGN)?, + cur_index: 0, + }) + } + + /// memcpy from fp vector bytes (its len should be `dim * size_of::()`) to self.data + /// The dest slice is a fp vector with aligned dim + /// * fp_vector_buf's dim might not be aligned dim (N) + /// # Safety + /// Behavior is undefined if any of the following conditions are violated: + /// + /// * `fp_vector_buf`'s len must be `dim * size_of::()` bytes + /// + /// * `fp_vector_buf` must be smaller than or equal to `N * size_of::()` bytes. + /// + /// * `fp_vector_buf` and `self.data` must be nonoverlapping. + pub unsafe fn memcpy_from_fp_vector_buf(&mut self, fp_vector_buf: &[u8]) -> &[T] { + if self.cur_index == MAX_N_CMPS { + self.cur_index = 0; + } + + let aligned_dim_vector = &mut self.data[self.cur_index * N..(self.cur_index + 1) * N]; + + assert!(fp_vector_buf.len() % size_of::() == 0); + assert!(fp_vector_buf.len() <= size_of_val(aligned_dim_vector)); + + // memcpy from fp_vector_buf to aligned_dim_vector + unsafe { + ptr::copy_nonoverlapping( + fp_vector_buf.as_ptr(), + aligned_dim_vector.as_mut_ptr() as *mut u8, + fp_vector_buf.len(), + ); + } + + self.cur_index += 1; + aligned_dim_vector + } +} diff --git a/rust/diskann/src/model/data_store/inmem_dataset.rs b/rust/diskann/src/model/data_store/inmem_dataset.rs new file mode 100644 index 000000000..6d8b649a2 --- /dev/null +++ b/rust/diskann/src/model/data_store/inmem_dataset.rs @@ -0,0 +1,285 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! In-memory Dataset + +use rayon::prelude::*; +use std::mem; +use vector::{FullPrecisionDistance, Metric}; + +use crate::common::{ANNError, ANNResult, AlignedBoxWithSlice}; +use crate::model::Vertex; +use crate::utils::copy_aligned_data_from_file; + +/// Dataset of all in-memory FP points +#[derive(Debug)] +pub struct InmemDataset +where + [T; N]: FullPrecisionDistance, +{ + /// All in-memory points + pub data: AlignedBoxWithSlice, + + /// Number of points we anticipate to have + pub num_points: usize, + + /// Number of active points i.e. existing in the graph + pub num_active_pts: usize, + + /// Capacity of the dataset + pub capacity: usize, +} + +impl<'a, T, const N: usize> InmemDataset +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + /// Create the dataset with size num_points and growth factor. + /// growth factor=1 means no growth (provision 100% space of num_points) + /// growth factor=1.2 means provision 120% space of num_points (20% extra space) + pub fn new(num_points: usize, index_growth_factor: f32) -> ANNResult { + let capacity = (((num_points * N) as f32) * index_growth_factor) as usize; + + Ok(Self { + data: AlignedBoxWithSlice::new(capacity, mem::size_of::() * 16)?, + num_points, + num_active_pts: num_points, + capacity, + }) + } + + /// get immutable data slice + pub fn get_data(&self) -> &[T] { + &self.data + } + + /// Build the dataset from file + pub fn build_from_file(&mut self, filename: &str, num_points_to_load: usize) -> ANNResult<()> { + println!( + "Loading {} vectors from file {} into dataset...", + num_points_to_load, filename + ); + self.num_active_pts = num_points_to_load; + + copy_aligned_data_from_file(filename, self.into_dto(), 0)?; + + println!("Dataset loaded."); + Ok(()) + } + + /// Append the dataset from file + pub fn append_from_file( + &mut self, + filename: &str, + num_points_to_append: usize, + ) -> ANNResult<()> { + println!( + "Appending {} vectors from file {} into dataset...", + num_points_to_append, filename + ); + if self.num_points + num_points_to_append > self.capacity { + return Err(ANNError::log_index_error(format!( + "Cannot append {} points to dataset of capacity {}", + num_points_to_append, self.capacity + ))); + } + + let pts_offset = self.num_active_pts; + copy_aligned_data_from_file(filename, self.into_dto(), pts_offset)?; + + self.num_active_pts += num_points_to_append; + self.num_points += num_points_to_append; + + println!("Dataset appended."); + Ok(()) + } + + /// Get vertex by id + pub fn get_vertex(&'a self, id: u32) -> ANNResult> { + let start = id as usize * N; + let end = start + N; + + if end <= self.data.len() { + let val = <&[T; N]>::try_from(&self.data[start..end]).map_err(|err| { + ANNError::log_index_error(format!("Failed to get vertex {}, err={}", id, err)) + })?; + Ok(Vertex::new(val, id)) + } else { + Err(ANNError::log_index_error(format!( + "Invalid vertex id {}.", + id + ))) + } + } + + /// Get full precision distance between two nodes + pub fn get_distance(&self, id1: u32, id2: u32, metric: Metric) -> ANNResult { + let vertex1 = self.get_vertex(id1)?; + let vertex2 = self.get_vertex(id2)?; + + Ok(vertex1.compare(&vertex2, metric)) + } + + /// find out the medoid, the vertex in the dataset that is closest to the centroid + pub fn calculate_medoid_point_id(&self) -> ANNResult { + Ok(self.find_nearest_point_id(self.calculate_centroid_point()?)) + } + + /// calculate centroid, average of all vertices in the dataset + fn calculate_centroid_point(&self) -> ANNResult<[f32; N]> { + // Allocate and initialize the centroid vector + let mut center: [f32; N] = [0.0; N]; + + // Sum the data points' components + for i in 0..self.num_active_pts { + let vertex = self.get_vertex(i as u32)?; + let vertex_slice = vertex.vector(); + for j in 0..N { + center[j] += vertex_slice[j].into(); + } + } + + // Divide by the number of points to calculate the centroid + let capacity = self.num_active_pts as f32; + for item in center.iter_mut().take(N) { + *item /= capacity; + } + + Ok(center) + } + + /// find out the vertex closest to the given point + fn find_nearest_point_id(&self, point: [f32; N]) -> u32 { + // compute all to one distance + let mut distances = vec![0f32; self.num_active_pts]; + let slice = &self.data[..]; + distances.par_iter_mut().enumerate().for_each(|(i, dist)| { + let start = i * N; + for j in 0..N { + let diff: f32 = (point.as_slice()[j] - slice[start + j].into()) + * (point.as_slice()[j] - slice[start + j].into()); + *dist += diff; + } + }); + + let mut min_idx = 0; + let mut min_dist = f32::MAX; + for (i, distance) in distances.iter().enumerate().take(self.num_active_pts) { + if *distance < min_dist { + min_idx = i; + min_dist = *distance; + } + } + min_idx as u32 + } + + /// Prefetch vertex data in the memory hierarchy + /// NOTE: good efficiency when total_vec_size is integral multiple of 64 + #[inline] + pub fn prefetch_vector(&self, id: u32) { + let start = id as usize * N; + let end = start + N; + + if end <= self.data.len() { + let vec = &self.data[start..end]; + vector::prefetch_vector(vec); + } + } + + /// Convert into dto object + pub fn into_dto(&mut self) -> DatasetDto { + DatasetDto { + data: &mut self.data, + rounded_dim: N, + } + } +} + +/// Dataset dto used for other layer, such as storage +/// N is the aligned dimension +#[derive(Debug)] +pub struct DatasetDto<'a, T> { + /// data slice borrow from dataset + pub data: &'a mut [T], + + /// rounded dimension + pub rounded_dim: usize, +} + +#[cfg(test)] +mod dataset_test { + use std::fs; + + use super::*; + use crate::model::vertex::DIM_128; + + #[test] + fn get_vertex_within_range() { + let num_points = 1_000_000; + let id = 999_999; + let dataset = InmemDataset::::new(num_points, 1f32).unwrap(); + + let vertex = dataset.get_vertex(999_999).unwrap(); + + assert_eq!(vertex.vertex_id(), id); + assert_eq!(vertex.vector().len(), DIM_128); + assert_eq!(vertex.vector().as_ptr(), unsafe { + dataset.data.as_ptr().add((id as usize) * DIM_128) + }); + } + + #[test] + fn get_vertex_out_of_range() { + let num_points = 1_000_000; + let invalid_id = 1_000_000; + let dataset = InmemDataset::::new(num_points, 1f32).unwrap(); + + if dataset.get_vertex(invalid_id).is_ok() { + panic!("id ({}) should be out of range", invalid_id) + }; + } + + #[test] + fn load_data_test() { + let file_name = "dataset_test_load_data_test.bin"; + //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8] + let data: [u8; 72] = [ + 2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, + 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, + 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, + 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, + 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41, + ]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let mut dataset = InmemDataset::::new(2, 1f32).unwrap(); + + match copy_aligned_data_from_file( + file_name, + dataset.into_dto(), + 0, + ) { + Ok((npts, dim)) => { + fs::remove_file(file_name).expect("Failed to delete file"); + assert!(npts == 2); + assert!(dim == 8); + assert!(dataset.data.len() == 16); + + let first_vertex = dataset.get_vertex(0).unwrap(); + let second_vertex = dataset.get_vertex(1).unwrap(); + + assert!(*first_vertex.vector() == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + assert!(*second_vertex.vector() == [9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]); + } + Err(e) => { + fs::remove_file(file_name).expect("Failed to delete file"); + panic!("{}", e) + } + } + } +} + diff --git a/rust/diskann/src/model/data_store/mod.rs b/rust/diskann/src/model/data_store/mod.rs new file mode 100644 index 000000000..4e7e68393 --- /dev/null +++ b/rust/diskann/src/model/data_store/mod.rs @@ -0,0 +1,11 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod inmem_dataset; +pub use inmem_dataset::InmemDataset; +pub use inmem_dataset::DatasetDto; + +mod disk_scratch_dataset; +pub use disk_scratch_dataset::*; diff --git a/rust/diskann/src/model/graph/adjacency_list.rs b/rust/diskann/src/model/graph/adjacency_list.rs new file mode 100644 index 000000000..7ad2d7d5b --- /dev/null +++ b/rust/diskann/src/model/graph/adjacency_list.rs @@ -0,0 +1,64 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Adjacency List + +use std::ops::{Deref, DerefMut}; + +#[derive(Debug, Eq, PartialEq)] +/// Represents the out neighbors of a vertex +pub struct AdjacencyList { + edges: Vec, +} + +/// In-mem index related limits +const GRAPH_SLACK_FACTOR: f32 = 1.3_f32; + +impl AdjacencyList { + /// Create AdjacencyList with capacity slack for a range. + pub fn for_range(range: usize) -> Self { + let capacity = (range as f32 * GRAPH_SLACK_FACTOR).ceil() as usize; + Self { + edges: Vec::with_capacity(capacity), + } + } + + /// Push a node to the list of neighbors for the given node. + pub fn push(&mut self, node_id: u32) { + debug_assert!(self.edges.len() < self.edges.capacity()); + self.edges.push(node_id); + } +} + +impl From> for AdjacencyList { + fn from(edges: Vec) -> Self { + Self { edges } + } +} + +impl Deref for AdjacencyList { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.edges + } +} + +impl DerefMut for AdjacencyList { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.edges + } +} + +impl<'a> IntoIterator for &'a AdjacencyList { + type Item = &'a u32; + type IntoIter = std::slice::Iter<'a, u32>; + + fn into_iter(self) -> Self::IntoIter { + self.edges.iter() + } +} + diff --git a/rust/diskann/src/model/graph/disk_graph.rs b/rust/diskann/src/model/graph/disk_graph.rs new file mode 100644 index 000000000..49190b1cd --- /dev/null +++ b/rust/diskann/src/model/graph/disk_graph.rs @@ -0,0 +1,179 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_docs)] + +//! Disk graph + +use byteorder::{LittleEndian, ByteOrder}; +use vector::FullPrecisionDistance; + +use crate::common::{ANNResult, ANNError}; +use crate::model::data_store::DiskScratchDataset; +use crate::model::Vertex; +use crate::storage::DiskGraphStorage; + +use super::{VertexAndNeighbors, SectorGraph, AdjacencyList}; + +/// Disk graph +pub struct DiskGraph { + /// dim of fp vector in disk sector + dim: usize, + + /// number of nodes per sector + num_nodes_per_sector: u64, + + /// max node length in bytes + max_node_len: u64, + + /// the len of fp vector + fp_vector_len: u64, + + /// list of nodes (vertex_id) to fetch from disk + nodes_to_fetch: Vec, + + /// Sector graph + sector_graph: SectorGraph, +} + +impl<'a> DiskGraph { + /// Create DiskGraph instance + pub fn new( + dim: usize, + num_nodes_per_sector: u64, + max_node_len: u64, + fp_vector_len: u64, + beam_width: usize, + graph_storage: DiskGraphStorage, + ) -> ANNResult { + let graph = Self { + dim, + num_nodes_per_sector, + max_node_len, + fp_vector_len, + nodes_to_fetch: Vec::with_capacity(2 * beam_width), + sector_graph: SectorGraph::new(graph_storage)?, + }; + + Ok(graph) + } + + /// Add vertex_id into the list to fetch from disk + pub fn add_vertex(&mut self, id: u32) { + self.nodes_to_fetch.push(id); + } + + /// Fetch nodes from disk index + pub fn fetch_nodes(&mut self) -> ANNResult<()> { + let sectors_to_fetch: Vec = self.nodes_to_fetch.iter().map(|&id| self.node_sector_index(id)).collect(); + self.sector_graph.read_graph(§ors_to_fetch)?; + + Ok(()) + } + + /// Copy disk fp vector to DiskScratchDataset + /// Return the fp vector with aligned dim from DiskScratchDataset + pub fn copy_fp_vector_to_disk_scratch_dataset( + &self, + node_index: usize, + disk_scratch_dataset: &'a mut DiskScratchDataset + ) -> ANNResult> + where + [T; N]: FullPrecisionDistance, + { + if self.dim > N { + return Err(ANNError::log_index_error(format!( + "copy_sector_fp_to_aligned_dataset: dim {} is greater than aligned dim {}", + self.dim, N))); + } + + let fp_vector_buf = self.node_fp_vector_buf(node_index); + + // Safety condition is met here + let aligned_dim_vector = unsafe { disk_scratch_dataset.memcpy_from_fp_vector_buf(fp_vector_buf) }; + + Vertex::<'a, T, N>::try_from((aligned_dim_vector, self.nodes_to_fetch[node_index])) + .map_err(|err| ANNError::log_index_error(format!("TryFromSliceError: failed to get Vertex for disk index node, err={}", err))) + } + + /// Reset graph + pub fn reset(&mut self) { + self.nodes_to_fetch.clear(); + self.sector_graph.reset(); + } + + fn get_vertex_and_neighbors(&self, node_index: usize) -> VertexAndNeighbors { + let node_disk_buf = self.node_disk_buf(node_index); + let buf = &node_disk_buf[self.fp_vector_len as usize..]; + let num_neighbors = LittleEndian::read_u32(&buf[0..4]) as usize; + let neighbors_buf = &buf[4..4 + num_neighbors * 4]; + + let mut adjacency_list = AdjacencyList::for_range(num_neighbors); + for chunk in neighbors_buf.chunks(4) { + let neighbor_id = LittleEndian::read_u32(chunk); + adjacency_list.push(neighbor_id); + } + + VertexAndNeighbors::new(self.nodes_to_fetch[node_index], adjacency_list) + } + + #[inline] + fn node_sector_index(&self, vertex_id: u32) -> u64 { + vertex_id as u64 / self.num_nodes_per_sector + 1 + } + + #[inline] + fn node_disk_buf(&self, node_index: usize) -> &[u8] { + let vertex_id = self.nodes_to_fetch[node_index]; + + // get sector_buf where this node is located + let sector_buf = self.sector_graph.get_sector_buf(node_index); + let node_offset = (vertex_id as u64 % self.num_nodes_per_sector * self.max_node_len) as usize; + §or_buf[node_offset..node_offset + self.max_node_len as usize] + } + + #[inline] + fn node_fp_vector_buf(&self, node_index: usize) -> &[u8] { + let node_disk_buf = self.node_disk_buf(node_index); + &node_disk_buf[..self.fp_vector_len as usize] + } +} + +/// Iterator for DiskGraph +pub struct DiskGraphIntoIterator<'a> { + graph: &'a DiskGraph, + index: usize, +} + +impl<'a> IntoIterator for &'a DiskGraph +{ + type IntoIter = DiskGraphIntoIterator<'a>; + type Item = ANNResult<(usize, VertexAndNeighbors)>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + DiskGraphIntoIterator { + graph: self, + index: 0, + } + } +} + +impl<'a> Iterator for DiskGraphIntoIterator<'a> +{ + type Item = ANNResult<(usize, VertexAndNeighbors)>; + + fn next(&mut self) -> Option { + if self.index >= self.graph.nodes_to_fetch.len() { + return None; + } + + let node_index = self.index; + let vertex_and_neighbors = self.graph.get_vertex_and_neighbors(self.index); + + self.index += 1; + Some(Ok((node_index, vertex_and_neighbors))) + } +} + diff --git a/rust/diskann/src/model/graph/inmem_graph.rs b/rust/diskann/src/model/graph/inmem_graph.rs new file mode 100644 index 000000000..3d08db837 --- /dev/null +++ b/rust/diskann/src/model/graph/inmem_graph.rs @@ -0,0 +1,141 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! In-memory graph + +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +use crate::common::ANNError; + +use super::VertexAndNeighbors; + +/// The entire graph of in-memory index +#[derive(Debug)] +pub struct InMemoryGraph { + /// The entire graph + pub final_graph: Vec>, +} + +impl InMemoryGraph { + /// Create InMemoryGraph instance + pub fn new(size: usize, max_degree: u32) -> Self { + let mut graph = Vec::with_capacity(size); + for id in 0..size { + graph.push(RwLock::new(VertexAndNeighbors::for_range( + id as u32, + max_degree as usize, + ))); + } + Self { final_graph: graph } + } + + /// Size of graph + pub fn size(&self) -> usize { + self.final_graph.len() + } + + /// Extend the graph by size vectors + pub fn extend(&mut self, size: usize, max_degree: u32) { + for id in 0..size { + self.final_graph + .push(RwLock::new(VertexAndNeighbors::for_range( + id as u32, + max_degree as usize, + ))); + } + } + + /// Get read guard of vertex_id + pub fn read_vertex_and_neighbors( + &self, + vertex_id: u32, + ) -> Result, ANNError> { + self.final_graph[vertex_id as usize].read().map_err(|err| { + ANNError::log_lock_poison_error(format!( + "PoisonError: Lock poisoned when reading final_graph for vertex_id {}, err={}", + vertex_id, err + )) + }) + } + + /// Get write guard of vertex_id + pub fn write_vertex_and_neighbors( + &self, + vertex_id: u32, + ) -> Result, ANNError> { + self.final_graph[vertex_id as usize].write().map_err(|err| { + ANNError::log_lock_poison_error(format!( + "PoisonError: Lock poisoned when writing final_graph for vertex_id {}, err={}", + vertex_id, err + )) + }) + } +} + +#[cfg(test)] +mod graph_tests { + use crate::model::{graph::AdjacencyList, GRAPH_SLACK_FACTOR}; + + use super::*; + + #[test] + fn test_new() { + let graph = InMemoryGraph::new(10, 10); + let capacity = (GRAPH_SLACK_FACTOR * 10_f64).ceil() as usize; + + assert_eq!(graph.final_graph.len(), 10); + for i in 0..10 { + let neighbor = graph.final_graph[i].read().unwrap(); + assert_eq!(neighbor.vertex_id, i as u32); + assert_eq!(neighbor.get_neighbors().capacity(), capacity); + } + } + + #[test] + fn test_size() { + let graph = InMemoryGraph::new(10, 10); + assert_eq!(graph.size(), 10); + } + + #[test] + fn test_extend() { + let mut graph = InMemoryGraph::new(10, 10); + graph.extend(10, 10); + + assert_eq!(graph.size(), 20); + + let capacity = (GRAPH_SLACK_FACTOR * 10_f64).ceil() as usize; + let mut id: u32 = 0; + + for i in 10..20 { + let neighbor = graph.final_graph[i].read().unwrap(); + assert_eq!(neighbor.vertex_id, id); + assert_eq!(neighbor.get_neighbors().capacity(), capacity); + id += 1; + } + } + + #[test] + fn test_read_vertex_and_neighbors() { + let graph = InMemoryGraph::new(10, 10); + let neighbor = graph.read_vertex_and_neighbors(0); + assert!(neighbor.is_ok()); + assert_eq!(neighbor.unwrap().vertex_id, 0); + } + + #[test] + fn test_write_vertex_and_neighbors() { + let graph = InMemoryGraph::new(10, 10); + { + let neighbor = graph.write_vertex_and_neighbors(0); + assert!(neighbor.is_ok()); + neighbor.unwrap().add_to_neighbors(10, 10); + } + + let neighbor = graph.read_vertex_and_neighbors(0).unwrap(); + assert_eq!(neighbor.get_neighbors(), &AdjacencyList::from(vec![10_u32])); + } +} diff --git a/rust/diskann/src/model/graph/mod.rs b/rust/diskann/src/model/graph/mod.rs new file mode 100644 index 000000000..d1457f1c2 --- /dev/null +++ b/rust/diskann/src/model/graph/mod.rs @@ -0,0 +1,20 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod inmem_graph; +pub use inmem_graph::InMemoryGraph; + +pub mod vertex_and_neighbors; +pub use vertex_and_neighbors::VertexAndNeighbors; + +mod adjacency_list; +pub use adjacency_list::AdjacencyList; + +mod sector_graph; +pub use sector_graph::*; + +mod disk_graph; +pub use disk_graph::*; + diff --git a/rust/diskann/src/model/graph/sector_graph.rs b/rust/diskann/src/model/graph/sector_graph.rs new file mode 100644 index 000000000..e51e0bf03 --- /dev/null +++ b/rust/diskann/src/model/graph/sector_graph.rs @@ -0,0 +1,87 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_docs)] + +//! Sector graph + +use std::ops::Deref; + +use crate::common::{AlignedBoxWithSlice, ANNResult, ANNError}; +use crate::model::{MAX_N_SECTOR_READS, SECTOR_LEN, AlignedRead}; +use crate::storage::DiskGraphStorage; + +/// Sector graph read from disk index +pub struct SectorGraph { + /// Sector bytes from disk + /// One sector has num_nodes_per_sector nodes + /// Each node's layout: {full precision vector:[T; DIM]}{num_nbrs: u32}{neighbors: [u32; num_nbrs]} + /// The fp vector is not aligned + sectors_data: AlignedBoxWithSlice, + + /// Graph storage to read sectors + graph_storage: DiskGraphStorage, + + /// Current sector index into which the next read reads data + cur_sector_idx: u64, +} + +impl SectorGraph { + /// Create SectorGraph instance + pub fn new(graph_storage: DiskGraphStorage) -> ANNResult { + Ok(Self { + sectors_data: AlignedBoxWithSlice::new(MAX_N_SECTOR_READS * SECTOR_LEN, SECTOR_LEN)?, + graph_storage, + cur_sector_idx: 0, + }) + } + + /// Reset SectorGraph + pub fn reset(&mut self) { + self.cur_sector_idx = 0; + } + + /// Read sectors into sectors_data + /// They are in the same order as sectors_to_fetch + pub fn read_graph(&mut self, sectors_to_fetch: &[u64]) -> ANNResult<()> { + let cur_sector_idx_usize: usize = self.cur_sector_idx.try_into()?; + if sectors_to_fetch.len() > MAX_N_SECTOR_READS - cur_sector_idx_usize { + return Err(ANNError::log_index_error(format!( + "Trying to read too many sectors. number of sectors to read: {}, max number of sectors can read: {}", + sectors_to_fetch.len(), + MAX_N_SECTOR_READS - cur_sector_idx_usize, + ))); + } + + let mut sector_slices = self.sectors_data.split_into_nonoverlapping_mut_slices( + cur_sector_idx_usize * SECTOR_LEN..(cur_sector_idx_usize + sectors_to_fetch.len()) * SECTOR_LEN, + SECTOR_LEN)?; + + let mut read_requests = Vec::with_capacity(sector_slices.len()); + for (local_sector_idx, slice) in sector_slices.iter_mut().enumerate() { + let sector_id = sectors_to_fetch[local_sector_idx]; + read_requests.push(AlignedRead::new(sector_id * SECTOR_LEN as u64, slice)?); + } + + self.graph_storage.read(&mut read_requests)?; + self.cur_sector_idx += sectors_to_fetch.len() as u64; + + Ok(()) + } + + /// Get sector data by local index + #[inline] + pub fn get_sector_buf(&self, local_sector_idx: usize) -> &[u8] { + &self.sectors_data[local_sector_idx * SECTOR_LEN..(local_sector_idx + 1) * SECTOR_LEN] + } +} + +impl Deref for SectorGraph { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.sectors_data + } +} + diff --git a/rust/diskann/src/model/graph/vertex_and_neighbors.rs b/rust/diskann/src/model/graph/vertex_and_neighbors.rs new file mode 100644 index 000000000..a9fa38932 --- /dev/null +++ b/rust/diskann/src/model/graph/vertex_and_neighbors.rs @@ -0,0 +1,159 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Vertex and its Adjacency List + +use crate::model::GRAPH_SLACK_FACTOR; + +use super::AdjacencyList; + +/// The out neighbors of vertex_id +#[derive(Debug)] +pub struct VertexAndNeighbors { + /// The id of the vertex + pub vertex_id: u32, + + /// All out neighbors (id) of vertex_id + neighbors: AdjacencyList, +} + +impl VertexAndNeighbors { + /// Create VertexAndNeighbors with id and capacity + pub fn for_range(id: u32, range: usize) -> Self { + Self { + vertex_id: id, + neighbors: AdjacencyList::for_range(range), + } + } + + /// Create VertexAndNeighbors with id and neighbors + pub fn new(vertex_id: u32, neighbors: AdjacencyList) -> Self { + Self { + vertex_id, + neighbors, + } + } + + /// Get size of neighbors + #[inline(always)] + pub fn size(&self) -> usize { + self.neighbors.len() + } + + /// Update the neighbors vector (post a pruning exercise) + #[inline(always)] + pub fn set_neighbors(&mut self, new_neighbors: AdjacencyList) { + // Replace the graph entry with the pruned neighbors + self.neighbors = new_neighbors; + } + + /// Get the neighbors + #[inline(always)] + pub fn get_neighbors(&self) -> &AdjacencyList { + &self.neighbors + } + + /// Adds a node to the list of neighbors for the given node. + /// + /// # Arguments + /// + /// * `node_id` - The ID of the node to add. + /// * `range` - The range of the graph. + /// + /// # Return + /// + /// Returns `None` if the node is already in the list of neighbors, or a `Vec` containing the updated list of neighbors if the list of neighbors is full. + pub fn add_to_neighbors(&mut self, node_id: u32, range: u32) -> Option> { + // Check if n is already in the graph entry + if self.neighbors.contains(&node_id) { + return None; + } + + let neighbor_len = self.neighbors.len(); + + // If not, check if the graph entry has enough space + if neighbor_len < (GRAPH_SLACK_FACTOR * range as f64) as usize { + // If yes, add n to the graph entry + self.neighbors.push(node_id); + return None; + } + + let mut copy_of_neighbors = Vec::with_capacity(neighbor_len + 1); + unsafe { + let dst = copy_of_neighbors.as_mut_ptr(); + std::ptr::copy_nonoverlapping(self.neighbors.as_ptr(), dst, neighbor_len); + dst.add(neighbor_len).write(node_id); + copy_of_neighbors.set_len(neighbor_len + 1); + } + + Some(copy_of_neighbors) + } +} + +#[cfg(test)] +mod vertex_and_neighbors_tests { + use crate::model::GRAPH_SLACK_FACTOR; + + use super::*; + + #[test] + fn test_set_with_capacity() { + let neighbors = VertexAndNeighbors::for_range(20, 10); + assert_eq!(neighbors.vertex_id, 20); + assert_eq!( + neighbors.neighbors.capacity(), + (10_f32 * GRAPH_SLACK_FACTOR as f32).ceil() as usize + ); + } + + #[test] + fn test_size() { + let mut neighbors = VertexAndNeighbors::for_range(20, 10); + + for i in 0..5 { + neighbors.neighbors.push(i); + } + + assert_eq!(neighbors.size(), 5); + } + + #[test] + fn test_set_neighbors() { + let mut neighbors = VertexAndNeighbors::for_range(20, 10); + let new_vec = AdjacencyList::from(vec![1, 2, 3, 4, 5]); + neighbors.set_neighbors(AdjacencyList::from(new_vec.clone())); + + assert_eq!(neighbors.neighbors, new_vec); + } + + #[test] + fn test_get_neighbors() { + let mut neighbors = VertexAndNeighbors::for_range(20, 10); + neighbors.set_neighbors(AdjacencyList::from(vec![1, 2, 3, 4, 5])); + let neighbor_ref = neighbors.get_neighbors(); + + assert!(std::ptr::eq(&neighbors.neighbors, neighbor_ref)) + } + + #[test] + fn test_add_to_neighbors() { + let mut neighbors = VertexAndNeighbors::for_range(20, 10); + + assert_eq!(neighbors.add_to_neighbors(1, 1), None); + assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1])); + + assert_eq!(neighbors.add_to_neighbors(1, 1), None); + assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1])); + + let ret = neighbors.add_to_neighbors(2, 1); + assert!(ret.is_some()); + assert_eq!(ret.unwrap(), vec![1, 2]); + assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1])); + + assert_eq!(neighbors.add_to_neighbors(2, 2), None); + assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1, 2])); + } +} diff --git a/rust/diskann/src/model/mod.rs b/rust/diskann/src/model/mod.rs new file mode 100644 index 000000000..a4f15ee52 --- /dev/null +++ b/rust/diskann/src/model/mod.rs @@ -0,0 +1,29 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod neighbor; +pub use neighbor::Neighbor; +pub use neighbor::NeighborPriorityQueue; + +pub mod data_store; +pub use data_store::InmemDataset; + +pub mod graph; +pub use graph::InMemoryGraph; +pub use graph::VertexAndNeighbors; + +pub mod configuration; +pub use configuration::*; + +pub mod scratch; +pub use scratch::*; + +pub mod vertex; +pub use vertex::Vertex; + +pub mod pq; +pub use pq::*; + +pub mod windows_aligned_file_reader; +pub use windows_aligned_file_reader::*; diff --git a/rust/diskann/src/model/neighbor/mod.rs b/rust/diskann/src/model/neighbor/mod.rs new file mode 100644 index 000000000..cd0dbad2a --- /dev/null +++ b/rust/diskann/src/model/neighbor/mod.rs @@ -0,0 +1,13 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod neighbor; +pub use neighbor::*; + +mod neighbor_priority_queue; +pub use neighbor_priority_queue::*; + +mod sorted_neighbor_vector; +pub use sorted_neighbor_vector::SortedNeighborVector; diff --git a/rust/diskann/src/model/neighbor/neighbor.rs b/rust/diskann/src/model/neighbor/neighbor.rs new file mode 100644 index 000000000..8c712bcd3 --- /dev/null +++ b/rust/diskann/src/model/neighbor/neighbor.rs @@ -0,0 +1,104 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::cmp::Ordering; + +/// Neighbor node +#[derive(Debug, Clone, Copy)] +pub struct Neighbor { + /// The id of the node + pub id: u32, + + /// The distance from the query node to current node + pub distance: f32, + + /// Whether the current is visited or not + pub visited: bool, +} + +impl Neighbor { + /// Create the neighbor node and it has not been visited + pub fn new (id: u32, distance: f32) -> Self { + Self { + id, + distance, + visited: false + } + } +} + +impl Default for Neighbor { + fn default() -> Self { + Self { id: 0, distance: 0.0_f32, visited: false } + } +} + +impl PartialEq for Neighbor { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl Eq for Neighbor {} + +impl Ord for Neighbor { + fn cmp(&self, other: &Self) -> Ordering { + let ord = self.distance.partial_cmp(&other.distance).unwrap_or(std::cmp::Ordering::Equal); + + if ord == Ordering::Equal { + return self.id.cmp(&other.id); + } + + ord + } +} + +impl PartialOrd for Neighbor { + #[inline] + fn lt(&self, other: &Self) -> bool { + self.distance < other.distance || (self.distance == other.distance && self.id < other.id) + } + + // Reason for allowing panic = "Does not support comparing Neighbor with partial_cmp" + #[allow(clippy::panic)] + fn partial_cmp(&self, _: &Self) -> Option { + panic!("Neighbor only allows eq and lt") + } +} + +#[cfg(test)] +mod neighbor_test { + use super::*; + + #[test] + fn eq_lt_works() { + let n1 = Neighbor::new(1, 1.1); + let n2 = Neighbor::new(2, 2.0); + let n3 = Neighbor::new(1, 1.1); + + assert!(n1 != n2); + assert!(n1 < n2); + assert!(n1 == n3); + } + + #[test] + #[should_panic] + fn gt_should_panic() { + let n1 = Neighbor::new(1, 1.1); + let n2 = Neighbor::new(2, 2.0); + + assert!(n2 > n1); + } + + #[test] + #[should_panic] + fn le_should_panic() { + let n1 = Neighbor::new(1, 1.1); + let n2 = Neighbor::new(2, 2.0); + + assert!(n1 <= n2); + } +} + diff --git a/rust/diskann/src/model/neighbor/neighbor_priority_queue.rs b/rust/diskann/src/model/neighbor/neighbor_priority_queue.rs new file mode 100644 index 000000000..81b161026 --- /dev/null +++ b/rust/diskann/src/model/neighbor/neighbor_priority_queue.rs @@ -0,0 +1,241 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::model::Neighbor; + +/// Neighbor priority Queue based on the distance to the query node +#[derive(Debug)] +pub struct NeighborPriorityQueue { + /// The size of the priority queue + size: usize, + + /// The capacity of the priority queue + capacity: usize, + + /// The current notvisited neighbor whose distance is smallest among all notvisited neighbor + cur: usize, + + /// The neighbor collection + data: Vec, +} + +impl Default for NeighborPriorityQueue { + fn default() -> Self { + Self::new() + } +} + +impl NeighborPriorityQueue { + /// Create NeighborPriorityQueue without capacity + pub fn new() -> Self { + Self { + size: 0, + capacity: 0, + cur: 0, + data: Vec::new(), + } + } + + /// Create NeighborPriorityQueue with capacity + pub fn with_capacity(capacity: usize) -> Self { + Self { + size: 0, + capacity, + cur: 0, + data: vec![Neighbor::default(); capacity + 1], + } + } + + /// Inserts item with order. + /// The item will be dropped if queue is full / already exist in queue / it has a greater distance than the last item. + /// The set cursor that is used to pop() the next item will be set to the lowest index of an uncheck item. + pub fn insert(&mut self, nbr: Neighbor) { + if self.size == self.capacity && self.get_at(self.size - 1) < &nbr { + return; + } + + let mut lo = 0; + let mut hi = self.size; + while lo < hi { + let mid = (lo + hi) >> 1; + if &nbr < self.get_at(mid) { + hi = mid; + } else if self.get_at(mid).id == nbr.id { + // Make sure the same neighbor isn't inserted into the set + return; + } else { + lo = mid + 1; + } + } + + if lo < self.capacity { + self.data.copy_within(lo..self.size, lo + 1); + } + self.data[lo] = Neighbor::new(nbr.id, nbr.distance); + if self.size < self.capacity { + self.size += 1; + } + if lo < self.cur { + self.cur = lo; + } + } + + /// Get the neighbor at index - SAFETY: index must be less than size + fn get_at(&self, index: usize) -> &Neighbor { + unsafe { self.data.get_unchecked(index) } + } + + /// Get the closest and notvisited neighbor + pub fn closest_notvisited(&mut self) -> Neighbor { + self.data[self.cur].visited = true; + let pre = self.cur; + while self.cur < self.size && self.get_at(self.cur).visited { + self.cur += 1; + } + self.data[pre] + } + + /// Whether there is notvisited node or not + pub fn has_notvisited_node(&self) -> bool { + self.cur < self.size + } + + /// Get the size of the NeighborPriorityQueue + pub fn size(&self) -> usize { + self.size + } + + /// Get the capacity of the NeighborPriorityQueue + pub fn capacity(&self) -> usize { + self.capacity + } + + /// Sets an artificial capacity of the NeighborPriorityQueue. For benchmarking purposes only. + pub fn set_capacity(&mut self, capacity: usize) { + if capacity < self.data.len() { + self.capacity = capacity; + } + } + + /// Reserve capacity + pub fn reserve(&mut self, capacity: usize) { + if capacity > self.capacity { + self.data.resize(capacity + 1, Neighbor::default()); + self.capacity = capacity; + } + } + + /// Set size and cur to 0 + pub fn clear(&mut self) { + self.size = 0; + self.cur = 0; + } +} + +impl std::ops::Index for NeighborPriorityQueue { + type Output = Neighbor; + + fn index(&self, i: usize) -> &Self::Output { + &self.data[i] + } +} + +#[cfg(test)] +mod neighbor_priority_queue_test { + use super::*; + + #[test] + fn test_reserve_capacity() { + let mut queue = NeighborPriorityQueue::with_capacity(10); + assert_eq!(queue.capacity(), 10); + queue.reserve(20); + assert_eq!(queue.capacity(), 20); + } + + #[test] + fn test_insert() { + let mut queue = NeighborPriorityQueue::with_capacity(3); + assert_eq!(queue.size(), 0); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + assert_eq!(queue.size(), 2); + queue.insert(Neighbor::new(2, 0.5)); // should be ignored as the same neighbor + assert_eq!(queue.size(), 2); + queue.insert(Neighbor::new(3, 0.9)); + assert_eq!(queue.size(), 3); + assert_eq!(queue[2].id, 1); + queue.insert(Neighbor::new(4, 2.0)); // should be dropped as queue is full and distance is greater than last item + assert_eq!(queue.size(), 3); + assert_eq!(queue[0].id, 2); // node id in queue should be [2,3,1] + assert_eq!(queue[1].id, 3); + assert_eq!(queue[2].id, 1); + println!("{:?}", queue); + } + + #[test] + fn test_index() { + let mut queue = NeighborPriorityQueue::with_capacity(3); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + assert_eq!(queue[0].id, 2); + assert_eq!(queue[0].distance, 0.5); + } + + #[test] + fn test_visit() { + let mut queue = NeighborPriorityQueue::with_capacity(3); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); // node id in queue should be [2,1,3] + assert!(queue.has_notvisited_node()); + let nbr = queue.closest_notvisited(); + assert_eq!(nbr.id, 2); + assert_eq!(nbr.distance, 0.5); + assert!(nbr.visited); + assert!(queue.has_notvisited_node()); + let nbr = queue.closest_notvisited(); + assert_eq!(nbr.id, 1); + assert_eq!(nbr.distance, 1.0); + assert!(nbr.visited); + assert!(queue.has_notvisited_node()); + let nbr = queue.closest_notvisited(); + assert_eq!(nbr.id, 3); + assert_eq!(nbr.distance, 1.5); + assert!(nbr.visited); + assert!(!queue.has_notvisited_node()); + } + + #[test] + fn test_clear_queue() { + let mut queue = NeighborPriorityQueue::with_capacity(3); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + assert_eq!(queue.size(), 2); + assert!(queue.has_notvisited_node()); + queue.clear(); + assert_eq!(queue.size(), 0); + assert!(!queue.has_notvisited_node()); + } + + #[test] + fn test_reserve() { + let mut queue = NeighborPriorityQueue::new(); + queue.reserve(10); + assert_eq!(queue.data.len(), 11); + assert_eq!(queue.capacity, 10); + } + + #[test] + fn test_set_capacity() { + let mut queue = NeighborPriorityQueue::with_capacity(10); + queue.set_capacity(5); + assert_eq!(queue.capacity, 5); + assert_eq!(queue.data.len(), 11); + + queue.set_capacity(11); + assert_eq!(queue.capacity, 5); + } +} + diff --git a/rust/diskann/src/model/neighbor/sorted_neighbor_vector.rs b/rust/diskann/src/model/neighbor/sorted_neighbor_vector.rs new file mode 100644 index 000000000..4c3eff00f --- /dev/null +++ b/rust/diskann/src/model/neighbor/sorted_neighbor_vector.rs @@ -0,0 +1,37 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Sorted Neighbor Vector + +use std::ops::{Deref, DerefMut}; + +use super::Neighbor; + +/// A newtype on top of vector of neighbors, is sorted by distance +#[derive(Debug)] +pub struct SortedNeighborVector<'a>(&'a mut Vec); + +impl<'a> SortedNeighborVector<'a> { + /// Create a new SortedNeighborVector + pub fn new(vec: &'a mut Vec) -> Self { + vec.sort_unstable(); + Self(vec) + } +} + +impl<'a> Deref for SortedNeighborVector<'a> { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl<'a> DerefMut for SortedNeighborVector<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0 + } +} diff --git a/rust/diskann/src/model/pq/fixed_chunk_pq_table.rs b/rust/diskann/src/model/pq/fixed_chunk_pq_table.rs new file mode 100644 index 000000000..bfedcae6e --- /dev/null +++ b/rust/diskann/src/model/pq/fixed_chunk_pq_table.rs @@ -0,0 +1,483 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations)] + +use hashbrown::HashMap; +use rayon::prelude::{ + IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator, ParallelSliceMut, +}; +use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0}; + +use crate::{ + common::{ANNError, ANNResult}, + model::NUM_PQ_CENTROIDS, +}; + +/// PQ Pivot table loading and calculate distance +#[derive(Debug)] +pub struct FixedChunkPQTable { + /// pq_tables = float array of size [256 * ndims] + pq_table: Vec, + + /// ndims = true dimension of vectors + dim: usize, + + /// num_pq_chunks = the pq chunk number + num_pq_chunks: usize, + + /// chunk_offsets = the offset of each chunk, start from 0 + chunk_offsets: Vec, + + /// centroid of each dimension + centroids: Vec, + + /// Becasue we're using L2 distance, this is no needed now. + /// Transport of pq_table. transport_pq_table = float array of size [ndims * 256]. + /// e.g. if pa_table is 2 centroids * 3 dims + /// [ 1, 2, 3, + /// 4, 5, 6] + /// then transport_pq_table would be 3 dims * 2 centroids + /// [ 1, 4, + /// 2, 5, + /// 3, 6] + /// transport_pq_table: Vec, + + /// Map dim offset to chunk index e.g., 8 dims in to 2 chunks + /// then would be [(0,0), (1,0), (2,0), (3,0), (4,1), (5,1), (6,1), (7,1)] + dimoffset_chunk_mapping: HashMap, +} + +impl FixedChunkPQTable { + /// Create the FixedChunkPQTable with dim and chunk numbers and pivot file data (pivot table + cenroids + chunk offsets) + pub fn new( + dim: usize, + num_pq_chunks: usize, + pq_table: Vec, + centroids: Vec, + chunk_offsets: Vec, + ) -> Self { + let mut dimoffset_chunk_mapping = HashMap::new(); + for chunk_index in 0..num_pq_chunks { + for dim_offset in chunk_offsets[chunk_index]..chunk_offsets[chunk_index + 1] { + dimoffset_chunk_mapping.insert(dim_offset, chunk_index); + } + } + + Self { + pq_table, + dim, + num_pq_chunks, + chunk_offsets, + centroids, + dimoffset_chunk_mapping, + } + } + + /// Get chunk number + pub fn get_num_chunks(&self) -> usize { + self.num_pq_chunks + } + + /// Shifting the query according to mean or the whole corpus + pub fn preprocess_query(&self, query_vec: &mut [f32]) { + for (query, ¢roid) in query_vec.iter_mut().zip(self.centroids.iter()) { + *query -= centroid; + } + } + + /// Pre-calculated the distance between query and each centroid by l2 distance + /// * `query_vec` - query vector: 1 * dim + /// * `dist_vec` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids + #[allow(clippy::needless_range_loop)] + pub fn populate_chunk_distances(&self, query_vec: &[f32]) -> Vec { + let mut dist_vec = vec![0.0; self.num_pq_chunks * NUM_PQ_CENTROIDS]; + for centroid_index in 0..NUM_PQ_CENTROIDS { + for chunk_index in 0..self.num_pq_chunks { + for dim_offset in + self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1] + { + let diff: f32 = self.pq_table[self.dim * centroid_index + dim_offset] + - query_vec[dim_offset]; + dist_vec[chunk_index * NUM_PQ_CENTROIDS + centroid_index] += diff * diff; + } + } + } + dist_vec + } + + /// Pre-calculated the distance between query and each centroid by inner product + /// * `query_vec` - query vector: 1 * dim + /// * `dist_vec` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids + /// + /// Reason to allow clippy::needless_range_loop: + /// The inner loop is operating over a range that is different for each iteration of the outer loop. + /// This isn't a scenario where using iter().enumerate() would be easily applicable, + /// because the inner loop isn't iterating directly over the contents of a slice or array. + /// Thus, using indexing might be the most straightforward way to express this logic. + #[allow(clippy::needless_range_loop)] + pub fn populate_chunk_inner_products(&self, query_vec: &[f32]) -> Vec { + let mut dist_vec = vec![0.0; self.num_pq_chunks * NUM_PQ_CENTROIDS]; + for centroid_index in 0..NUM_PQ_CENTROIDS { + for chunk_index in 0..self.num_pq_chunks { + for dim_offset in + self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1] + { + // assumes that we are not shifting the vectors to mean zero, i.e., centroid + // array should be all zeros returning negative to keep the search code + // clean (max inner product vs min distance) + let diff: f32 = self.pq_table[self.dim * centroid_index + dim_offset] + * query_vec[dim_offset]; + dist_vec[chunk_index * NUM_PQ_CENTROIDS + centroid_index] -= diff; + } + } + } + dist_vec + } + + /// Calculate the distance between query and given centroid by l2 distance + /// * `query_vec` - query vector: 1 * dim + /// * `base_vec` - given centroid array: 1 * num_pq_chunks + #[allow(clippy::needless_range_loop)] + pub fn l2_distance(&self, query_vec: &[f32], base_vec: &[u8]) -> f32 { + let mut res_vec: Vec = vec![0.0; self.num_pq_chunks]; + res_vec + .par_iter_mut() + .enumerate() + .for_each(|(chunk_index, chunk_diff)| { + for dim_offset in + self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1] + { + let diff = self.pq_table + [self.dim * base_vec[chunk_index] as usize + dim_offset] + - query_vec[dim_offset]; + *chunk_diff += diff * diff; + } + }); + + let res: f32 = res_vec.iter().sum::(); + + res + } + + /// Calculate the distance between query and given centroid by inner product + /// * `query_vec` - query vector: 1 * dim + /// * `base_vec` - given centroid array: 1 * num_pq_chunks + #[allow(clippy::needless_range_loop)] + pub fn inner_product(&self, query_vec: &[f32], base_vec: &[u8]) -> f32 { + let mut res_vec: Vec = vec![0.0; self.num_pq_chunks]; + res_vec + .par_iter_mut() + .enumerate() + .for_each(|(chunk_index, chunk_diff)| { + for dim_offset in + self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1] + { + *chunk_diff += self.pq_table + [self.dim * base_vec[chunk_index] as usize + dim_offset] + * query_vec[dim_offset]; + } + }); + + let res: f32 = res_vec.iter().sum::(); + + // returns negative value to simulate distances (max -> min conversion) + -res + } + + /// Revert vector by adding centroid + /// * `base_vec` - given centroid array: 1 * num_pq_chunks + /// * `out_vec` - reverted vector + pub fn inflate_vector(&self, base_vec: &[u8]) -> ANNResult> { + let mut out_vec: Vec = vec![0.0; self.dim]; + for (dim_offset, value) in out_vec.iter_mut().enumerate() { + let chunk_index = + self.dimoffset_chunk_mapping + .get(&dim_offset) + .ok_or(ANNError::log_pq_error( + "ERROR: dim_offset not found in dimoffset_chunk_mapping".to_string(), + ))?; + *value = self.pq_table[self.dim * base_vec[*chunk_index] as usize + dim_offset] + + self.centroids[dim_offset]; + } + + Ok(out_vec) + } +} + +/// Given a batch input nodes, return a batch of PQ distance +/// * `pq_ids` - batch nodes: n_pts * pq_nchunks +/// * `n_pts` - batch number +/// * `pq_nchunks` - pq chunk number number +/// * `pq_dists` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids +/// * `dists_out` - n_pts * 1 +pub fn pq_dist_lookup( + pq_ids: &[u8], + n_pts: usize, + pq_nchunks: usize, + pq_dists: &[f32], +) -> Vec { + let mut dists_out: Vec = vec![0.0; n_pts]; + unsafe { + _mm_prefetch(dists_out.as_ptr() as *const i8, _MM_HINT_T0); + _mm_prefetch(pq_ids.as_ptr() as *const i8, _MM_HINT_T0); + _mm_prefetch(pq_ids.as_ptr().add(64) as *const i8, _MM_HINT_T0); + _mm_prefetch(pq_ids.as_ptr().add(128) as *const i8, _MM_HINT_T0); + } + for chunk in 0..pq_nchunks { + let chunk_dists = &pq_dists[256 * chunk..]; + if chunk < pq_nchunks - 1 { + unsafe { + _mm_prefetch( + chunk_dists.as_ptr().offset(256 * chunk as isize).add(256) as *const i8, + _MM_HINT_T0, + ); + } + } + dists_out + .par_iter_mut() + .enumerate() + .for_each(|(n_iter, dist)| { + let pq_centerid = pq_ids[pq_nchunks * n_iter + chunk]; + *dist += chunk_dists[pq_centerid as usize]; + }); + } + dists_out +} + +pub fn aggregate_coords(ids: &[u32], all_coords: &[u8], ndims: usize) -> Vec { + let mut out: Vec = vec![0u8; ids.len() * ndims]; + let ndim_u32 = ndims as u32; + out.par_chunks_mut(ndims) + .enumerate() + .for_each(|(index, chunk)| { + let id_compressed_pivot = &all_coords + [(ids[index] * ndim_u32) as usize..(ids[index] * ndim_u32 + ndim_u32) as usize]; + let temp_slice = + unsafe { std::slice::from_raw_parts(id_compressed_pivot.as_ptr(), ndims) }; + chunk.copy_from_slice(temp_slice); + }); + + out +} + +#[cfg(test)] +mod fixed_chunk_pq_table_test { + + use super::*; + use crate::common::{ANNError, ANNResult}; + use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, file_exists, load_bin}; + + const DIM: usize = 128; + + #[test] + fn load_pivot_test() { + let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin"; + let (dim, pq_table, centroids, chunk_offsets) = + load_pq_pivots_bin(pq_pivots_path, &1).unwrap(); + let fixed_chunk_pq_table = + FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets); + + assert_eq!(dim, DIM); + assert_eq!(fixed_chunk_pq_table.pq_table.len(), DIM * NUM_PQ_CENTROIDS); + assert_eq!(fixed_chunk_pq_table.centroids.len(), DIM); + + assert_eq!(fixed_chunk_pq_table.chunk_offsets[0], 0); + assert_eq!(fixed_chunk_pq_table.chunk_offsets[1], DIM); + assert_eq!(fixed_chunk_pq_table.chunk_offsets.len(), 2); + } + + #[test] + fn get_num_chunks_test() { + let num_chunks = 7; + let pa_table = vec![0.0; DIM * NUM_PQ_CENTROIDS]; + let centroids = vec![0.0; DIM]; + let chunk_offsets = vec![0, 7, 9, 11, 22, 34, 78, 127]; + let fixed_chunk_pq_table = + FixedChunkPQTable::new(DIM, num_chunks, pa_table, centroids, chunk_offsets); + let chunk: usize = fixed_chunk_pq_table.get_num_chunks(); + assert_eq!(chunk, num_chunks); + } + + #[test] + fn preprocess_query_test() { + let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin"; + let (dim, pq_table, centroids, chunk_offsets) = + load_pq_pivots_bin(pq_pivots_path, &1).unwrap(); + let fixed_chunk_pq_table = + FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets); + + let mut query_vec: Vec = vec![ + 32.39f32, 78.57f32, 50.32f32, 80.46f32, 6.47f32, 69.76f32, 94.2f32, 83.36f32, 5.8f32, + 68.78f32, 42.32f32, 61.77f32, 90.26f32, 60.41f32, 3.86f32, 61.21f32, 16.6f32, 54.46f32, + 7.29f32, 54.24f32, 92.49f32, 30.18f32, 65.36f32, 99.09f32, 3.8f32, 36.4f32, 86.72f32, + 65.18f32, 29.87f32, 62.21f32, 58.32f32, 43.23f32, 94.3f32, 79.61f32, 39.67f32, + 11.18f32, 48.88f32, 38.19f32, 93.95f32, 10.46f32, 36.7f32, 14.75f32, 81.64f32, + 59.18f32, 99.03f32, 74.23f32, 1.26f32, 82.69f32, 35.7f32, 38.39f32, 46.17f32, 64.75f32, + 7.15f32, 36.55f32, 77.32f32, 18.65f32, 32.8f32, 74.84f32, 18.12f32, 20.19f32, 70.06f32, + 48.37f32, 40.18f32, 45.69f32, 88.3f32, 39.15f32, 60.97f32, 71.29f32, 61.79f32, + 47.23f32, 94.71f32, 58.04f32, 52.4f32, 34.66f32, 59.1f32, 47.11f32, 30.2f32, 58.72f32, + 74.35f32, 83.68f32, 66.8f32, 28.57f32, 29.45f32, 52.02f32, 91.95f32, 92.44f32, + 65.25f32, 38.3f32, 35.6f32, 41.67f32, 91.33f32, 76.81f32, 74.88f32, 33.17f32, 48.36f32, + 41.42f32, 23f32, 8.31f32, 81.69f32, 80.08f32, 50.55f32, 54.46f32, 23.79f32, 43.46f32, + 84.5f32, 10.42f32, 29.51f32, 19.73f32, 46.48f32, 35.01f32, 52.3f32, 66.97f32, 4.8f32, + 74.81f32, 2.82f32, 61.82f32, 25.06f32, 17.3f32, 17.29f32, 63.2f32, 64.1f32, 61.68f32, + 37.42f32, 3.39f32, 97.45f32, 5.32f32, 59.02f32, 35.6f32, + ]; + fixed_chunk_pq_table.preprocess_query(&mut query_vec); + assert_eq!(query_vec[0], 32.39f32 - fixed_chunk_pq_table.centroids[0]); + assert_eq!( + query_vec[127], + 35.6f32 - fixed_chunk_pq_table.centroids[127] + ); + } + + #[test] + fn calculate_distances_tests() { + let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin"; + + let (dim, pq_table, centroids, chunk_offsets) = + load_pq_pivots_bin(pq_pivots_path, &1).unwrap(); + let fixed_chunk_pq_table = + FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets); + + let query_vec: Vec = vec![ + 32.39f32, 78.57f32, 50.32f32, 80.46f32, 6.47f32, 69.76f32, 94.2f32, 83.36f32, 5.8f32, + 68.78f32, 42.32f32, 61.77f32, 90.26f32, 60.41f32, 3.86f32, 61.21f32, 16.6f32, 54.46f32, + 7.29f32, 54.24f32, 92.49f32, 30.18f32, 65.36f32, 99.09f32, 3.8f32, 36.4f32, 86.72f32, + 65.18f32, 29.87f32, 62.21f32, 58.32f32, 43.23f32, 94.3f32, 79.61f32, 39.67f32, + 11.18f32, 48.88f32, 38.19f32, 93.95f32, 10.46f32, 36.7f32, 14.75f32, 81.64f32, + 59.18f32, 99.03f32, 74.23f32, 1.26f32, 82.69f32, 35.7f32, 38.39f32, 46.17f32, 64.75f32, + 7.15f32, 36.55f32, 77.32f32, 18.65f32, 32.8f32, 74.84f32, 18.12f32, 20.19f32, 70.06f32, + 48.37f32, 40.18f32, 45.69f32, 88.3f32, 39.15f32, 60.97f32, 71.29f32, 61.79f32, + 47.23f32, 94.71f32, 58.04f32, 52.4f32, 34.66f32, 59.1f32, 47.11f32, 30.2f32, 58.72f32, + 74.35f32, 83.68f32, 66.8f32, 28.57f32, 29.45f32, 52.02f32, 91.95f32, 92.44f32, + 65.25f32, 38.3f32, 35.6f32, 41.67f32, 91.33f32, 76.81f32, 74.88f32, 33.17f32, 48.36f32, + 41.42f32, 23f32, 8.31f32, 81.69f32, 80.08f32, 50.55f32, 54.46f32, 23.79f32, 43.46f32, + 84.5f32, 10.42f32, 29.51f32, 19.73f32, 46.48f32, 35.01f32, 52.3f32, 66.97f32, 4.8f32, + 74.81f32, 2.82f32, 61.82f32, 25.06f32, 17.3f32, 17.29f32, 63.2f32, 64.1f32, 61.68f32, + 37.42f32, 3.39f32, 97.45f32, 5.32f32, 59.02f32, 35.6f32, + ]; + + let dist_vec = fixed_chunk_pq_table.populate_chunk_distances(&query_vec); + assert_eq!(dist_vec.len(), 256); + + // populate_chunk_distances_test + let mut sampled_output = 0.0; + (0..DIM).for_each(|dim_offset| { + let diff = fixed_chunk_pq_table.pq_table[dim_offset] - query_vec[dim_offset]; + sampled_output += diff * diff; + }); + assert_eq!(sampled_output, dist_vec[0]); + + // populate_chunk_inner_products_test + let dist_vec = fixed_chunk_pq_table.populate_chunk_inner_products(&query_vec); + assert_eq!(dist_vec.len(), 256); + + let mut sampled_output = 0.0; + (0..DIM).for_each(|dim_offset| { + sampled_output -= fixed_chunk_pq_table.pq_table[dim_offset] * query_vec[dim_offset]; + }); + assert_eq!(sampled_output, dist_vec[0]); + + // l2_distance_test + let base_vec: Vec = vec![3u8]; + let dist = fixed_chunk_pq_table.l2_distance(&query_vec, &base_vec); + let mut l2_output = 0.0; + (0..DIM).for_each(|dim_offset| { + let diff = fixed_chunk_pq_table.pq_table[3 * DIM + dim_offset] - query_vec[dim_offset]; + l2_output += diff * diff; + }); + assert_eq!(l2_output, dist); + + // inner_product_test + let dist = fixed_chunk_pq_table.inner_product(&query_vec, &base_vec); + let mut l2_output = 0.0; + (0..DIM).for_each(|dim_offset| { + l2_output -= + fixed_chunk_pq_table.pq_table[3 * DIM + dim_offset] * query_vec[dim_offset]; + }); + assert_eq!(l2_output, dist); + + // inflate_vector_test + let inflate_vector = fixed_chunk_pq_table.inflate_vector(&base_vec).unwrap(); + assert_eq!(inflate_vector.len(), DIM); + assert_eq!( + inflate_vector[0], + fixed_chunk_pq_table.pq_table[3 * DIM] + fixed_chunk_pq_table.centroids[0] + ); + assert_eq!( + inflate_vector[1], + fixed_chunk_pq_table.pq_table[3 * DIM + 1] + fixed_chunk_pq_table.centroids[1] + ); + assert_eq!( + inflate_vector[127], + fixed_chunk_pq_table.pq_table[3 * DIM + 127] + fixed_chunk_pq_table.centroids[127] + ); + } + + fn load_pq_pivots_bin( + pq_pivots_path: &str, + num_pq_chunks: &usize, + ) -> ANNResult<(usize, Vec, Vec, Vec)> { + if !file_exists(pq_pivots_path) { + return Err(ANNError::log_pq_error( + "ERROR: PQ k-means pivot file not found.".to_string(), + )); + } + + let (data, offset_num, offset_dim) = load_bin::(pq_pivots_path, 0)?; + let file_offset_data = convert_types_u64_usize(&data, offset_num, offset_dim); + if offset_num != 4 { + let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", pq_pivots_path, offset_num); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, pq_center_num, dim) = load_bin::(pq_pivots_path, file_offset_data[0])?; + let pq_table = data.to_vec(); + if pq_center_num != NUM_PQ_CENTROIDS { + let error_message = format!( + "Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.", + pq_pivots_path, pq_center_num, NUM_PQ_CENTROIDS + ); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, centroid_dim, nc) = load_bin::(pq_pivots_path, file_offset_data[1])?; + let centroids = data.to_vec(); + if centroid_dim != dim || nc != 1 { + let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", pq_pivots_path, centroid_dim, nc, dim); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, chunk_offset_num, nc) = load_bin::(pq_pivots_path, file_offset_data[2])?; + let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_num, nc); + if chunk_offset_num != num_pq_chunks + 1 || nc != 1 { + let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_num, nc, num_pq_chunks + 1); + return Err(ANNError::log_pq_error(error_message)); + } + + Ok((dim, pq_table, centroids, chunk_offsets)) + } +} + +#[cfg(test)] +mod pq_index_prune_query_test { + + use super::*; + + #[test] + fn pq_dist_lookup_test() { + let pq_ids: Vec = vec![1u8, 3u8, 2u8, 2u8]; + let mut pq_dists: Vec = Vec::with_capacity(256 * 2); + for _ in 0..pq_dists.capacity() { + pq_dists.push(rand::random()); + } + + let dists_out = pq_dist_lookup(&pq_ids, 2, 2, &pq_dists); + assert_eq!(dists_out.len(), 2); + assert_eq!(dists_out[0], pq_dists[0 + 1] + pq_dists[256 + 3]); + assert_eq!(dists_out[1], pq_dists[0 + 2] + pq_dists[256 + 2]); + } +} diff --git a/rust/diskann/src/model/pq/mod.rs b/rust/diskann/src/model/pq/mod.rs new file mode 100644 index 000000000..85daaa7c6 --- /dev/null +++ b/rust/diskann/src/model/pq/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod fixed_chunk_pq_table; +pub use fixed_chunk_pq_table::*; + +mod pq_construction; +pub use pq_construction::*; diff --git a/rust/diskann/src/model/pq/pq_construction.rs b/rust/diskann/src/model/pq/pq_construction.rs new file mode 100644 index 000000000..0a7b0784e --- /dev/null +++ b/rust/diskann/src/model/pq/pq_construction.rs @@ -0,0 +1,398 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations)] + +use rayon::prelude::{IndexedParallelIterator, ParallelIterator}; +use rayon::slice::ParallelSliceMut; + +use crate::common::{ANNError, ANNResult}; +use crate::storage::PQStorage; +use crate::utils::{compute_closest_centers, file_exists, k_means_clustering}; + +/// Max size of PQ training set +pub const MAX_PQ_TRAINING_SET_SIZE: f64 = 256_000f64; + +/// Max number of PQ chunks +pub const MAX_PQ_CHUNKS: usize = 512; + +pub const NUM_PQ_CENTROIDS: usize = 256; +/// block size for reading/processing large files and matrices in blocks +const BLOCK_SIZE: usize = 5000000; +const NUM_KMEANS_REPS_PQ: usize = 12; + +/// given training data in train_data of dimensions num_train * dim, generate +/// PQ pivots using k-means algorithm to partition the co-ordinates into +/// num_pq_chunks (if it divides dimension, else rounded) chunks, and runs +/// k-means in each chunk to compute the PQ pivots and stores in bin format in +/// file pq_pivots_path as a s num_centers*dim floating point binary file +/// PQ pivot table layout: {pivot offsets data: METADATA_SIZE}{pivot vector:[dim; num_centroid]}{centroid vector:[dim; 1]}{chunk offsets:[chunk_num+1; 1]} +fn generate_pq_pivots( + train_data: &mut [f32], + num_train: usize, + dim: usize, + num_centers: usize, + num_pq_chunks: usize, + max_k_means_reps: usize, + pq_storage: &mut PQStorage, +) -> ANNResult<()> { + if num_pq_chunks > dim { + return Err(ANNError::log_pq_error( + "Error: number of chunks more than dimension.".to_string(), + )); + } + + if pq_storage.pivot_data_exist() { + let (file_num_centers, file_dim) = pq_storage.read_pivot_metadata()?; + if file_dim == dim && file_num_centers == num_centers { + // PQ pivot file exists. Not generating again. + return Ok(()); + } + } + + // Calculate centroid and center the training data + // If we use L2 distance, there is an option to + // translate all vectors to make them centered and + // then compute PQ. This needs to be set to false + // when using PQ for MIPS as such translations dont + // preserve inner products. + // Now, we're using L2 as default. + let mut centroid: Vec = vec![0.0; dim]; + for dim_index in 0..dim { + for train_data_index in 0..num_train { + centroid[dim_index] += train_data[train_data_index * dim + dim_index]; + } + centroid[dim_index] /= num_train as f32; + } + for dim_index in 0..dim { + for train_data_index in 0..num_train { + train_data[train_data_index * dim + dim_index] -= centroid[dim_index]; + } + } + + // Calculate each chunk's offset + // If we have 8 dimension and 3 chunk then offsets would be [0,3,6,8] + let mut chunk_offsets: Vec = vec![0; num_pq_chunks + 1]; + let mut chunk_offset: usize = 0; + for chunk_index in 0..num_pq_chunks { + chunk_offset += dim / num_pq_chunks; + if chunk_index < (dim % num_pq_chunks) { + chunk_offset += 1; + } + chunk_offsets[chunk_index + 1] = chunk_offset; + } + + let mut full_pivot_data: Vec = vec![0.0; num_centers * dim]; + for chunk_index in 0..num_pq_chunks { + let chunk_size = chunk_offsets[chunk_index + 1] - chunk_offsets[chunk_index]; + + let mut cur_train_data: Vec = vec![0.0; num_train * chunk_size]; + let mut cur_pivot_data: Vec = vec![0.0; num_centers * chunk_size]; + + cur_train_data + .par_chunks_mut(chunk_size) + .enumerate() + .for_each(|(train_data_index, chunk)| { + for (dim_offset, item) in chunk.iter_mut().enumerate() { + *item = train_data + [train_data_index * dim + chunk_offsets[chunk_index] + dim_offset]; + } + }); + + // Run kmeans to get the centroids of this chunk. + let (_closest_docs, _closest_center, _residual) = k_means_clustering( + &cur_train_data, + num_train, + chunk_size, + &mut cur_pivot_data, + num_centers, + max_k_means_reps, + )?; + + // Copy centroids from this chunk table to full table + for center_index in 0..num_centers { + full_pivot_data[center_index * dim + chunk_offsets[chunk_index] + ..center_index * dim + chunk_offsets[chunk_index + 1]] + .copy_from_slice( + &cur_pivot_data[center_index * chunk_size..(center_index + 1) * chunk_size], + ); + } + } + + pq_storage.write_pivot_data( + &full_pivot_data, + ¢roid, + &chunk_offsets, + num_centers, + dim, + )?; + + Ok(()) +} + +/// streams the base file (data_file), and computes the closest centers in each +/// chunk to generate the compressed data_file and stores it in +/// pq_compressed_vectors_path. +/// If the numbber of centers is < 256, it stores as byte vector, else as +/// 4-byte vector in binary format. +/// Compressed PQ table layout: {num_points: usize}{num_chunks: usize}{compressed pq table: [num_points; num_chunks]} +fn generate_pq_data_from_pivots>( + num_centers: usize, + num_pq_chunks: usize, + pq_storage: &mut PQStorage, +) -> ANNResult<()> { + let (num_points, dim) = pq_storage.read_pq_data_metadata()?; + + let full_pivot_data: Vec; + let centroid: Vec; + let chunk_offsets: Vec; + + if !pq_storage.pivot_data_exist() { + return Err(ANNError::log_pq_error( + "ERROR: PQ k-means pivot file not found.".to_string(), + )); + } else { + (full_pivot_data, centroid, chunk_offsets) = + pq_storage.load_pivot_data(&num_pq_chunks, &num_centers, &dim)?; + } + + pq_storage.write_compressed_pivot_metadata(num_points as i32, num_pq_chunks as i32)?; + + let block_size = if num_points <= BLOCK_SIZE { + num_points + } else { + BLOCK_SIZE + }; + let num_blocks = (num_points / block_size) + (num_points % block_size != 0) as usize; + + for block_index in 0..num_blocks { + let start_index: usize = block_index * block_size; + let end_index: usize = std::cmp::min((block_index + 1) * block_size, num_points); + let cur_block_size: usize = end_index - start_index; + + let mut block_compressed_base: Vec = vec![0; cur_block_size * num_pq_chunks]; + + let block_data: Vec = pq_storage.read_pq_block_data(cur_block_size, dim)?; + + let mut adjusted_block_data: Vec = vec![0.0; cur_block_size * dim]; + + for block_data_index in 0..cur_block_size { + for dim_index in 0..dim { + adjusted_block_data[block_data_index * dim + dim_index] = + block_data[block_data_index * dim + dim_index].into() - centroid[dim_index]; + } + } + + for chunk_index in 0..num_pq_chunks { + let cur_chunk_size = chunk_offsets[chunk_index + 1] - chunk_offsets[chunk_index]; + if cur_chunk_size == 0 { + continue; + } + + let mut cur_pivot_data: Vec = vec![0.0; num_centers * cur_chunk_size]; + let mut cur_data: Vec = vec![0.0; cur_block_size * cur_chunk_size]; + let mut closest_center: Vec = vec![0; cur_block_size]; + + // Divide the data into chunks and process each chunk in parallel. + cur_data + .par_chunks_mut(cur_chunk_size) + .enumerate() + .for_each(|(block_data_index, chunk)| { + for (dim_offset, item) in chunk.iter_mut().enumerate() { + *item = adjusted_block_data + [block_data_index * dim + chunk_offsets[chunk_index] + dim_offset]; + } + }); + + cur_pivot_data + .par_chunks_mut(cur_chunk_size) + .enumerate() + .for_each(|(center_index, chunk)| { + for (din_offset, item) in chunk.iter_mut().enumerate() { + *item = full_pivot_data + [center_index * dim + chunk_offsets[chunk_index] + din_offset]; + } + }); + + // Compute the closet centers + compute_closest_centers( + &cur_data, + cur_block_size, + cur_chunk_size, + &cur_pivot_data, + num_centers, + 1, + &mut closest_center, + None, + None, + )?; + + block_compressed_base + .par_chunks_mut(num_pq_chunks) + .enumerate() + .for_each(|(block_data_index, slice)| { + slice[chunk_index] = closest_center[block_data_index] as usize; + }); + } + + _ = pq_storage.write_compressed_pivot_data( + &block_compressed_base, + num_centers, + cur_block_size, + num_pq_chunks, + ); + } + Ok(()) +} + +/// Save the data on a file. +/// # Arguments +/// * `p_val` - choose how many ratio sample data as trained data to get pivot +/// * `num_pq_chunks` - pq chunk number +/// * `codebook_prefix` - predefined pivots file named +/// * `pq_storage` - pq file access +pub fn generate_quantized_data>( + p_val: f64, + num_pq_chunks: usize, + codebook_prefix: &str, + pq_storage: &mut PQStorage, +) -> ANNResult<()> { + // If predefined pivots already exists, skip training. + if !file_exists(codebook_prefix) { + // Instantiates train data with random sample updates train_data_vector + // Training data with train_size samples loaded. + // Each sampled file has train_dim. + let (mut train_data_vector, train_size, train_dim) = + pq_storage.gen_random_slice::(p_val)?; + + generate_pq_pivots( + &mut train_data_vector, + train_size, + train_dim, + NUM_PQ_CENTROIDS, + num_pq_chunks, + NUM_KMEANS_REPS_PQ, + pq_storage, + )?; + } + generate_pq_data_from_pivots::(NUM_PQ_CENTROIDS, num_pq_chunks, pq_storage)?; + Ok(()) +} + +#[cfg(test)] +mod pq_test { + + use std::fs::File; + use std::io::Write; + + use super::*; + use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, load_bin, METADATA_SIZE}; + + #[test] + fn generate_pq_pivots_test() { + let pivot_file_name = "generate_pq_pivots_test.bin"; + let compressed_file_name = "compressed.bin"; + let pq_training_file_name = "tests/data/siftsmall_learn.bin"; + let mut pq_storage = + PQStorage::new(pivot_file_name, compressed_file_name, pq_training_file_name).unwrap(); + let mut train_data: Vec = vec![ + 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, + 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, + 2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, + 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, + ]; + generate_pq_pivots(&mut train_data, 5, 8, 2, 2, 5, &mut pq_storage).unwrap(); + + let (data, nr, nc) = load_bin::(pivot_file_name, 0).unwrap(); + let file_offset_data = convert_types_u64_usize(&data, nr, nc); + assert_eq!(file_offset_data[0], METADATA_SIZE); + assert_eq!(nr, 4); + assert_eq!(nc, 1); + + let (data, nr, nc) = load_bin::(pivot_file_name, file_offset_data[0]).unwrap(); + let full_pivot_data = data.to_vec(); + assert_eq!(full_pivot_data.len(), 16); + assert_eq!(nr, 2); + assert_eq!(nc, 8); + + let (data, nr, nc) = load_bin::(pivot_file_name, file_offset_data[1]).unwrap(); + let centroid = data.to_vec(); + assert_eq!( + centroid[0], + (1.0f32 + 2.0f32 + 2.1f32 + 2.2f32 + 100.0f32) / 5.0f32 + ); + assert_eq!(nr, 8); + assert_eq!(nc, 1); + + let (data, nr, nc) = load_bin::(pivot_file_name, file_offset_data[2]).unwrap(); + let chunk_offsets = convert_types_u32_usize(&data, nr, nc); + assert_eq!(chunk_offsets[0], 0); + assert_eq!(chunk_offsets[1], 4); + assert_eq!(chunk_offsets[2], 8); + assert_eq!(nr, 3); + assert_eq!(nc, 1); + std::fs::remove_file(pivot_file_name).unwrap(); + } + + #[test] + fn generate_pq_data_from_pivots_test() { + let data_file = "generate_pq_data_from_pivots_test_data.bin"; + //npoints=5, dim=8, 5 vectors [1.0;8] [2.0;8] [2.1;8] [2.2;8] [100.0;8] + let mut train_data: Vec = vec![ + 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, + 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, + 2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, + 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, + ]; + let my_nums_unstructured: &[u8] = unsafe { + std::slice::from_raw_parts(train_data.as_ptr() as *const u8, train_data.len() * 4) + }; + let meta: Vec = vec![5, 8]; + let meta_unstructured: &[u8] = + unsafe { std::slice::from_raw_parts(meta.as_ptr() as *const u8, meta.len() * 4) }; + let mut data_file_writer = File::create(data_file).unwrap(); + data_file_writer + .write_all(meta_unstructured) + .expect("Failed to write sample file"); + data_file_writer + .write_all(my_nums_unstructured) + .expect("Failed to write sample file"); + + let pq_pivots_path = "generate_pq_data_from_pivots_test_pivot.bin"; + let pq_compressed_vectors_path = "generate_pq_data_from_pivots_test.bin"; + let mut pq_storage = + PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, data_file).unwrap(); + generate_pq_pivots(&mut train_data, 5, 8, 2, 2, 5, &mut pq_storage).unwrap(); + generate_pq_data_from_pivots::(2, 2, &mut pq_storage).unwrap(); + let (data, nr, nc) = load_bin::(pq_compressed_vectors_path, 0).unwrap(); + assert_eq!(nr, 5); + assert_eq!(nc, 2); + assert_eq!(data[0], data[2]); + assert_ne!(data[0], data[8]); + + std::fs::remove_file(data_file).unwrap(); + std::fs::remove_file(pq_pivots_path).unwrap(); + std::fs::remove_file(pq_compressed_vectors_path).unwrap(); + } + + #[test] + fn pq_end_to_end_validation_with_codebook_test() { + let data_file = "tests/data/siftsmall_learn.bin"; + let pq_pivots_path = "tests/data/siftsmall_learn.bin_pq_pivots.bin"; + let gound_truth_path = "tests/data/siftsmall_learn.bin_pq_compressed.bin"; + let pq_compressed_vectors_path = "validation.bin"; + let mut pq_storage = + PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, data_file).unwrap(); + generate_quantized_data::(0.5, 1, pq_pivots_path, &mut pq_storage).unwrap(); + + let (data, nr, nc) = load_bin::(pq_compressed_vectors_path, 0).unwrap(); + let (gt_data, gt_nr, gt_nc) = load_bin::(gound_truth_path, 0).unwrap(); + assert_eq!(nr, gt_nr); + assert_eq!(nc, gt_nc); + for i in 0..data.len() { + assert_eq!(data[i], gt_data[i]); + } + std::fs::remove_file(pq_compressed_vectors_path).unwrap(); + } +} diff --git a/rust/diskann/src/model/scratch/concurrent_queue.rs b/rust/diskann/src/model/scratch/concurrent_queue.rs new file mode 100644 index 000000000..8c72bab02 --- /dev/null +++ b/rust/diskann/src/model/scratch/concurrent_queue.rs @@ -0,0 +1,312 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Aligned allocator + +use std::collections::VecDeque; +use std::ops::Deref; +use std::sync::{Arc, Condvar, Mutex, MutexGuard}; +use std::time::Duration; + +use crate::common::{ANNError, ANNResult}; + +#[derive(Debug)] +/// Query scratch data structures +pub struct ConcurrentQueue { + q: Mutex>, + c: Mutex, + push_cv: Condvar, +} + +impl Default for ConcurrentQueue { + fn default() -> Self { + Self::new() + } +} + +impl ConcurrentQueue { + /// Create a concurrent queue + pub fn new() -> Self { + Self { + q: Mutex::new(VecDeque::new()), + c: Mutex::new(false), + push_cv: Condvar::new(), + } + } + + /// Block the current thread until it is able to acquire the mutex + pub fn reserve(&self, size: usize) -> ANNResult<()> { + let mut guard = lock(&self.q)?; + guard.reserve(size); + Ok(()) + } + + /// queue stats + pub fn size(&self) -> ANNResult { + let guard = lock(&self.q)?; + + Ok(guard.len()) + } + + /// empty the queue + pub fn is_empty(&self) -> ANNResult { + Ok(self.size()? == 0) + } + + /// push back + pub fn push(&self, new_val: T) -> ANNResult<()> { + let mut guard = lock(&self.q)?; + self.push_internal(&mut guard, new_val); + self.push_cv.notify_all(); + Ok(()) + } + + /// push back + fn push_internal(&self, guard: &mut MutexGuard>, new_val: T) { + guard.push_back(new_val); + } + + /// insert into queue + pub fn insert(&self, iter: I) -> ANNResult<()> + where + I: IntoIterator, + { + let mut guard = lock(&self.q)?; + for item in iter { + self.push_internal(&mut guard, item); + } + + self.push_cv.notify_all(); + Ok(()) + } + + /// pop front + pub fn pop(&self) -> ANNResult> { + let mut guard = lock(&self.q)?; + Ok(guard.pop_front()) + } + + /// Empty - is this necessary? + pub fn empty_queue(&self) -> ANNResult<()> { + let mut guard = lock(&self.q)?; + while !guard.is_empty() { + let _ = guard.pop_front(); + } + Ok(()) + } + + /// register for push notifications + pub fn wait_for_push_notify(&self, wait_time: Duration) -> ANNResult<()> { + let guard_lock = lock(&self.c)?; + let _ = self + .push_cv + .wait_timeout(guard_lock, wait_time) + .map_err(|err| { + ANNError::log_lock_poison_error(format!( + "ConcurrentQueue Lock is poisoned, err={}", + err + )) + })?; + Ok(()) + } +} + +fn lock(mutex: &Mutex) -> ANNResult> { + let guard = mutex.lock().map_err(|err| { + ANNError::log_lock_poison_error(format!("ConcurrentQueue lock is poisoned, err={}", err)) + })?; + Ok(guard) +} + +/// A thread-safe queue that holds instances of `T`. +/// Each instance is stored in a `Box` to keep the size of the queue node constant. +#[derive(Debug)] +pub struct ArcConcurrentBoxedQueue { + internal_queue: Arc>>, +} + +impl ArcConcurrentBoxedQueue { + /// Create a new `ArcConcurrentBoxedQueue`. + pub fn new() -> Self { + Self { + internal_queue: Arc::new(ConcurrentQueue::new()), + } + } +} + +impl Default for ArcConcurrentBoxedQueue { + fn default() -> Self { + Self::new() + } +} + +impl Clone for ArcConcurrentBoxedQueue { + /// Create a new `ArcConcurrentBoxedQueue` that shares the same internal queue + /// with the existing one. This allows multiple `ArcConcurrentBoxedQueue` to + /// operate on the same underlying queue. + fn clone(&self) -> Self { + Self { + internal_queue: Arc::clone(&self.internal_queue), + } + } +} + +/// Deref to the ConcurrentQueue. +impl Deref for ArcConcurrentBoxedQueue { + type Target = ConcurrentQueue>; + + fn deref(&self) -> &Self::Target { + &self.internal_queue + } +} + +#[cfg(test)] +mod tests { + use crate::model::ConcurrentQueue; + use std::sync::Arc; + use std::thread; + use std::time::Duration; + + #[test] + fn test_push_pop() { + let queue = ConcurrentQueue::::new(); + + queue.push(1).unwrap(); + queue.push(2).unwrap(); + queue.push(3).unwrap(); + + assert_eq!(queue.pop().unwrap(), Some(1)); + assert_eq!(queue.pop().unwrap(), Some(2)); + assert_eq!(queue.pop().unwrap(), Some(3)); + assert_eq!(queue.pop().unwrap(), None); + } + + #[test] + fn test_size_empty() { + let queue = ConcurrentQueue::new(); + + assert_eq!(queue.size().unwrap(), 0); + assert!(queue.is_empty().unwrap()); + + queue.push(1).unwrap(); + queue.push(2).unwrap(); + + assert_eq!(queue.size().unwrap(), 2); + assert!(!queue.is_empty().unwrap()); + + queue.pop().unwrap(); + queue.pop().unwrap(); + + assert_eq!(queue.size().unwrap(), 0); + assert!(queue.is_empty().unwrap()); + } + + #[test] + fn test_insert() { + let queue = ConcurrentQueue::new(); + + let data = vec![1, 2, 3]; + queue.insert(data.into_iter()).unwrap(); + + assert_eq!(queue.pop().unwrap(), Some(1)); + assert_eq!(queue.pop().unwrap(), Some(2)); + assert_eq!(queue.pop().unwrap(), Some(3)); + assert_eq!(queue.pop().unwrap(), None); + } + + #[test] + fn test_notifications() { + let queue = Arc::new(ConcurrentQueue::new()); + let queue_clone = Arc::clone(&queue); + + let producer = thread::spawn(move || { + for i in 0..3 { + thread::sleep(Duration::from_millis(50)); + queue_clone.push(i).unwrap(); + } + }); + + let consumer = thread::spawn(move || { + let mut values = vec![]; + + for _ in 0..3 { + let mut val = -1; + while val == -1 { + queue + .wait_for_push_notify(Duration::from_millis(10)) + .unwrap(); + val = queue.pop().unwrap().unwrap_or(-1); + } + + values.push(val); + } + + values + }); + + producer.join().unwrap(); + let consumer_results = consumer.join().unwrap(); + + assert_eq!(consumer_results, vec![0, 1, 2]); + } + + #[test] + fn test_multithreaded_push_pop() { + let queue = Arc::new(ConcurrentQueue::new()); + let queue_clone = Arc::clone(&queue); + + let producer = thread::spawn(move || { + for i in 0..10 { + queue_clone.push(i).unwrap(); + thread::sleep(Duration::from_millis(50)); + } + }); + + let consumer = thread::spawn(move || { + let mut values = vec![]; + + for _ in 0..10 { + let mut val = -1; + while val == -1 { + val = queue.pop().unwrap().unwrap_or(-1); + thread::sleep(Duration::from_millis(10)); + } + + values.push(val); + } + + values + }); + + producer.join().unwrap(); + let consumer_results = consumer.join().unwrap(); + + assert_eq!(consumer_results, (0..10).collect::>()); + } + + /// This is a single value test. It avoids the unlimited wait until the collectin got empty on the previous test. + /// It will make sure the signal mutex is matching the waiting mutex. + #[test] + fn test_wait_for_push_notify() { + let queue = Arc::new(ConcurrentQueue::::new()); + let queue_clone = Arc::clone(&queue); + + let producer = thread::spawn(move || { + thread::sleep(Duration::from_millis(100)); + queue_clone.push(1).unwrap(); + }); + + let consumer = thread::spawn(move || { + queue + .wait_for_push_notify(Duration::from_millis(200)) + .unwrap(); + assert_eq!(queue.pop().unwrap(), Some(1)); + }); + + producer.join().unwrap(); + consumer.join().unwrap(); + } +} diff --git a/rust/diskann/src/model/scratch/inmem_query_scratch.rs b/rust/diskann/src/model/scratch/inmem_query_scratch.rs new file mode 100644 index 000000000..f0fa432c2 --- /dev/null +++ b/rust/diskann/src/model/scratch/inmem_query_scratch.rs @@ -0,0 +1,186 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Scratch space for in-memory index based search + +use std::cmp::max; +use std::mem; + +use hashbrown::HashSet; + +use crate::common::{ANNError, ANNResult, AlignedBoxWithSlice}; +use crate::model::configuration::index_write_parameters::IndexWriteParameters; +use crate::model::{Neighbor, NeighborPriorityQueue, PQScratch}; + +use super::Scratch; + +/// In-mem index related limits +pub const GRAPH_SLACK_FACTOR: f64 = 1.3_f64; + +/// Max number of points for using bitset +pub const MAX_POINTS_FOR_USING_BITSET: usize = 100000; + +/// TODO: SSD Index related limits +pub const MAX_GRAPH_DEGREE: usize = 512; + +/// TODO: SSD Index related limits +pub const MAX_N_CMPS: usize = 16384; + +/// TODO: SSD Index related limits +pub const SECTOR_LEN: usize = 4096; + +/// TODO: SSD Index related limits +pub const MAX_N_SECTOR_READS: usize = 128; + +/// The alignment required for memory access. This will be multiplied with size of T to get the actual alignment +pub const QUERY_ALIGNMENT_OF_T_SIZE: usize = 16; + +/// Scratch space for in-memory index based search +#[derive(Debug)] +pub struct InMemQueryScratch { + /// Size of the candidate queue + pub candidate_size: u32, + + /// Max degree for each vertex + pub max_degree: u32, + + /// Max occlusion size + pub max_occlusion_size: u32, + + /// Query node + pub query: AlignedBoxWithSlice, + + /// Best candidates, whose size is candidate_queue_size + pub best_candidates: NeighborPriorityQueue, + + /// Occlude factor + pub occlude_factor: Vec, + + /// Visited neighbor id + pub id_scratch: Vec, + + /// The distance between visited neighbor and query node + pub dist_scratch: Vec, + + /// The PQ Scratch, keey it private since this class use the Box to own the memory. Use the function pq_scratch to get its reference + pub pq_scratch: Option>, + + /// Buffers used in process delete, capacity increases as needed + pub expanded_nodes_set: HashSet, + + /// Expanded neighbors + pub expanded_neighbors_vector: Vec, + + /// Occlude list + pub occlude_list_output: Vec, + + /// RobinSet for larger dataset + pub node_visited_robinset: HashSet, +} + +impl InMemQueryScratch { + /// Create InMemQueryScratch instance + pub fn new( + search_candidate_size: u32, + index_write_parameter: &IndexWriteParameters, + init_pq_scratch: bool, + ) -> ANNResult { + let indexing_candidate_size = index_write_parameter.search_list_size; + let max_degree = index_write_parameter.max_degree; + let max_occlusion_size = index_write_parameter.max_occlusion_size; + + if search_candidate_size == 0 || indexing_candidate_size == 0 || max_degree == 0 || N == 0 { + return Err(ANNError::log_index_error(format!( + "In InMemQueryScratch, one of search_candidate_size = {}, indexing_candidate_size = {}, dim = {} or max_degree = {} is zero.", + search_candidate_size, indexing_candidate_size, N, max_degree))); + } + + let query = AlignedBoxWithSlice::new(N, mem::size_of::() * QUERY_ALIGNMENT_OF_T_SIZE)?; + let pq_scratch = if init_pq_scratch { + Some(Box::new(PQScratch::new(MAX_GRAPH_DEGREE, N)?)) + } else { + None + }; + + let occlude_factor = Vec::with_capacity(max_occlusion_size as usize); + + let capacity = (1.5 * GRAPH_SLACK_FACTOR * (max_degree as f64)).ceil() as usize; + let id_scratch = Vec::with_capacity(capacity); + let dist_scratch = Vec::with_capacity(capacity); + + let expanded_nodes_set = HashSet::::new(); + let expanded_neighbors_vector = Vec::::new(); + let occlude_list_output = Vec::::new(); + + let candidate_size = max(search_candidate_size, indexing_candidate_size); + let node_visited_robinset = HashSet::::with_capacity(20 * candidate_size as usize); + let scratch = Self { + candidate_size, + max_degree, + max_occlusion_size, + query, + best_candidates: NeighborPriorityQueue::with_capacity(candidate_size as usize), + occlude_factor, + id_scratch, + dist_scratch, + pq_scratch, + expanded_nodes_set, + expanded_neighbors_vector, + occlude_list_output, + node_visited_robinset, + }; + + Ok(scratch) + } + + /// Resize the scratch with new candidate size + pub fn resize_for_new_candidate_size(&mut self, new_candidate_size: u32) { + if new_candidate_size > self.candidate_size { + let delta = new_candidate_size - self.candidate_size; + self.candidate_size = new_candidate_size; + self.best_candidates.reserve(delta as usize); + self.node_visited_robinset.reserve((20 * delta) as usize); + } + } +} + +impl Scratch for InMemQueryScratch { + fn clear(&mut self) { + self.best_candidates.clear(); + self.occlude_factor.clear(); + + self.node_visited_robinset.clear(); + + self.id_scratch.clear(); + self.dist_scratch.clear(); + + self.expanded_nodes_set.clear(); + self.expanded_neighbors_vector.clear(); + self.occlude_list_output.clear(); + } +} + +#[cfg(test)] +mod inmemory_query_scratch_test { + use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder; + + use super::*; + + #[test] + fn node_visited_robinset_test() { + let index_write_parameter = IndexWriteParametersBuilder::new(10, 10) + .with_max_occlusion_size(5) + .build(); + + let mut scratch = + InMemQueryScratch::::new(100, &index_write_parameter, false).unwrap(); + + assert_eq!(scratch.node_visited_robinset.len(), 0); + + scratch.clear(); + assert_eq!(scratch.node_visited_robinset.len(), 0); + } +} diff --git a/rust/diskann/src/model/scratch/mod.rs b/rust/diskann/src/model/scratch/mod.rs new file mode 100644 index 000000000..cf9ee2900 --- /dev/null +++ b/rust/diskann/src/model/scratch/mod.rs @@ -0,0 +1,28 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod scratch_traits; +pub use scratch_traits::*; + +pub mod concurrent_queue; +pub use concurrent_queue::*; + +pub mod pq_scratch; +pub use pq_scratch::*; + + +pub mod inmem_query_scratch; +pub use inmem_query_scratch::*; + +pub mod scratch_store_manager; +pub use scratch_store_manager::*; + +pub mod ssd_query_scratch; +pub use ssd_query_scratch::*; + +pub mod ssd_thread_data; +pub use ssd_thread_data::*; + +pub mod ssd_io_context; +pub use ssd_io_context::*; diff --git a/rust/diskann/src/model/scratch/pq_scratch.rs b/rust/diskann/src/model/scratch/pq_scratch.rs new file mode 100644 index 000000000..bf9d6c547 --- /dev/null +++ b/rust/diskann/src/model/scratch/pq_scratch.rs @@ -0,0 +1,105 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Aligned allocator + +use std::mem::size_of; + +use crate::common::{ANNResult, AlignedBoxWithSlice}; + +const MAX_PQ_CHUNKS: usize = 512; + +#[derive(Debug)] +/// PQ scratch +pub struct PQScratch { + /// Aligned pq table dist scratch, must be at least [256 * NCHUNKS] + pub aligned_pqtable_dist_scratch: AlignedBoxWithSlice, + /// Aligned dist scratch, must be at least diskann MAX_DEGREE + pub aligned_dist_scratch: AlignedBoxWithSlice, + /// Aligned pq coord scratch, must be at least [N_CHUNKS * MAX_DEGREE] + pub aligned_pq_coord_scratch: AlignedBoxWithSlice, + /// Rotated query + pub rotated_query: AlignedBoxWithSlice, + /// Aligned query float + pub aligned_query_float: AlignedBoxWithSlice, +} + +impl PQScratch { + const ALIGNED_ALLOC_256: usize = 256; + + /// Create a new pq scratch + pub fn new(graph_degree: usize, aligned_dim: usize) -> ANNResult { + let aligned_pq_coord_scratch = + AlignedBoxWithSlice::new(graph_degree * MAX_PQ_CHUNKS, PQScratch::ALIGNED_ALLOC_256)?; + let aligned_pqtable_dist_scratch = + AlignedBoxWithSlice::new(256 * MAX_PQ_CHUNKS, PQScratch::ALIGNED_ALLOC_256)?; + let aligned_dist_scratch = + AlignedBoxWithSlice::new(graph_degree, PQScratch::ALIGNED_ALLOC_256)?; + let aligned_query_float = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::())?; + let rotated_query = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::())?; + + Ok(Self { + aligned_pqtable_dist_scratch, + aligned_dist_scratch, + aligned_pq_coord_scratch, + rotated_query, + aligned_query_float, + }) + } + + /// Set rotated_query and aligned_query_float values + pub fn set(&mut self, dim: usize, query: &[T], norm: f32) + where + T: Into + Copy, + { + for (d, item) in query.iter().enumerate().take(dim) { + let query_val: f32 = (*item).into(); + if (norm - 1.0).abs() > f32::EPSILON { + self.rotated_query[d] = query_val / norm; + self.aligned_query_float[d] = query_val / norm; + } else { + self.rotated_query[d] = query_val; + self.aligned_query_float[d] = query_val; + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::model::PQScratch; + + #[test] + fn test_pq_scratch() { + let graph_degree = 512; + let aligned_dim = 8; + + let mut pq_scratch: PQScratch = PQScratch::new(graph_degree, aligned_dim).unwrap(); + + // Check alignment + assert_eq!( + (pq_scratch.aligned_pqtable_dist_scratch.as_ptr() as usize) % 256, + 0 + ); + assert_eq!((pq_scratch.aligned_dist_scratch.as_ptr() as usize) % 256, 0); + assert_eq!( + (pq_scratch.aligned_pq_coord_scratch.as_ptr() as usize) % 256, + 0 + ); + assert_eq!((pq_scratch.rotated_query.as_ptr() as usize) % 32, 0); + assert_eq!((pq_scratch.aligned_query_float.as_ptr() as usize) % 32, 0); + + // Test set() method + let query = vec![1u8, 2, 3, 4, 5, 6, 7, 8]; + let norm = 2.0f32; + pq_scratch.set::(query.len(), &query, norm); + + (0..query.len()).for_each(|i| { + assert_eq!(pq_scratch.rotated_query[i], query[i] as f32 / norm); + assert_eq!(pq_scratch.aligned_query_float[i], query[i] as f32 / norm); + }); + } +} diff --git a/rust/diskann/src/model/scratch/scratch_store_manager.rs b/rust/diskann/src/model/scratch/scratch_store_manager.rs new file mode 100644 index 000000000..4e2397f49 --- /dev/null +++ b/rust/diskann/src/model/scratch/scratch_store_manager.rs @@ -0,0 +1,84 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::common::ANNResult; + +use super::ArcConcurrentBoxedQueue; +use super::{scratch_traits::Scratch}; +use std::time::Duration; + +pub struct ScratchStoreManager { + scratch: Option>, + scratch_pool: ArcConcurrentBoxedQueue, +} + +impl ScratchStoreManager { + pub fn new(scratch_pool: ArcConcurrentBoxedQueue, wait_time: Duration) -> ANNResult { + let mut scratch = scratch_pool.pop()?; + while scratch.is_none() { + scratch_pool.wait_for_push_notify(wait_time)?; + scratch = scratch_pool.pop()?; + } + + Ok(ScratchStoreManager { + scratch, + scratch_pool, + }) + } + + pub fn scratch_space(&mut self) -> Option<&mut T> { + self.scratch.as_deref_mut() + } +} + +impl Drop for ScratchStoreManager { + fn drop(&mut self) { + if let Some(mut scratch) = self.scratch.take() { + scratch.clear(); + let _ = self.scratch_pool.push(scratch); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug)] + struct MyScratch { + data: Vec, + } + + impl Scratch for MyScratch { + fn clear(&mut self) { + self.data.clear(); + } + } + + #[test] + fn test_scratch_store_manager() { + let wait_time = Duration::from_millis(100); + + let scratch_pool = ArcConcurrentBoxedQueue::new(); + for i in 1..3 { + scratch_pool.push(Box::new(MyScratch { + data: vec![i, 2 * i, 3 * i], + })).unwrap(); + } + + let mut manager = ScratchStoreManager::new(scratch_pool.clone(), wait_time).unwrap(); + let scratch_space = manager.scratch_space().unwrap(); + + assert_eq!(scratch_space.data, vec![1, 2, 3]); + + // At this point, the ScratchStoreManager will go out of scope, + // causing the Drop implementation to be called, which should + // call the clear method on MyScratch. + drop(manager); + + let current_scratch = scratch_pool.pop().unwrap().unwrap(); + assert_eq!(current_scratch.data, vec![2, 4, 6]); + } +} + diff --git a/rust/diskann/src/model/scratch/scratch_traits.rs b/rust/diskann/src/model/scratch/scratch_traits.rs new file mode 100644 index 000000000..71e4b932d --- /dev/null +++ b/rust/diskann/src/model/scratch/scratch_traits.rs @@ -0,0 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub trait Scratch { + fn clear(&mut self); +} + diff --git a/rust/diskann/src/model/scratch/ssd_io_context.rs b/rust/diskann/src/model/scratch/ssd_io_context.rs new file mode 100644 index 000000000..d4dff0cec --- /dev/null +++ b/rust/diskann/src/model/scratch/ssd_io_context.rs @@ -0,0 +1,38 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete. +use crate::common::ANNError; + +use platform::{FileHandle, IOCompletionPort}; + +// The IOContext struct for disk I/O. One for each thread. +pub struct IOContext { + pub status: Status, + pub file_handle: FileHandle, + pub io_completion_port: IOCompletionPort, +} + +impl Default for IOContext { + fn default() -> Self { + IOContext { + status: Status::ReadWait, + file_handle: FileHandle::default(), + io_completion_port: IOCompletionPort::default(), + } + } +} + +impl IOContext { + pub fn new() -> Self { + Self::default() + } +} + +pub enum Status { + ReadWait, + ReadSuccess, + ReadFailed(ANNError), + ProcessComplete, +} diff --git a/rust/diskann/src/model/scratch/ssd_query_scratch.rs b/rust/diskann/src/model/scratch/ssd_query_scratch.rs new file mode 100644 index 000000000..b36669303 --- /dev/null +++ b/rust/diskann/src/model/scratch/ssd_query_scratch.rs @@ -0,0 +1,132 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete. +use std::mem; +use std::vec::Vec; + +use hashbrown::HashSet; + +use crate::{ + common::{ANNResult, AlignedBoxWithSlice}, + model::{Neighbor, NeighborPriorityQueue}, + model::data_store::DiskScratchDataset, +}; + +use super::{PQScratch, Scratch, MAX_GRAPH_DEGREE, QUERY_ALIGNMENT_OF_T_SIZE}; + +// Scratch space for disk index based search. +pub struct SSDQueryScratch +{ + // Disk scratch dataset storing fp vectors with aligned dim (N) + pub scratch_dataset: DiskScratchDataset, + + // The query scratch. + pub query: AlignedBoxWithSlice, + + /// The PQ Scratch. + pub pq_scratch: Option>, + + // The visited set. + pub id_scratch: HashSet, + + /// Best candidates, whose size is candidate_queue_size + pub best_candidates: NeighborPriorityQueue, + + // Full return set. + pub full_return_set: Vec, +} + +// +impl SSDQueryScratch +{ + pub fn new( + visited_reserve: usize, + candidate_queue_size: usize, + init_pq_scratch: bool, + ) -> ANNResult { + let scratch_dataset = DiskScratchDataset::::new()?; + + let query = AlignedBoxWithSlice::::new(N, mem::size_of::() * QUERY_ALIGNMENT_OF_T_SIZE)?; + + let id_scratch = HashSet::::with_capacity(visited_reserve); + let full_return_set = Vec::::with_capacity(visited_reserve); + let best_candidates = NeighborPriorityQueue::with_capacity(candidate_queue_size); + + let pq_scratch = if init_pq_scratch { + Some(Box::new(PQScratch::new(MAX_GRAPH_DEGREE, N)?)) + } else { + None + }; + + Ok(Self { + scratch_dataset, + query, + pq_scratch, + id_scratch, + best_candidates, + full_return_set, + }) + } + + pub fn pq_scratch(&mut self) -> &Option> { + &self.pq_scratch + } +} + +impl Scratch for SSDQueryScratch +{ + fn clear(&mut self) { + self.id_scratch.clear(); + self.best_candidates.clear(); + self.full_return_set.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new() { + // Arrange + let visited_reserve = 100; + let candidate_queue_size = 10; + let init_pq_scratch = true; + + // Act + let result = + SSDQueryScratch::::new(visited_reserve, candidate_queue_size, init_pq_scratch); + + // Assert + assert!(result.is_ok()); + + let scratch = result.unwrap(); + + // Assert the properties of the scratch instance + assert!(scratch.pq_scratch.is_some()); + assert!(scratch.id_scratch.is_empty()); + assert!(scratch.best_candidates.size() == 0); + assert!(scratch.full_return_set.is_empty()); + } + + #[test] + fn test_clear() { + // Arrange + let mut scratch = SSDQueryScratch::::new(100, 10, true).unwrap(); + + // Add some data to scratch fields + scratch.id_scratch.insert(1); + scratch.best_candidates.insert(Neighbor::new(2, 0.5)); + scratch.full_return_set.push(Neighbor::new(3, 0.8)); + + // Act + scratch.clear(); + + // Assert + assert!(scratch.id_scratch.is_empty()); + assert!(scratch.best_candidates.size() == 0); + assert!(scratch.full_return_set.is_empty()); + } +} diff --git a/rust/diskann/src/model/scratch/ssd_thread_data.rs b/rust/diskann/src/model/scratch/ssd_thread_data.rs new file mode 100644 index 000000000..e37495901 --- /dev/null +++ b/rust/diskann/src/model/scratch/ssd_thread_data.rs @@ -0,0 +1,92 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete. +use std::sync::Arc; + +use super::{scratch_traits::Scratch, IOContext, SSDQueryScratch}; +use crate::common::ANNResult; + +// The thread data struct for SSD I/O. One for each thread, contains the ScratchSpace and the IOContext. +pub struct SSDThreadData { + pub scratch: SSDQueryScratch, + pub io_context: Option>, +} + +impl SSDThreadData { + pub fn new( + aligned_dim: usize, + visited_reserve: usize, + init_pq_scratch: bool, + ) -> ANNResult { + let scratch = SSDQueryScratch::new(aligned_dim, visited_reserve, init_pq_scratch)?; + Ok(SSDThreadData { + scratch, + io_context: None, + }) + } + + pub fn clear(&mut self) { + self.scratch.clear(); + } +} + +#[cfg(test)] +mod tests { + use crate::model::Neighbor; + + use super::*; + + #[test] + fn test_new() { + // Arrange + let aligned_dim = 10; + let visited_reserve = 100; + let init_pq_scratch = true; + + // Act + let result = SSDThreadData::::new(aligned_dim, visited_reserve, init_pq_scratch); + + // Assert + assert!(result.is_ok()); + + let thread_data = result.unwrap(); + + // Assert the properties of the thread data instance + assert!(thread_data.io_context.is_none()); + + let scratch = &thread_data.scratch; + // Assert the properties of the scratch instance + assert!(scratch.pq_scratch.is_some()); + assert!(scratch.id_scratch.is_empty()); + assert!(scratch.best_candidates.size() == 0); + assert!(scratch.full_return_set.is_empty()); + } + + #[test] + fn test_clear() { + // Arrange + let mut thread_data = SSDThreadData::::new(10, 100, true).unwrap(); + + // Add some data to scratch fields + thread_data.scratch.id_scratch.insert(1); + thread_data + .scratch + .best_candidates + .insert(Neighbor::new(2, 0.5)); + thread_data + .scratch + .full_return_set + .push(Neighbor::new(3, 0.8)); + + // Act + thread_data.clear(); + + // Assert + assert!(thread_data.scratch.id_scratch.is_empty()); + assert!(thread_data.scratch.best_candidates.size() == 0); + assert!(thread_data.scratch.full_return_set.is_empty()); + } +} + diff --git a/rust/diskann/src/model/vertex/dimension.rs b/rust/diskann/src/model/vertex/dimension.rs new file mode 100644 index 000000000..32670a8db --- /dev/null +++ b/rust/diskann/src/model/vertex/dimension.rs @@ -0,0 +1,22 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Vertex dimension + +/// 32 vertex dimension +pub const DIM_32: usize = 32; + +/// 64 vertex dimension +pub const DIM_64: usize = 64; + +/// 104 vertex dimension +pub const DIM_104: usize = 104; + +/// 128 vertex dimension +pub const DIM_128: usize = 128; + +/// 256 vertex dimension +pub const DIM_256: usize = 256; diff --git a/rust/diskann/src/model/vertex/mod.rs b/rust/diskann/src/model/vertex/mod.rs new file mode 100644 index 000000000..224d476dc --- /dev/null +++ b/rust/diskann/src/model/vertex/mod.rs @@ -0,0 +1,10 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod vertex; +pub use vertex::Vertex; + +mod dimension; +pub use dimension::*; diff --git a/rust/diskann/src/model/vertex/vertex.rs b/rust/diskann/src/model/vertex/vertex.rs new file mode 100644 index 000000000..55369748e --- /dev/null +++ b/rust/diskann/src/model/vertex/vertex.rs @@ -0,0 +1,68 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Vertex + +use std::array::TryFromSliceError; + +use vector::{FullPrecisionDistance, Metric}; + +/// Vertex with data type T and dimension N +#[derive(Debug)] +pub struct Vertex<'a, T, const N: usize> +where + [T; N]: FullPrecisionDistance, +{ + /// Vertex value + val: &'a [T; N], + + /// Vertex Id + id: u32, +} + +impl<'a, T, const N: usize> Vertex<'a, T, N> +where + [T; N]: FullPrecisionDistance, +{ + /// Create the vertex with data + pub fn new(val: &'a [T; N], id: u32) -> Self { + Self { + val, + id, + } + } + + /// Compare the vertex with another. + #[inline(always)] + pub fn compare(&self, other: &Vertex<'a, T, N>, metric: Metric) -> f32 { + <[T; N]>::distance_compare(self.val, other.val, metric) + } + + /// Get the vector associated with the vertex. + #[inline] + pub fn vector(&self) -> &[T; N] { + self.val + } + + /// Get the vertex id. + #[inline] + pub fn vertex_id(&self) -> u32 { + self.id + } +} + +impl<'a, T, const N: usize> TryFrom<(&'a [T], u32)> for Vertex<'a, T, N> +where + [T; N]: FullPrecisionDistance, +{ + type Error = TryFromSliceError; + + fn try_from((mem_slice, id): (&'a [T], u32)) -> Result { + let array: &[T; N] = mem_slice.try_into()?; + Ok(Vertex::new(array, id)) + } +} + diff --git a/rust/diskann/src/model/windows_aligned_file_reader/mod.rs b/rust/diskann/src/model/windows_aligned_file_reader/mod.rs new file mode 100644 index 000000000..0e63df0a6 --- /dev/null +++ b/rust/diskann/src/model/windows_aligned_file_reader/mod.rs @@ -0,0 +1,7 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod windows_aligned_file_reader; +pub use windows_aligned_file_reader::*; diff --git a/rust/diskann/src/model/windows_aligned_file_reader/windows_aligned_file_reader.rs b/rust/diskann/src/model/windows_aligned_file_reader/windows_aligned_file_reader.rs new file mode 100644 index 000000000..1cc3dc032 --- /dev/null +++ b/rust/diskann/src/model/windows_aligned_file_reader/windows_aligned_file_reader.rs @@ -0,0 +1,414 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::sync::Arc; +use std::time::Duration; +use std::{ptr, thread}; + +use crossbeam::sync::ShardedLock; +use hashbrown::HashMap; +use once_cell::sync::Lazy; + +use platform::file_handle::{AccessMode, ShareMode}; +use platform::{ + file_handle::FileHandle, + file_io::{get_queued_completion_status, read_file_to_slice}, + io_completion_port::IOCompletionPort, +}; + +use winapi::{ + shared::{basetsd::ULONG_PTR, minwindef::DWORD}, + um::minwinbase::OVERLAPPED, +}; + +use crate::common::{ANNError, ANNResult}; +use crate::model::IOContext; + +pub const MAX_IO_CONCURRENCY: usize = 128; // To do: explore the optimal value for this. The current value is taken from C++ code. +pub const FILE_ATTRIBUTE_READONLY: DWORD = 0x00000001; +pub const IO_COMPLETION_TIMEOUT: DWORD = u32::MAX; // Infinite timeout. +pub const DISK_IO_ALIGNMENT: usize = 512; +pub const ASYNC_IO_COMPLETION_CHECK_INTERVAL: Duration = Duration::from_micros(5); + +/// Aligned read struct for disk IO, it takes the ownership of the AlignedBoxedSlice and returns the AlignedBoxWithSlice data immutably. +pub struct AlignedRead<'a, T> { + /// where to read from + /// offset needs to be aligned with DISK_IO_ALIGNMENT + offset: u64, + + /// where to read into + /// aligned_buf and its len need to be aligned with DISK_IO_ALIGNMENT + aligned_buf: &'a mut [T], +} + +impl<'a, T> AlignedRead<'a, T> { + pub fn new(offset: u64, aligned_buf: &'a mut [T]) -> ANNResult { + Self::assert_is_aligned(offset as usize)?; + Self::assert_is_aligned(std::mem::size_of_val(aligned_buf))?; + + Ok(Self { + offset, + aligned_buf, + }) + } + + fn assert_is_aligned(val: usize) -> ANNResult<()> { + match val % DISK_IO_ALIGNMENT { + 0 => Ok(()), + _ => Err(ANNError::log_disk_io_request_alignment_error(format!( + "The offset or length of AlignedRead request is not {} bytes aligned", + DISK_IO_ALIGNMENT + ))), + } + } + + pub fn aligned_buf(&self) -> &[T] { + self.aligned_buf + } +} + +pub struct WindowsAlignedFileReader { + file_name: String, + + // ctx_map is the mapping from thread id to io context. It is hashmap behind a sharded lock to allow concurrent access from multiple threads. + // ShardedLock: shardedlock provides an implementation of a reader-writer lock that offers concurrent read access to the shared data while allowing exclusive write access. + // It achieves better scalability by dividing the shared data into multiple shards, and each with its own internal lock. + // Multiple threads can read from different shards simultaneously, reducing contention. + // https://docs.rs/crossbeam/0.8.2/crossbeam/sync/struct.ShardedLock.html + // Comparing to RwLock, ShardedLock provides higher concurrency for read operations and is suitable for read heavy workloads. + // The value of the hashmap is an Arc to allow immutable access to IOContext with automatic reference counting. + ctx_map: Lazy>>>, +} + +impl WindowsAlignedFileReader { + pub fn new(fname: &str) -> ANNResult { + let reader: WindowsAlignedFileReader = WindowsAlignedFileReader { + file_name: fname.to_string(), + ctx_map: Lazy::new(|| ShardedLock::new(HashMap::new())), + }; + + reader.register_thread()?; + Ok(reader) + } + + // Register the io context for a thread if it hasn't been registered. + pub fn register_thread(&self) -> ANNResult<()> { + let mut ctx_map = self.ctx_map.write().map_err(|_| { + ANNError::log_lock_poison_error("unable to acquire read lock on ctx_map".to_string()) + })?; + + let id = thread::current().id(); + if ctx_map.contains_key(&id) { + println!( + "Warning:: Duplicate registration for thread_id : {:?}. Directly call get_ctx to get the thread context data.", + id); + + return Ok(()); + } + + let mut ctx = IOContext::new(); + + match unsafe { FileHandle::new(&self.file_name, AccessMode::Read, ShareMode::Read) } { + Ok(file_handle) => ctx.file_handle = file_handle, + Err(err) => { + return Err(ANNError::log_io_error(err)); + } + } + + // Create a io completion port for the file handle, later it will be used to get the completion status. + match IOCompletionPort::new(&ctx.file_handle, None, 0, 0) { + Ok(io_completion_port) => ctx.io_completion_port = io_completion_port, + Err(err) => { + return Err(ANNError::log_io_error(err)); + } + } + + ctx_map.insert(id, Arc::new(ctx)); + + Ok(()) + } + + // Get the reference counted io context for the current thread. + pub fn get_ctx(&self) -> ANNResult> { + let ctx_map = self.ctx_map.read().map_err(|_| { + ANNError::log_lock_poison_error("unable to acquire read lock on ctx_map".to_string()) + })?; + + let id = thread::current().id(); + match ctx_map.get(&id) { + Some(ctx) => Ok(Arc::clone(ctx)), + None => Err(ANNError::log_index_error(format!( + "unable to find IOContext for thread_id {:?}", + id + ))), + } + } + + // Read the data from the file by sending concurrent io requests in batches. + pub fn read(&self, read_requests: &mut [AlignedRead], ctx: &IOContext) -> ANNResult<()> { + let n_requests = read_requests.len(); + let n_batches = (n_requests + MAX_IO_CONCURRENCY - 1) / MAX_IO_CONCURRENCY; + + let mut overlapped_in_out = + vec![unsafe { std::mem::zeroed::() }; MAX_IO_CONCURRENCY]; + + for batch_idx in 0..n_batches { + let batch_start = MAX_IO_CONCURRENCY * batch_idx; + let batch_size = std::cmp::min(n_requests - batch_start, MAX_IO_CONCURRENCY); + + for j in 0..batch_size { + let req = &mut read_requests[batch_start + j]; + let os = &mut overlapped_in_out[j]; + + match unsafe { + read_file_to_slice(&ctx.file_handle, req.aligned_buf, os, req.offset) + } { + Ok(_) => {} + Err(error) => { + return Err(ANNError::IOError { err: (error) }); + } + } + } + + let mut n_read: DWORD = 0; + let mut n_complete: u64 = 0; + let mut completion_key: ULONG_PTR = 0; + let mut lp_os: *mut OVERLAPPED = ptr::null_mut(); + while n_complete < batch_size as u64 { + match unsafe { + get_queued_completion_status( + &ctx.io_completion_port, + &mut n_read, + &mut completion_key, + &mut lp_os, + IO_COMPLETION_TIMEOUT, + ) + } { + // An IO request completed. + Ok(true) => n_complete += 1, + // No IO request completed, continue to wait. + Ok(false) => { + thread::sleep(ASYNC_IO_COMPLETION_CHECK_INTERVAL); + } + // An error ocurred. + Err(error) => return Err(ANNError::IOError { err: (error) }), + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{fs::File, io::BufReader}; + + use bincode::deserialize_from; + use serde::{Deserialize, Serialize}; + + use crate::{common::AlignedBoxWithSlice, model::SECTOR_LEN}; + + use super::*; + pub const TEST_INDEX_PATH: &str = + "./tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_alligned_reader_test.index"; + pub const TRUTH_NODE_DATA_PATH: &str = + "./tests/data/disk_index_node_data_aligned_reader_truth.bin"; + + #[derive(Debug, Serialize, Deserialize)] + struct NodeData { + num_neighbors: u32, + coordinates: Vec, + neighbors: Vec, + } + + impl PartialEq for NodeData { + fn eq(&self, other: &Self) -> bool { + self.num_neighbors == other.num_neighbors + && self.coordinates == other.coordinates + && self.neighbors == other.neighbors + } + } + + #[test] + fn test_new_aligned_file_reader() { + // Replace "test_file_path" with actual file path + let result = WindowsAlignedFileReader::new(TEST_INDEX_PATH); + assert!(result.is_ok()); + + let reader = result.unwrap(); + assert_eq!(reader.file_name, TEST_INDEX_PATH); + } + + #[test] + fn test_read() { + let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap(); + let ctx = reader.get_ctx().unwrap(); + + let read_length = 512; // adjust according to your logic + let num_read = 10; + let mut aligned_mem = AlignedBoxWithSlice::::new(read_length * num_read, 512).unwrap(); + + // create and add AlignedReads to the vector + let mut mem_slices = aligned_mem + .split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length) + .unwrap(); + + let mut aligned_reads: Vec> = mem_slices + .iter_mut() + .enumerate() + .map(|(i, slice)| { + let offset = (i * read_length) as u64; + AlignedRead::new(offset, slice).unwrap() + }) + .collect(); + + let result = reader.read(&mut aligned_reads, &ctx); + assert!(result.is_ok()); + } + + #[test] + fn test_read_disk_index_by_sector() { + let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap(); + let ctx = reader.get_ctx().unwrap(); + + let read_length = SECTOR_LEN; // adjust according to your logic + let num_sector = 10; + let mut aligned_mem = + AlignedBoxWithSlice::::new(read_length * num_sector, 512).unwrap(); + + // Each slice will be used as the buffer for a read request of a sector. + let mut mem_slices = aligned_mem + .split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length) + .unwrap(); + + let mut aligned_reads: Vec> = mem_slices + .iter_mut() + .enumerate() + .map(|(sector_id, slice)| { + let offset = (sector_id * read_length) as u64; + AlignedRead::new(offset, slice).unwrap() + }) + .collect(); + + let result = reader.read(&mut aligned_reads, &ctx); + assert!(result.is_ok()); + + aligned_reads.iter().for_each(|read| { + assert_eq!(read.aligned_buf.len(), SECTOR_LEN); + }); + + let disk_layout_meta = reconstruct_disk_meta(aligned_reads[0].aligned_buf); + assert!(disk_layout_meta.len() > 9); + + let dims = disk_layout_meta[1]; + let num_pts = disk_layout_meta[0]; + let max_node_len = disk_layout_meta[3]; + let max_num_nodes_per_sector = disk_layout_meta[4]; + + assert!(max_node_len * max_num_nodes_per_sector < SECTOR_LEN as u64); + + let num_nbrs_start = (dims as usize) * std::mem::size_of::(); + let nbrs_buf_start = num_nbrs_start + std::mem::size_of::(); + + let mut node_data_array = Vec::with_capacity(max_num_nodes_per_sector as usize * 9); + + // Only validate the first 9 sectors with graph nodes. + (1..9).for_each(|sector_id| { + let sector_data = &mem_slices[sector_id]; + for node_data in sector_data.chunks_exact(max_node_len as usize) { + // Extract coordinates data from the start of the node_data + let coordinates_end = (dims as usize) * std::mem::size_of::(); + let coordinates = node_data[0..coordinates_end] + .chunks_exact(std::mem::size_of::()) + .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap())) + .collect(); + + // Extract number of neighbors from the node_data + let neighbors_num = u32::from_le_bytes( + node_data[num_nbrs_start..nbrs_buf_start] + .try_into() + .unwrap(), + ); + + let nbors_buf_end = + nbrs_buf_start + (neighbors_num as usize) * std::mem::size_of::(); + + // Extract neighbors from the node data. + let mut neighbors = Vec::new(); + for nbors_data in node_data[nbrs_buf_start..nbors_buf_end] + .chunks_exact(std::mem::size_of::()) + { + let nbors_id = u32::from_le_bytes(nbors_data.try_into().unwrap()); + assert!(nbors_id < num_pts as u32); + neighbors.push(nbors_id); + } + + // Create NodeData struct and push it to the node_data_array + node_data_array.push(NodeData { + num_neighbors: neighbors_num, + coordinates, + neighbors, + }); + } + }); + + // Compare that each node read from the disk index are expected. + let node_data_truth_file = File::open(TRUTH_NODE_DATA_PATH).unwrap(); + let reader = BufReader::new(node_data_truth_file); + + let node_data_vec: Vec = deserialize_from(reader).unwrap(); + for (node_from_node_data_file, node_from_disk_index) in + node_data_vec.iter().zip(node_data_array.iter()) + { + // Verify that the NodeData from the file is equal to the NodeData in node_data_array + assert_eq!(node_from_node_data_file, node_from_disk_index); + } + } + + #[test] + fn test_read_fail_invalid_file() { + let reader = WindowsAlignedFileReader::new("/invalid_path"); + assert!(reader.is_err()); + } + + #[test] + fn test_read_no_requests() { + let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap(); + let ctx = reader.get_ctx().unwrap(); + + let mut read_requests = Vec::>::new(); + let result = reader.read(&mut read_requests, &ctx); + assert!(result.is_ok()); + } + + #[test] + fn test_get_ctx() { + let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap(); + let result = reader.get_ctx(); + assert!(result.is_ok()); + } + + #[test] + fn test_register_thread() { + let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap(); + let result = reader.register_thread(); + assert!(result.is_ok()); + } + + fn reconstruct_disk_meta(buffer: &[u8]) -> Vec { + let size_of_u64 = std::mem::size_of::(); + + let num_values = buffer.len() / size_of_u64; + let mut disk_layout_meta = Vec::with_capacity(num_values); + let meta_data = &buffer[8..]; + + for chunk in meta_data.chunks_exact(size_of_u64) { + let value = u64::from_le_bytes(chunk.try_into().unwrap()); + disk_layout_meta.push(value); + } + + disk_layout_meta + } +} diff --git a/rust/diskann/src/storage/disk_graph_storage.rs b/rust/diskann/src/storage/disk_graph_storage.rs new file mode 100644 index 000000000..448175212 --- /dev/null +++ b/rust/diskann/src/storage/disk_graph_storage.rs @@ -0,0 +1,37 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_docs)] + +//! Disk graph storage + +use std::sync::Arc; + +use crate::{model::{WindowsAlignedFileReader, IOContext, AlignedRead}, common::ANNResult}; + +/// Graph storage for disk index +/// One thread has one storage instance +pub struct DiskGraphStorage { + /// Disk graph reader + disk_graph_reader: Arc, + + /// IOContext of current thread + ctx: Arc, +} + +impl DiskGraphStorage { + /// Create a new DiskGraphStorage instance + pub fn new(disk_graph_reader: Arc) -> ANNResult { + let ctx = disk_graph_reader.get_ctx()?; + Ok(Self { + disk_graph_reader, + ctx, + }) + } + + /// Read disk graph data + pub fn read(&self, read_requests: &mut [AlignedRead]) -> ANNResult<()> { + self.disk_graph_reader.read(read_requests, &self.ctx) + } +} diff --git a/rust/diskann/src/storage/disk_index_storage.rs b/rust/diskann/src/storage/disk_index_storage.rs new file mode 100644 index 000000000..0c558084d --- /dev/null +++ b/rust/diskann/src/storage/disk_index_storage.rs @@ -0,0 +1,363 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use byteorder::{ByteOrder, LittleEndian, ReadBytesExt}; +use std::fs::File; +use std::io::Read; +use std::marker::PhantomData; +use std::{fs, mem}; + +use crate::common::{ANNError, ANNResult}; +use crate::model::NUM_PQ_CENTROIDS; +use crate::storage::PQStorage; +use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, load_bin, save_bin_u64}; +use crate::utils::{ + file_exists, gen_sample_data, get_file_size, round_up, CachedReader, CachedWriter, +}; + +const SECTOR_LEN: usize = 4096; + +/// Todo: Remove the allow(dead_code) when the disk search code is complete +#[allow(dead_code)] +pub struct PQPivotData { + dim: usize, + pq_table: Vec, + centroids: Vec, + chunk_offsets: Vec, +} + +pub struct DiskIndexStorage { + /// Dataset file + dataset_file: String, + + /// Index file path prefix + index_path_prefix: String, + + // TODO: Only a placeholder for T, will be removed later + _marker: PhantomData, + + pq_storage: PQStorage, +} + +impl DiskIndexStorage { + /// Create DiskIndexStorage instance + pub fn new(dataset_file: String, index_path_prefix: String) -> ANNResult { + let pq_storage: PQStorage = PQStorage::new( + &(index_path_prefix.clone() + ".bin_pq_pivots.bin"), + &(index_path_prefix.clone() + ".bin_pq_compressed.bin"), + &dataset_file, + )?; + + Ok(DiskIndexStorage { + dataset_file, + index_path_prefix, + _marker: PhantomData, + pq_storage, + }) + } + + pub fn get_pq_storage(&mut self) -> &mut PQStorage { + &mut self.pq_storage + } + + pub fn dataset_file(&self) -> &String { + &self.dataset_file + } + + pub fn index_path_prefix(&self) -> &String { + &self.index_path_prefix + } + + /// Create disk layout + /// Sector #1: disk_layout_meta + /// Sector #n: num_nodes_per_sector nodes + /// Each node's layout: {full precision vector:[T; DIM]}{num_nbrs: u32}{neighbors: [u32; num_nbrs]} + /// # Arguments + /// * `dataset_file` - dataset file containing full precision vectors + /// * `mem_index_file` - in-memory index graph file + /// * `disk_layout_file` - output disk layout file + pub fn create_disk_layout(&self) -> ANNResult<()> { + let mem_index_file = self.mem_index_file(); + let disk_layout_file = self.disk_index_file(); + + // amount to read or write in one shot + let read_blk_size = 64 * 1024 * 1024; + let write_blk_size = read_blk_size; + let mut dataset_reader = CachedReader::new(self.dataset_file.as_str(), read_blk_size)?; + + let num_pts = dataset_reader.read_u32()? as u64; + let dims = dataset_reader.read_u32()? as u64; + + // Create cached reader + writer + let actual_file_size = get_file_size(mem_index_file.as_str())?; + println!("Vamana index file size={}", actual_file_size); + + let mut vamana_reader = File::open(mem_index_file)?; + let mut diskann_writer = CachedWriter::new(disk_layout_file.as_str(), write_blk_size)?; + + let index_file_size = vamana_reader.read_u64::()?; + if index_file_size != actual_file_size { + println!( + "Vamana Index file size does not match expected size per meta-data. file size from file: {}, actual file size: {}", + index_file_size, actual_file_size + ); + } + + let max_degree = vamana_reader.read_u32::()?; + let medoid = vamana_reader.read_u32::()?; + let vamana_frozen_num = vamana_reader.read_u64::()?; + + let mut vamana_frozen_loc = 0; + if vamana_frozen_num == 1 { + vamana_frozen_loc = medoid; + } + + let max_node_len = ((max_degree as u64 + 1) * (mem::size_of::() as u64)) + + (dims * (mem::size_of::() as u64)); + let num_nodes_per_sector = (SECTOR_LEN as u64) / max_node_len; + + println!("medoid: {}B", medoid); + println!("max_node_len: {}B", max_node_len); + println!("num_nodes_per_sector: {}B", num_nodes_per_sector); + + // SECTOR_LEN buffer for each sector + let mut sector_buf = vec![0u8; SECTOR_LEN]; + let mut node_buf = vec![0u8; max_node_len as usize]; + + let num_nbrs_start = (dims as usize) * mem::size_of::(); + let nbrs_buf_start = num_nbrs_start + mem::size_of::(); + + // number of sectors (1 for meta data) + let num_sectors = round_up(num_pts, num_nodes_per_sector) / num_nodes_per_sector; + let disk_index_file_size = (num_sectors + 1) * (SECTOR_LEN as u64); + + let disk_layout_meta = vec![ + num_pts, + dims, + medoid as u64, + max_node_len, + num_nodes_per_sector, + vamana_frozen_num, + vamana_frozen_loc as u64, + // append_reorder_data + // We are not supporting this. Temporarily write it into the layout so that + // we can leverage C++ query driver to test the disk index + false as u64, + disk_index_file_size, + ]; + + diskann_writer.write(§or_buf)?; + + let mut cur_node_coords = vec![0u8; (dims as usize) * mem::size_of::()]; + let mut cur_node_id = 0u64; + + for sector in 0..num_sectors { + if sector % 100_000 == 0 { + println!("Sector #{} written", sector); + } + sector_buf.fill(0); + + for sector_node_id in 0..num_nodes_per_sector { + if cur_node_id >= num_pts { + break; + } + + node_buf.fill(0); + + // read cur node's num_nbrs + let num_nbrs = vamana_reader.read_u32::()?; + + // sanity checks on num_nbrs + debug_assert!(num_nbrs > 0); + debug_assert!(num_nbrs <= max_degree); + + // write coords of node first + dataset_reader.read(&mut cur_node_coords)?; + node_buf[..cur_node_coords.len()].copy_from_slice(&cur_node_coords); + + // write num_nbrs + LittleEndian::write_u32( + &mut node_buf[num_nbrs_start..(num_nbrs_start + mem::size_of::())], + num_nbrs, + ); + + // write neighbors + let nbrs_buf = &mut node_buf[nbrs_buf_start + ..(nbrs_buf_start + (num_nbrs as usize) * mem::size_of::())]; + vamana_reader.read_exact(nbrs_buf)?; + + // get offset into sector_buf + let sector_node_buf_start = (sector_node_id * max_node_len) as usize; + let sector_node_buf = &mut sector_buf + [sector_node_buf_start..(sector_node_buf_start + max_node_len as usize)]; + sector_node_buf.copy_from_slice(&node_buf[..(max_node_len as usize)]); + + cur_node_id += 1; + } + + // flush sector to disk + diskann_writer.write(§or_buf)?; + } + + diskann_writer.flush()?; + save_bin_u64( + disk_layout_file.as_str(), + &disk_layout_meta, + disk_layout_meta.len(), + 1, + 0, + )?; + + Ok(()) + } + + pub fn index_build_cleanup(&self) -> ANNResult<()> { + fs::remove_file(self.mem_index_file())?; + Ok(()) + } + + pub fn gen_query_warmup_data(&self, sampling_rate: f64) -> ANNResult<()> { + gen_sample_data::( + &self.dataset_file, + &self.warmup_query_prefix(), + sampling_rate, + )?; + Ok(()) + } + + /// Load pre-trained pivot table + pub fn load_pq_pivots_bin( + &self, + num_pq_chunks: &usize, + ) -> ANNResult { + let pq_pivots_path = &self.pq_pivot_file(); + if !file_exists(pq_pivots_path) { + return Err(ANNError::log_pq_error( + "ERROR: PQ k-means pivot file not found.".to_string(), + )); + } + + let (data, offset_num, offset_dim) = load_bin::(pq_pivots_path, 0)?; + let file_offset_data = convert_types_u64_usize(&data, offset_num, offset_dim); + if offset_num != 4 { + let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", pq_pivots_path, offset_num); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, pivot_num, dim) = load_bin::(pq_pivots_path, file_offset_data[0])?; + let pq_table = data.to_vec(); + if pivot_num != NUM_PQ_CENTROIDS { + let error_message = format!( + "Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.", + pq_pivots_path, pivot_num, NUM_PQ_CENTROIDS + ); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, centroid_dim, nc) = load_bin::(pq_pivots_path, file_offset_data[1])?; + let centroids = data.to_vec(); + if centroid_dim != dim || nc != 1 { + let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", pq_pivots_path, centroid_dim, nc, dim); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, chunk_offset_num, nc) = load_bin::(pq_pivots_path, file_offset_data[2])?; + let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_num, nc); + if chunk_offset_num != num_pq_chunks + 1 || nc != 1 { + let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_num, nc, num_pq_chunks + 1); + return Err(ANNError::log_pq_error(error_message)); + } + + Ok(PQPivotData { + dim, + pq_table, + centroids, + chunk_offsets + }) + } + + fn mem_index_file(&self) -> String { + self.index_path_prefix.clone() + "_mem.index" + } + + fn disk_index_file(&self) -> String { + self.index_path_prefix.clone() + "_disk.index" + } + + fn warmup_query_prefix(&self) -> String { + self.index_path_prefix.clone() + "_sample" + } + + pub fn pq_pivot_file(&self) -> String { + self.index_path_prefix.clone() + ".bin_pq_pivots.bin" + } + + pub fn compressed_pq_pivot_file(&self) -> String { + self.index_path_prefix.clone() + ".bin_pq_compressed.bin" + } +} + +#[cfg(test)] +mod disk_index_storage_test { + use std::fs; + + use crate::test_utils::get_test_file_path; + + use super::*; + + const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin"; + const DISK_INDEX_PATH_PREFIX: &str = "tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2"; + const TRUTH_DISK_LAYOUT: &str = + "tests/data/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index"; + + #[test] + fn create_disk_layout_test() { + let storage = DiskIndexStorage::::new( + get_test_file_path(TEST_DATA_FILE), + get_test_file_path(DISK_INDEX_PATH_PREFIX), + ).unwrap(); + storage.create_disk_layout().unwrap(); + + let disk_layout_file = storage.disk_index_file(); + let rust_disk_layout = fs::read(disk_layout_file.as_str()).unwrap(); + let truth_disk_layout = fs::read(get_test_file_path(TRUTH_DISK_LAYOUT).as_str()).unwrap(); + + assert!(rust_disk_layout == truth_disk_layout); + + fs::remove_file(disk_layout_file.as_str()).expect("Failed to delete file"); + } + + #[test] + fn load_pivot_test() { + let dim: usize = 128; + let num_pq_chunk: usize = 1; + let pivot_file_prefix: &str = "tests/data/siftsmall_learn"; + let storage = DiskIndexStorage::::new( + get_test_file_path(TEST_DATA_FILE), + pivot_file_prefix.to_string(), + ).unwrap(); + + let pq_pivot_data = + storage.load_pq_pivots_bin(&num_pq_chunk).unwrap(); + + assert_eq!(pq_pivot_data.pq_table.len(), NUM_PQ_CENTROIDS * dim); + assert_eq!(pq_pivot_data.centroids.len(), dim); + + assert_eq!(pq_pivot_data.chunk_offsets[0], 0); + assert_eq!(pq_pivot_data.chunk_offsets[1], dim); + assert_eq!(pq_pivot_data.chunk_offsets.len(), num_pq_chunk + 1); + } + + #[test] + #[should_panic(expected = "ERROR: PQ k-means pivot file not found.")] + fn load_pivot_file_not_exist_test() { + let num_pq_chunk: usize = 1; + let pivot_file_prefix: &str = "tests/data/siftsmall_learn_file_not_exist"; + let storage = DiskIndexStorage::::new( + get_test_file_path(TEST_DATA_FILE), + pivot_file_prefix.to_string(), + ).unwrap(); + let _ = storage.load_pq_pivots_bin(&num_pq_chunk).unwrap(); + } +} diff --git a/rust/diskann/src/storage/mod.rs b/rust/diskann/src/storage/mod.rs new file mode 100644 index 000000000..03c5b8e82 --- /dev/null +++ b/rust/diskann/src/storage/mod.rs @@ -0,0 +1,12 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod disk_index_storage; +pub use disk_index_storage::*; + +mod disk_graph_storage; +pub use disk_graph_storage::*; + +mod pq_storage; +pub use pq_storage::*; diff --git a/rust/diskann/src/storage/pq_storage.rs b/rust/diskann/src/storage/pq_storage.rs new file mode 100644 index 000000000..b1d3fa05a --- /dev/null +++ b/rust/diskann/src/storage/pq_storage.rs @@ -0,0 +1,367 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use byteorder::{LittleEndian, ReadBytesExt}; +use rand::distributions::{Distribution, Uniform}; +use std::fs::File; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::mem; + +use crate::common::{ANNError, ANNResult}; +use crate::utils::CachedReader; +use crate::utils::{ + convert_types_u32_usize, convert_types_u64_usize, convert_types_usize_u32, + convert_types_usize_u64, convert_types_usize_u8, save_bin_f32, save_bin_u32, save_bin_u64, +}; +use crate::utils::{file_exists, load_bin, open_file_to_write, METADATA_SIZE}; + +#[derive(Debug)] +pub struct PQStorage { + /// Pivot table path + pivot_file: String, + + /// Compressed pivot path + compressed_pivot_file: String, + + /// Data used to construct PQ table and PQ compressed table + pq_data_file: String, + + /// PQ data reader + pq_data_file_reader: File, +} + +impl PQStorage { + pub fn new( + pivot_file: &str, + compressed_pivot_file: &str, + pq_data_file: &str, + ) -> std::io::Result { + let pq_data_file_reader = File::open(pq_data_file)?; + Ok(Self { + pivot_file: pivot_file.to_string(), + compressed_pivot_file: compressed_pivot_file.to_string(), + pq_data_file: pq_data_file.to_string(), + pq_data_file_reader, + }) + } + + pub fn write_compressed_pivot_metadata(&self, npts: i32, pq_chunk: i32) -> std::io::Result<()> { + let mut writer = open_file_to_write(&self.compressed_pivot_file)?; + writer.write_all(&npts.to_le_bytes())?; + writer.write_all(&pq_chunk.to_le_bytes())?; + Ok(()) + } + + pub fn write_compressed_pivot_data( + &self, + compressed_base: &[usize], + num_centers: usize, + block_size: usize, + num_pq_chunks: usize, + ) -> std::io::Result<()> { + let mut writer = open_file_to_write(&self.compressed_pivot_file)?; + writer.seek(SeekFrom::Start((std::mem::size_of::() * 2) as u64))?; + if num_centers > 256 { + writer.write_all(unsafe { + std::slice::from_raw_parts( + compressed_base.as_ptr() as *const u8, + block_size * num_pq_chunks * std::mem::size_of::(), + ) + })?; + } else { + let compressed_base_u8 = + convert_types_usize_u8(compressed_base, block_size, num_pq_chunks); + writer.write_all(&compressed_base_u8)?; + } + Ok(()) + } + + pub fn write_pivot_data( + &self, + full_pivot_data: &[f32], + centroid: &[f32], + chunk_offsets: &[usize], + num_centers: usize, + dim: usize, + ) -> std::io::Result<()> { + let mut cumul_bytes: Vec = vec![0; 4]; + cumul_bytes[0] = METADATA_SIZE; + cumul_bytes[1] = cumul_bytes[0] + + save_bin_f32( + &self.pivot_file, + full_pivot_data, + num_centers, + dim, + cumul_bytes[0], + )?; + cumul_bytes[2] = + cumul_bytes[1] + save_bin_f32(&self.pivot_file, centroid, dim, 1, cumul_bytes[1])?; + + // Because the writer only can write u32, u64 but not usize, so we need to convert the type first. + let chunk_offsets_u64 = convert_types_usize_u32(chunk_offsets, chunk_offsets.len(), 1); + cumul_bytes[3] = cumul_bytes[2] + + save_bin_u32( + &self.pivot_file, + &chunk_offsets_u64, + chunk_offsets.len(), + 1, + cumul_bytes[2], + )?; + + let cumul_bytes_u64 = convert_types_usize_u64(&cumul_bytes, 4, 1); + save_bin_u64(&self.pivot_file, &cumul_bytes_u64, cumul_bytes.len(), 1, 0)?; + + Ok(()) + } + + pub fn pivot_data_exist(&self) -> bool { + file_exists(&self.pivot_file) + } + + pub fn read_pivot_metadata(&self) -> std::io::Result<(usize, usize)> { + let (_, file_num_centers, file_dim) = load_bin::(&self.pivot_file, METADATA_SIZE)?; + Ok((file_num_centers, file_dim)) + } + + pub fn load_pivot_data( + &self, + num_pq_chunks: &usize, + num_centers: &usize, + dim: &usize, + ) -> ANNResult<(Vec, Vec, Vec)> { + // Load file offset data. File saved as offset data(4*1) -> pivot data(centroid num*dim) -> centroid of dim data(dim*1) -> chunk offset data(chunksize+1*1) + // Because we only can write u64 rather than usize, so the file stored as u64 type. Need to convert to usize when use. + let (data, offset_num, nc) = load_bin::(&self.pivot_file, 0)?; + let file_offset_data = convert_types_u64_usize(&data, offset_num, nc); + if offset_num != 4 { + let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", &self.pivot_file, offset_num); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, pivot_num, pivot_dim) = load_bin::(&self.pivot_file, file_offset_data[0])?; + let full_pivot_data = data; + if pivot_num != *num_centers || pivot_dim != *dim { + let error_message = format!("Error reading pq_pivots file {}. file_num_centers = {}, file_dim = {} but expecting {} centers in {} dimensions.", &self.pivot_file, pivot_num, pivot_dim, num_centers, dim); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, centroid_dim, nc) = load_bin::(&self.pivot_file, file_offset_data[1])?; + let centroid = data; + if centroid_dim != *dim || nc != 1 { + let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", &self.pivot_file, centroid_dim, nc, dim); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, chunk_offset_number, nc) = + load_bin::(&self.pivot_file, file_offset_data[2])?; + let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_number, nc); + if chunk_offset_number != *num_pq_chunks + 1 || nc != 1 { + let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_number, nc, num_pq_chunks + 1); + return Err(ANNError::log_pq_error(error_message)); + } + Ok((full_pivot_data, centroid, chunk_offsets)) + } + + pub fn read_pq_data_metadata(&mut self) -> std::io::Result<(usize, usize)> { + let npts_i32 = self.pq_data_file_reader.read_i32::()?; + let dim_i32 = self.pq_data_file_reader.read_i32::()?; + let num_points = npts_i32 as usize; + let dim = dim_i32 as usize; + Ok((num_points, dim)) + } + + pub fn read_pq_block_data( + &mut self, + cur_block_size: usize, + dim: usize, + ) -> std::io::Result> { + let mut buf = vec![0u8; cur_block_size * dim * std::mem::size_of::()]; + self.pq_data_file_reader.read_exact(&mut buf)?; + + let ptr = buf.as_ptr() as *const T; + let block_data = unsafe { std::slice::from_raw_parts(ptr, cur_block_size * dim) }; + Ok(block_data.to_vec()) + } + + /// streams data from the file, and samples each vector with probability p_val + /// and returns a matrix of size slice_size* ndims as floating point type. + /// the slice_size and ndims are set inside the function. + /// # Arguments + /// * `file_name` - filename where the data is + /// * `p_val` - possibility to sample data + /// * `sampled_vectors` - sampled vector chose by p_val possibility + /// * `slice_size` - how many sampled data return + /// * `dim` - each sample data dimension + pub fn gen_random_slice>( + &self, + mut p_val: f64, + ) -> ANNResult<(Vec, usize, usize)> { + let read_blk_size = 64 * 1024 * 1024; + let mut reader = CachedReader::new(&self.pq_data_file, read_blk_size)?; + + let npts = reader.read_u32()? as usize; + let dim = reader.read_u32()? as usize; + let mut sampled_vectors: Vec = Vec::new(); + let mut slice_size = 0; + p_val = if p_val < 1f64 { p_val } else { 1f64 }; + + let mut generator = rand::thread_rng(); + let distribution = Uniform::from(0.0..1.0); + + for _ in 0..npts { + let mut cur_vector_bytes = vec![0u8; dim * mem::size_of::()]; + reader.read(&mut cur_vector_bytes)?; + let random_value = distribution.sample(&mut generator); + if random_value < p_val { + let ptr = cur_vector_bytes.as_ptr() as *const T; + let cur_vector_t = unsafe { std::slice::from_raw_parts(ptr, dim) }; + sampled_vectors.extend(cur_vector_t.iter().map(|&t| t.into())); + slice_size += 1; + } + } + + Ok((sampled_vectors, slice_size, dim)) + } +} + +#[cfg(test)] +mod pq_storage_tests { + use rand::Rng; + + use super::*; + use crate::utils::gen_random_slice; + + const DATA_FILE: &str = "tests/data/siftsmall_learn.bin"; + const PQ_PIVOT_PATH: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin"; + const PQ_COMPRESSED_PATH: &str = "tests/data/empty_pq_compressed.bin"; + + #[test] + fn new_test() { + let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE); + assert!(result.is_ok()); + } + + #[test] + fn write_compressed_pivot_metadata_test() { + let compress_pivot_path = "write_compressed_pivot_metadata_test.bin"; + let result = PQStorage::new(PQ_PIVOT_PATH, compress_pivot_path, DATA_FILE).unwrap(); + + _ = result.write_compressed_pivot_metadata(100, 20); + let mut result_reader = File::open(compress_pivot_path).unwrap(); + let npts_i32 = result_reader.read_i32::().unwrap(); + let dim_i32 = result_reader.read_i32::().unwrap(); + + assert_eq!(npts_i32, 100); + assert_eq!(dim_i32, 20); + + std::fs::remove_file(compress_pivot_path).unwrap(); + } + + #[test] + fn write_compressed_pivot_data_test() { + let compress_pivot_path = "write_compressed_pivot_data_test.bin"; + let result = PQStorage::new(PQ_PIVOT_PATH, compress_pivot_path, DATA_FILE).unwrap(); + + let mut rng = rand::thread_rng(); + + let num_centers = 256; + let block_size = 4; + let num_pq_chunks = 2; + let compressed_base: Vec = (0..block_size * num_pq_chunks) + .map(|_| rng.gen_range(0..num_centers)) + .collect(); + _ = result.write_compressed_pivot_data( + &compressed_base, + num_centers, + block_size, + num_pq_chunks, + ); + + let mut result_reader = File::open(compress_pivot_path).unwrap(); + _ = result_reader.read_i32::().unwrap(); + _ = result_reader.read_i32::().unwrap(); + let mut buf = vec![0u8; block_size * num_pq_chunks * std::mem::size_of::()]; + result_reader.read_exact(&mut buf).unwrap(); + + let ptr = buf.as_ptr() as *const u8; + let block_data = unsafe { std::slice::from_raw_parts(ptr, block_size * num_pq_chunks) }; + + for index in 0..block_data.len() { + assert_eq!(compressed_base[index], block_data[index] as usize); + } + std::fs::remove_file(compress_pivot_path).unwrap(); + } + + #[test] + fn pivot_data_exist_test() { + let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap(); + assert!(result.pivot_data_exist()); + + let pivot_path = "not_exist_pivot_path.bin"; + let result = PQStorage::new(pivot_path, PQ_COMPRESSED_PATH, DATA_FILE).unwrap(); + assert!(!result.pivot_data_exist()); + } + + #[test] + fn read_pivot_metadata_test() { + let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap(); + let (npt, dim) = result.read_pivot_metadata().unwrap(); + + assert_eq!(npt, 256); + assert_eq!(dim, 128); + } + + #[test] + fn load_pivot_data_test() { + let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap(); + let (pq_pivot_data, centroids, chunk_offsets) = + result.load_pivot_data(&1, &256, &128).unwrap(); + + assert_eq!(pq_pivot_data.len(), 256 * 128); + assert_eq!(centroids.len(), 128); + assert_eq!(chunk_offsets.len(), 2); + } + + #[test] + fn read_pq_data_metadata_test() { + let mut result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap(); + let (npt, dim) = result.read_pq_data_metadata().unwrap(); + + assert_eq!(npt, 25000); + assert_eq!(dim, 128); + } + + #[test] + fn gen_random_slice_test() { + let file_name = "gen_random_slice_test.bin"; + //npoints=2, dim=8 + let data: [u8; 72] = [ + 2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, + 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, + 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, + 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, + 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41, + ]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let (sampled_vectors, slice_size, ndims) = + gen_random_slice::(file_name, 1f64).unwrap(); + let mut start = 8; + (0..sampled_vectors.len()).for_each(|i| { + assert_eq!(sampled_vectors[i].to_le_bytes(), data[start..start + 4]); + start += 4; + }); + assert_eq!(sampled_vectors.len(), 16); + assert_eq!(slice_size, 2); + assert_eq!(ndims, 8); + + let (sampled_vectors, slice_size, ndims) = + gen_random_slice::(file_name, 0f64).unwrap(); + assert_eq!(sampled_vectors.len(), 0); + assert_eq!(slice_size, 0); + assert_eq!(ndims, 8); + + std::fs::remove_file(file_name).expect("Failed to delete file"); + } +} diff --git a/rust/diskann/src/test_utils/inmem_index_initialization.rs b/rust/diskann/src/test_utils/inmem_index_initialization.rs new file mode 100644 index 000000000..db3b58179 --- /dev/null +++ b/rust/diskann/src/test_utils/inmem_index_initialization.rs @@ -0,0 +1,74 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use vector::Metric; + +use crate::index::InmemIndex; +use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder; +use crate::model::{IndexConfiguration}; +use crate::model::vertex::DIM_128; +use crate::utils::{file_exists, load_metadata_from_file}; + +use super::get_test_file_path; + +// f32, 128 DIM and 256 points source data +const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin"; +const NUM_POINTS_TO_LOAD: usize = 256; + +pub fn create_index_with_test_data() -> InmemIndex { + let index_write_parameters = IndexWriteParametersBuilder::new(50, 4).with_alpha(1.2).build(); + let config = IndexConfiguration::new( + Metric::L2, + 128, + 128, + 256, + false, + 0, + false, + 0, + 1.0f32, + index_write_parameters); + let mut index: InmemIndex = InmemIndex::new(config).unwrap(); + + build_test_index(&mut index, get_test_file_path(TEST_DATA_FILE).as_str(), NUM_POINTS_TO_LOAD); + + index.start = index.dataset.calculate_medoid_point_id().unwrap(); + + index +} + +fn build_test_index(index: &mut InmemIndex, filename: &str, num_points_to_load: usize) { + if !file_exists(filename) { + panic!("ERROR: Data file {} does not exist.", filename); + } + + let (file_num_points, file_dim) = load_metadata_from_file(filename).unwrap(); + if file_num_points > index.configuration.max_points { + panic!( + "ERROR: Driver requests loading {} points and file has {} points, + but index can support only {} points as specified in configuration.", + num_points_to_load, file_num_points, index.configuration.max_points + ); + } + + if num_points_to_load > file_num_points { + panic!( + "ERROR: Driver requests loading {} points and file has only {} points.", + num_points_to_load, file_num_points + ); + } + + if file_dim != index.configuration.dim { + panic!( + "ERROR: Driver requests loading {} dimension, but file has {} dimension.", + index.configuration.dim, file_dim + ); + } + + index.dataset.build_from_file(filename, num_points_to_load).unwrap(); + + println!("Using only first {} from file.", num_points_to_load); + + index.num_active_pts = num_points_to_load; +} diff --git a/rust/diskann/src/test_utils/mod.rs b/rust/diskann/src/test_utils/mod.rs new file mode 100644 index 000000000..fc8de5f30 --- /dev/null +++ b/rust/diskann/src/test_utils/mod.rs @@ -0,0 +1,11 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod inmem_index_initialization; + +/// test files should be placed under tests folder +pub fn get_test_file_path(relative_path: &str) -> String { + format!("{}/{}", env!("CARGO_MANIFEST_DIR"), relative_path) +} + diff --git a/rust/diskann/src/utils/bit_vec_extension.rs b/rust/diskann/src/utils/bit_vec_extension.rs new file mode 100644 index 000000000..9571a726e --- /dev/null +++ b/rust/diskann/src/utils/bit_vec_extension.rs @@ -0,0 +1,45 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::cmp::Ordering; + +use bit_vec::BitVec; + +pub trait BitVecExtension { + fn resize(&mut self, new_len: usize, value: bool); +} + +impl BitVecExtension for BitVec { + fn resize(&mut self, new_len: usize, value: bool) { + let old_len = self.len(); + match new_len.cmp(&old_len) { + Ordering::Less => self.truncate(new_len), + Ordering::Greater => self.grow(new_len - old_len, value), + Ordering::Equal => {} + } + } +} + +#[cfg(test)] +mod bit_vec_extension_test { + use super::*; + + #[test] + fn resize_test() { + let mut bitset = BitVec::new(); + + bitset.resize(10, false); + assert_eq!(bitset.len(), 10); + assert!(bitset.none()); + + bitset.resize(11, true); + assert_eq!(bitset.len(), 11); + assert!(bitset[10]); + + bitset.resize(5, false); + assert_eq!(bitset.len(), 5); + assert!(bitset.none()); + } +} + diff --git a/rust/diskann/src/utils/cached_reader.rs b/rust/diskann/src/utils/cached_reader.rs new file mode 100644 index 000000000..1a21f1a77 --- /dev/null +++ b/rust/diskann/src/utils/cached_reader.rs @@ -0,0 +1,160 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::fs::File; +use std::io::{Seek, Read}; + +use crate::common::{ANNResult, ANNError}; + +/// Sequential cached reads +pub struct CachedReader { + /// File reader + reader: File, + + /// # bytes to cache in one shot read + cache_size: u64, + + /// Underlying buf for cache + cache_buf: Vec, + + /// Offset into cache_buf for cur_pos + cur_off: u64, + + /// File size + fsize: u64, +} + +impl CachedReader { + pub fn new(filename: &str, cache_size: u64) -> std::io::Result { + let mut reader = File::open(filename)?; + let metadata = reader.metadata()?; + let fsize = metadata.len(); + + let cache_size = cache_size.min(fsize); + let mut cache_buf = vec![0; cache_size as usize]; + reader.read_exact(&mut cache_buf)?; + println!("Opened: {}, size: {}, cache_size: {}", filename, fsize, cache_size); + + Ok(Self { + reader, + cache_size, + cache_buf, + cur_off: 0, + fsize, + }) + } + + pub fn get_file_size(&self) -> u64 { + self.fsize + } + + pub fn read(&mut self, read_buf: &mut [u8]) -> ANNResult<()> { + let n_bytes = read_buf.len() as u64; + if n_bytes <= (self.cache_size - self.cur_off) { + // case 1: cache contains all data + read_buf.copy_from_slice(&self.cache_buf[(self.cur_off as usize)..(self.cur_off as usize + n_bytes as usize)]); + self.cur_off += n_bytes; + } else { + // case 2: cache contains some data + let cached_bytes = self.cache_size - self.cur_off; + if n_bytes - cached_bytes > self.fsize - self.reader.stream_position()? { + return Err(ANNError::log_index_error(format!( + "Reading beyond end of file, n_bytes: {} cached_bytes: {} fsize: {} current pos: {}", + n_bytes, cached_bytes, self.fsize, self.reader.stream_position()?)) + ); + } + + read_buf[..cached_bytes as usize].copy_from_slice(&self.cache_buf[self.cur_off as usize..]); + // go to disk and fetch more data + self.reader.read_exact(&mut read_buf[cached_bytes as usize..])?; + // reset cur off + self.cur_off = self.cache_size; + + let size_left = self.fsize - self.reader.stream_position()?; + if size_left >= self.cache_size { + self.reader.read_exact(&mut self.cache_buf)?; + self.cur_off = 0; + } + // note that if size_left < cache_size, then cur_off = cache_size, + // so subsequent reads will all be directly from file + } + Ok(()) + } + + pub fn read_u32(&mut self) -> ANNResult { + let mut bytes = [0u8; 4]; + self.read(&mut bytes)?; + Ok(u32::from_le_bytes(bytes)) + } +} + +#[cfg(test)] +mod cached_reader_test { + use std::fs; + + use super::*; + + #[test] + fn cached_reader_works() { + let file_name = "cached_reader_works_test.bin"; + //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8] + let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3, + 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let mut reader = CachedReader::new(file_name, 8).unwrap(); + assert_eq!(reader.get_file_size(), 72); + assert_eq!(reader.cache_size, 8); + + let mut all_from_cache_buf = vec![0; 4]; + reader.read(all_from_cache_buf.as_mut_slice()).unwrap(); + assert_eq!(all_from_cache_buf, [2, 0, 1, 2]); + assert_eq!(reader.cur_off, 4); + + let mut partial_from_cache_buf = vec![0; 6]; + reader.read(partial_from_cache_buf.as_mut_slice()).unwrap(); + assert_eq!(partial_from_cache_buf, [8, 0, 1, 3, 0x00, 0x01]); + assert_eq!(reader.cur_off, 0); + + let mut over_cache_size_buf = vec![0; 60]; + reader.read(over_cache_size_buf.as_mut_slice()).unwrap(); + assert_eq!( + over_cache_size_buf, + [0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11] + ); + + let mut remaining_less_than_cache_size_buf = vec![0; 2]; + reader.read(remaining_less_than_cache_size_buf.as_mut_slice()).unwrap(); + assert_eq!(remaining_less_than_cache_size_buf, [0x80, 0x41]); + assert_eq!(reader.cur_off, reader.cache_size); + + fs::remove_file(file_name).expect("Failed to delete file"); + } + + #[test] + #[should_panic(expected = "n_bytes: 73 cached_bytes: 8 fsize: 72 current pos: 8")] + fn failed_for_reading_beyond_end_of_file() { + let file_name = "failed_for_reading_beyond_end_of_file_test.bin"; + //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8] + let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3, + 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let mut reader = CachedReader::new(file_name, 8).unwrap(); + fs::remove_file(file_name).expect("Failed to delete file"); + + let mut over_size_buf = vec![0; 73]; + reader.read(over_size_buf.as_mut_slice()).unwrap(); + } +} + diff --git a/rust/diskann/src/utils/cached_writer.rs b/rust/diskann/src/utils/cached_writer.rs new file mode 100644 index 000000000..d3929bef2 --- /dev/null +++ b/rust/diskann/src/utils/cached_writer.rs @@ -0,0 +1,142 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::io::{Write, Seek, SeekFrom}; +use std::fs::{OpenOptions, File}; +use std::path::Path; + +pub struct CachedWriter { + /// File writer + writer: File, + + /// # bytes to cache for one shot write + cache_size: u64, + + /// Underlying buf for cache + cache_buf: Vec, + + /// Offset into cache_buf for cur_pos + cur_off: u64, + + /// File size + fsize: u64, +} + +impl CachedWriter { + pub fn new(filename: &str, cache_size: u64) -> std::io::Result { + let writer = OpenOptions::new() + .write(true) + .create(true) + .open(Path::new(filename))?; + + if cache_size == 0 { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "Cache size must be greater than 0")); + } + + println!("Opened: {}, cache_size: {}", filename, cache_size); + Ok(Self { + writer, + cache_size, + cache_buf: vec![0; cache_size as usize], + cur_off: 0, + fsize: 0, + }) + } + + pub fn flush(&mut self) -> std::io::Result<()> { + // dump any remaining data in memory + if self.cur_off > 0 { + self.flush_cache()?; + } + + self.writer.flush()?; + println!("Finished writing {}B", self.fsize); + Ok(()) + } + + pub fn get_file_size(&self) -> u64 { + self.fsize + } + + /// Writes n_bytes from write_buf to the underlying cache + pub fn write(&mut self, write_buf: &[u8]) -> std::io::Result<()> { + let n_bytes = write_buf.len() as u64; + if n_bytes <= (self.cache_size - self.cur_off) { + // case 1: cache can take all data + self.cache_buf[(self.cur_off as usize)..((self.cur_off + n_bytes) as usize)].copy_from_slice(&write_buf[..n_bytes as usize]); + self.cur_off += n_bytes; + } else { + // case 2: cache cant take all data + // go to disk and write existing cache data + self.writer.write_all(&self.cache_buf[..self.cur_off as usize])?; + self.fsize += self.cur_off; + // write the new data to disk + self.writer.write_all(write_buf)?; + self.fsize += n_bytes; + // clear cache data and reset cur_off + self.cache_buf.fill(0); + self.cur_off = 0; + } + Ok(()) + } + + pub fn reset(&mut self) -> std::io::Result<()> { + self.flush_cache()?; + self.writer.seek(SeekFrom::Start(0))?; + Ok(()) + } + + fn flush_cache(&mut self) -> std::io::Result<()> { + self.writer.write_all(&self.cache_buf[..self.cur_off as usize])?; + self.fsize += self.cur_off; + self.cache_buf.fill(0); + self.cur_off = 0; + Ok(()) + } +} + +impl Drop for CachedWriter { + fn drop(&mut self) { + let _ = self.flush(); + } +} + +#[cfg(test)] +mod cached_writer_test { + use std::fs; + + use super::*; + + #[test] + fn cached_writer_works() { + let file_name = "cached_writer_works_test.bin"; + //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8] + let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3, + 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41]; + + let mut writer = CachedWriter::new(file_name, 8).unwrap(); + assert_eq!(writer.get_file_size(), 0); + assert_eq!(writer.cache_size, 8); + assert_eq!(writer.get_file_size(), 0); + + let cache_all_buf = &data[0..4]; + writer.write(cache_all_buf).unwrap(); + assert_eq!(&writer.cache_buf[..4], cache_all_buf); + assert_eq!(&writer.cache_buf[4..], vec![0; 4]); + assert_eq!(writer.cur_off, 4); + assert_eq!(writer.get_file_size(), 0); + + let write_all_buf = &data[4..10]; + writer.write(write_all_buf).unwrap(); + assert_eq!(writer.cache_buf, vec![0; 8]); + assert_eq!(writer.cur_off, 0); + assert_eq!(writer.get_file_size(), 10); + + fs::remove_file(file_name).expect("Failed to delete file"); + } +} + diff --git a/rust/diskann/src/utils/file_util.rs b/rust/diskann/src/utils/file_util.rs new file mode 100644 index 000000000..f187d0128 --- /dev/null +++ b/rust/diskann/src/utils/file_util.rs @@ -0,0 +1,377 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! File operations + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use std::{mem, io}; +use std::fs::{self, File, OpenOptions}; +use std::io::{Read, BufReader, Write, Seek, SeekFrom}; +use std::path::Path; + +use crate::model::data_store::DatasetDto; + +/// Read metadata of data file. +pub fn load_metadata_from_file(file_name: &str) -> std::io::Result<(usize, usize)> { + let file = File::open(file_name)?; + let mut reader = BufReader::new(file); + + let npoints = reader.read_i32::()? as usize; + let ndims = reader.read_i32::()? as usize; + + Ok((npoints, ndims)) +} + +/// Read the deleted vertex ids from file. +pub fn load_ids_to_delete_from_file(file_name: &str) -> std::io::Result<(usize, Vec)> { + // The first 4 bytes are the number of vector ids. + // The rest of the file are the vector ids in the format of usize. + // The vector ids are sorted in ascending order. + let mut file = File::open(file_name)?; + let num_ids = file.read_u32::()? as usize; + + let mut ids = Vec::with_capacity(num_ids); + for _ in 0..num_ids { + let id = file.read_u32::()?; + ids.push(id); + } + + Ok((num_ids, ids)) +} + +/// Copy data from file +/// # Arguments +/// * `bin_file` - filename where the data is +/// * `data` - destination dataset dto to which the data is copied +/// * `pts_offset` - offset of points. data will be loaded after this point in dataset +/// * `npts` - number of points read from bin_file +/// * `dim` - point dimension read from bin_file +/// * `rounded_dim` - rounded dimension (padding zero if it's > dim) +/// # Return +/// * `npts` - number of points read from bin_file +/// * `dim` - point dimension read from bin_file +pub fn copy_aligned_data_from_file( + bin_file: &str, + dataset_dto: DatasetDto, + pts_offset: usize, +) -> std::io::Result<(usize, usize)> { + let mut reader = File::open(bin_file)?; + + let npts = reader.read_i32::()? as usize; + let dim = reader.read_i32::()? as usize; + let rounded_dim = dataset_dto.rounded_dim; + let offset = pts_offset * rounded_dim; + + for i in 0..npts { + let data_slice = &mut dataset_dto.data[offset + i * rounded_dim..offset + i * rounded_dim + dim]; + let mut buf = vec![0u8; dim * mem::size_of::()]; + reader.read_exact(&mut buf)?; + + let ptr = buf.as_ptr() as *const T; + let temp_slice = unsafe { std::slice::from_raw_parts(ptr, dim) }; + data_slice.copy_from_slice(temp_slice); + + (i * rounded_dim + dim..i * rounded_dim + rounded_dim).for_each(|j| { + dataset_dto.data[j] = T::default(); + }); + } + + Ok((npts, dim)) +} + +/// Open a file to write +/// # Arguments +/// * `writer` - mutable File reference +/// * `file_name` - file name +#[inline] +pub fn open_file_to_write(file_name: &str) -> std::io::Result { + OpenOptions::new() + .write(true) + .create(true) + .open(Path::new(file_name)) +} + +/// Delete a file +/// # Arguments +/// * `file_name` - file name +pub fn delete_file(file_name: &str) -> std::io::Result<()> { + if file_exists(file_name) { + fs::remove_file(file_name)?; + } + + Ok(()) +} + +/// Check whether file exists or not +pub fn file_exists(filename: &str) -> bool { + std::path::Path::new(filename).exists() +} + +/// Save data to file +/// # Arguments +/// * `filename` - filename where the data is +/// * `data` - information data +/// * `npts` - number of points +/// * `ndims` - point dimension +/// * `aligned_dim` - aligned dimension +/// * `offset` - data offset in file +pub fn save_data_in_base_dimensions( + filename: &str, + data: &mut [T], + npts: usize, + ndims: usize, + aligned_dim: usize, + offset: usize, +) -> std::io::Result { + let mut writer = open_file_to_write(filename)?; + let npts_i32 = npts as i32; + let ndims_i32 = ndims as i32; + let bytes_written = 2 * std::mem::size_of::() + npts * ndims * (std::mem::size_of::()); + + writer.seek(std::io::SeekFrom::Start(offset as u64))?; + writer.write_all(&npts_i32.to_le_bytes())?; + writer.write_all(&ndims_i32.to_le_bytes())?; + let data_ptr = data.as_ptr() as *const u8; + for i in 0..npts { + let middle_offset = i * aligned_dim * std::mem::size_of::(); + let middle_slice = unsafe { std::slice::from_raw_parts(data_ptr.add(middle_offset), ndims * std::mem::size_of::()) }; + writer.write_all(middle_slice)?; + } + writer.flush()?; + Ok(bytes_written) +} + +/// Read data file +/// # Arguments +/// * `bin_file` - filename where the data is +/// * `file_offset` - data offset in file +/// * `data` - information data +/// * `npts` - number of points +/// * `ndims` - point dimension +pub fn load_bin( + bin_file: &str, + file_offset: usize) -> std::io::Result<(Vec, usize, usize)> +{ + let mut reader = File::open(bin_file)?; + reader.seek(std::io::SeekFrom::Start(file_offset as u64))?; + let npts = reader.read_i32::()? as usize; + let dim = reader.read_i32::()? as usize; + + let size = npts * dim * std::mem::size_of::(); + let mut buf = vec![0u8; size]; + reader.read_exact(&mut buf)?; + + let ptr = buf.as_ptr() as *const T; + let data = unsafe { std::slice::from_raw_parts(ptr, npts * dim)}; + + Ok((data.to_vec(), npts, dim)) +} + +/// Get file size +pub fn get_file_size(filename: &str) -> io::Result { + let reader = File::open(filename)?; + let metadata = reader.metadata()?; + Ok(metadata.len()) +} + +macro_rules! save_bin { + ($name:ident, $t:ty, $write_func:ident) => { + /// Write data into file + pub fn $name(filename: &str, data: &[$t], num_pts: usize, dims: usize, offset: usize) -> std::io::Result { + let mut writer = open_file_to_write(filename)?; + + println!("Writing bin: {}", filename); + writer.seek(SeekFrom::Start(offset as u64))?; + let num_pts_i32 = num_pts as i32; + let dims_i32 = dims as i32; + let bytes_written = num_pts * dims * mem::size_of::<$t>() + 2 * mem::size_of::(); + + writer.write_i32::(num_pts_i32)?; + writer.write_i32::(dims_i32)?; + println!("bin: #pts = {}, #dims = {}, size = {}B", num_pts, dims, bytes_written); + + for item in data.iter() { + writer.$write_func::(*item)?; + } + + writer.flush()?; + + println!("Finished writing bin."); + Ok(bytes_written) + } + }; +} + +save_bin!(save_bin_f32, f32, write_f32); +save_bin!(save_bin_u64, u64, write_u64); +save_bin!(save_bin_u32, u32, write_u32); + +#[cfg(test)] +mod file_util_test { + use crate::model::data_store::InmemDataset; + use std::fs; + use super::*; + + pub const DIM_8: usize = 8; + + #[test] + fn load_metadata_test() { + let file_name = "test_load_metadata_test.bin"; + let data = [200, 0, 0, 0, 128, 0, 0, 0]; // 200 and 128 in little endian bytes + std::fs::write(file_name, data).expect("Failed to write sample file"); + match load_metadata_from_file(file_name) { + Ok((npoints, ndims)) => { + assert!(npoints == 200); + assert!(ndims == 128); + }, + Err(_e) => {}, + } + fs::remove_file(file_name).expect("Failed to delete file"); + } + + #[test] + fn load_data_test() { + let file_name = "test_load_data_test.bin"; + //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8] + let data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0, + 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let mut dataset = InmemDataset::::new(2, 1f32).unwrap(); + + match copy_aligned_data_from_file(file_name, dataset.into_dto(), 0) { + Ok((num_points, dim)) => { + fs::remove_file(file_name).expect("Failed to delete file"); + assert!(num_points == 2); + assert!(dim == 8); + assert!(dataset.data.len() == 16); + + let first_vertex = dataset.get_vertex(0).unwrap(); + let second_vertex = dataset.get_vertex(1).unwrap(); + + assert!(*first_vertex.vector() == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + assert!(*second_vertex.vector() == [9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]); + }, + Err(e) => { + fs::remove_file(file_name).expect("Failed to delete file"); + panic!("{}", e) + }, + } + } + + #[test] + fn open_file_to_write_test() { + let file_name = "test_open_file_to_write_test.bin"; + let mut writer = File::create(file_name).unwrap(); + let data = [200, 0, 0, 0, 128, 0, 0, 0]; + writer.write(&data).expect("Failed to write sample file"); + + let _ = open_file_to_write(file_name); + + fs::remove_file(file_name).expect("Failed to delete file"); + } + + #[test] + fn delete_file_test() { + let file_name = "test_delete_file_test.bin"; + let mut file = File::create(file_name).unwrap(); + writeln!(file, "test delete file").unwrap(); + + let result = delete_file(file_name); + + assert!(result.is_ok()); + assert!(fs::metadata(file_name).is_err()); + } + + #[test] + fn save_data_in_base_dimensions_test() { + //npoints=2, dim=8 + let mut data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0, + 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41]; + let num_points = 2; + let dim = DIM_8; + let data_file = "save_data_in_base_dimensions_test.data"; + match save_data_in_base_dimensions(data_file, &mut data, num_points, dim, DIM_8, 0) { + Ok(num) => { + assert!(file_exists(data_file)); + assert_eq!(num, 2 * std::mem::size_of::() + num_points * dim * std::mem::size_of::()); + fs::remove_file(data_file).expect("Failed to delete file"); + }, + Err(e) => { + fs::remove_file(data_file).expect("Failed to delete file"); + panic!("{}", e) + } + } + } + + #[test] + fn save_bin_test() { + let filename = "save_bin_test"; + let data = vec![0u64, 1u64, 2u64]; + let num_pts = data.len(); + let dims = 1; + let bytes_written = save_bin_u64(filename, &data, num_pts, dims, 0).unwrap(); + assert_eq!(bytes_written, 32); + + let mut file = File::open(filename).unwrap(); + let mut buffer = vec![]; + + let npts_read = file.read_i32::().unwrap() as usize; + let dims_read = file.read_i32::().unwrap() as usize; + + file.read_to_end(&mut buffer).unwrap(); + let data_read: Vec = buffer + .chunks_exact(8) + .map(|b| u64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]])) + .collect(); + + std::fs::remove_file(filename).unwrap(); + + assert_eq!(num_pts, npts_read); + assert_eq!(dims, dims_read); + assert_eq!(data, data_read); + } + + #[test] + fn load_bin_test() { + let file_name = "load_bin_test"; + let data = vec![0u64, 1u64, 2u64]; + let num_pts = data.len(); + let dims = 1; + let bytes_written = save_bin_u64(file_name, &data, num_pts, dims, 0).unwrap(); + assert_eq!(bytes_written, 32); + + let (load_data, load_num_pts, load_dims) = load_bin::(file_name, 0).unwrap(); + assert_eq!(load_num_pts, num_pts); + assert_eq!(load_dims, dims); + assert_eq!(load_data, data); + std::fs::remove_file(file_name).unwrap(); + } + + #[test] + fn load_bin_offset_test() { + let offset:usize = 32; + let file_name = "load_bin_offset_test"; + let data = vec![0u64, 1u64, 2u64]; + let num_pts = data.len(); + let dims = 1; + let bytes_written = save_bin_u64(file_name, &data, num_pts, dims, offset).unwrap(); + assert_eq!(bytes_written, 32); + + let (load_data, load_num_pts, load_dims) = load_bin::(file_name, offset).unwrap(); + assert_eq!(load_num_pts, num_pts); + assert_eq!(load_dims, dims); + assert_eq!(load_data, data); + std::fs::remove_file(file_name).unwrap(); + } +} + diff --git a/rust/diskann/src/utils/hashset_u32.rs b/rust/diskann/src/utils/hashset_u32.rs new file mode 100644 index 000000000..15db687d6 --- /dev/null +++ b/rust/diskann/src/utils/hashset_u32.rs @@ -0,0 +1,46 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use hashbrown::HashSet; +use std::{hash::BuildHasherDefault, ops::{Deref, DerefMut}}; +use fxhash::FxHasher; + +lazy_static::lazy_static! { + /// Singleton hasher. + static ref HASHER: BuildHasherDefault = { + BuildHasherDefault::::default() + }; +} + +pub struct HashSetForU32 { + hashset: HashSet::>, +} + +impl HashSetForU32 { + pub fn with_capacity(capacity: usize) -> HashSetForU32 { + let hashset = HashSet::>::with_capacity_and_hasher(capacity, HASHER.clone()); + HashSetForU32 { + hashset + } + } +} + +impl Deref for HashSetForU32 { + type Target = HashSet::>; + + fn deref(&self) -> &Self::Target { + &self.hashset + } +} + +impl DerefMut for HashSetForU32 { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.hashset + } +} + diff --git a/rust/diskann/src/utils/kmeans.rs b/rust/diskann/src/utils/kmeans.rs new file mode 100644 index 000000000..d1edffad7 --- /dev/null +++ b/rust/diskann/src/utils/kmeans.rs @@ -0,0 +1,430 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Aligned allocator + +use rand::{distributions::Uniform, prelude::Distribution, thread_rng}; +use rayon::prelude::*; +use std::cmp::min; + +use crate::common::ANNResult; +use crate::utils::math_util::{calc_distance, compute_closest_centers, compute_vecs_l2sq}; + +/// Run Lloyds one iteration +/// Given data in row-major num_points * dim, and centers in row-major +/// num_centers * dim and squared lengths of ata points, output the closest +/// center to each data point, update centers, and also return inverted index. +/// If closest_centers == NULL, will allocate memory and return. +/// Similarly, if closest_docs == NULL, will allocate memory and return. +#[allow(clippy::too_many_arguments)] +fn lloyds_iter( + data: &[f32], + num_points: usize, + dim: usize, + centers: &mut [f32], + num_centers: usize, + docs_l2sq: &[f32], + mut closest_docs: &mut Vec>, + closest_center: &mut [u32], +) -> ANNResult { + let compute_residual = true; + + closest_docs.iter_mut().for_each(|doc| doc.clear()); + + compute_closest_centers( + data, + num_points, + dim, + centers, + num_centers, + 1, + closest_center, + Some(&mut closest_docs), + Some(docs_l2sq), + )?; + + centers.fill(0.0); + + centers + .par_chunks_mut(dim) + .enumerate() + .for_each(|(c, center)| { + let mut cluster_sum = vec![0.0; dim]; + for &doc_index in &closest_docs[c] { + let current = &data[doc_index * dim..(doc_index + 1) * dim]; + for (j, current_val) in current.iter().enumerate() { + cluster_sum[j] += *current_val as f64; + } + } + if !closest_docs[c].is_empty() { + for (i, sum_val) in cluster_sum.iter().enumerate() { + center[i] = (*sum_val / closest_docs[c].len() as f64) as f32; + } + } + }); + + let mut residual = 0.0; + if compute_residual { + let buf_pad: usize = 32; + let chunk_size: usize = 2 * 8192; + let nchunks = + num_points / chunk_size + (if num_points % chunk_size == 0 { 0 } else { 1 } as usize); + + let mut residuals: Vec = vec![0.0; nchunks * buf_pad]; + + residuals + .par_iter_mut() + .enumerate() + .for_each(|(chunk, res)| { + for d in (chunk * chunk_size)..min(num_points, (chunk + 1) * chunk_size) { + *res += calc_distance( + &data[d * dim..(d + 1) * dim], + ¢ers[closest_center[d] as usize * dim..], + dim, + ); + } + }); + + for chunk in 0..nchunks { + residual += residuals[chunk * buf_pad]; + } + } + + Ok(residual) +} + +/// Run Lloyds until max_reps or stopping criterion +/// If you pass NULL for closest_docs and closest_center, it will NOT return +/// the results, else it will assume appropriate allocation as closest_docs = +/// new vec [num_centers], and closest_center = new size_t[num_points] +/// Final centers are output in centers as row-major num_centers * dim. +fn run_lloyds( + data: &[f32], + num_points: usize, + dim: usize, + centers: &mut [f32], + num_centers: usize, + max_reps: usize, +) -> ANNResult<(Vec>, Vec, f32)> { + let mut residual = f32::MAX; + + let mut closest_docs = vec![Vec::new(); num_centers]; + let mut closest_center = vec![0; num_points]; + + let mut docs_l2sq = vec![0.0; num_points]; + compute_vecs_l2sq(&mut docs_l2sq, data, num_points, dim); + + let mut old_residual; + + for i in 0..max_reps { + old_residual = residual; + + residual = lloyds_iter( + data, + num_points, + dim, + centers, + num_centers, + &docs_l2sq, + &mut closest_docs, + &mut closest_center, + )?; + + if (i != 0 && (old_residual - residual) / residual < 0.00001) || (residual < f32::EPSILON) { + println!( + "Residuals unchanged: {} becomes {}. Early termination.", + old_residual, residual + ); + break; + } + } + + Ok((closest_docs, closest_center, residual)) +} + +/// Assume memory allocated for pivot_data as new float[num_centers * dim] +/// and select randomly num_centers points as pivots +fn selecting_pivots( + data: &[f32], + num_points: usize, + dim: usize, + pivot_data: &mut [f32], + num_centers: usize, +) { + let mut picked = Vec::new(); + let mut rng = thread_rng(); + let distribution = Uniform::from(0..num_points); + + for j in 0..num_centers { + let mut tmp_pivot = distribution.sample(&mut rng); + while picked.contains(&tmp_pivot) { + tmp_pivot = distribution.sample(&mut rng); + } + picked.push(tmp_pivot); + let data_offset = tmp_pivot * dim; + let pivot_offset = j * dim; + pivot_data[pivot_offset..pivot_offset + dim] + .copy_from_slice(&data[data_offset..data_offset + dim]); + } +} + +/// Select pivots in k-means++ algorithm +/// Points that are farther away from the already chosen centroids +/// have a higher probability of being selected as the next centroid. +/// The k-means++ algorithm helps avoid poor initial centroid +/// placement that can result in suboptimal clustering. +fn k_meanspp_selecting_pivots( + data: &[f32], + num_points: usize, + dim: usize, + pivot_data: &mut [f32], + num_centers: usize, +) { + if num_points > (1 << 23) { + println!("ERROR: n_pts {} currently not supported for k-means++, maximum is 8388608. Falling back to random pivot selection.", num_points); + selecting_pivots(data, num_points, dim, pivot_data, num_centers); + return; + } + + let mut picked: Vec = Vec::new(); + let mut rng = thread_rng(); + let real_distribution = Uniform::from(0.0..1.0); + let int_distribution = Uniform::from(0..num_points); + + let init_id = int_distribution.sample(&mut rng); + let mut num_picked = 1; + + picked.push(init_id); + let init_data_offset = init_id * dim; + pivot_data[0..dim].copy_from_slice(&data[init_data_offset..init_data_offset + dim]); + + let mut dist = vec![0.0; num_points]; + + dist.par_iter_mut().enumerate().for_each(|(i, dist_i)| { + *dist_i = calc_distance( + &data[i * dim..(i + 1) * dim], + &data[init_id * dim..(init_id + 1) * dim], + dim, + ); + }); + + let mut dart_val: f64; + let mut tmp_pivot = 0; + let mut sum_flag = false; + + while num_picked < num_centers { + dart_val = real_distribution.sample(&mut rng); + + let mut sum: f64 = 0.0; + for item in dist.iter().take(num_points) { + sum += *item as f64; + } + if sum == 0.0 { + sum_flag = true; + } + + dart_val *= sum; + + let mut prefix_sum: f64 = 0.0; + for (i, pivot) in dist.iter().enumerate().take(num_points) { + tmp_pivot = i; + if dart_val >= prefix_sum && dart_val < (prefix_sum + *pivot as f64) { + break; + } + + prefix_sum += *pivot as f64; + } + + if picked.contains(&tmp_pivot) && !sum_flag { + continue; + } + + picked.push(tmp_pivot); + let pivot_offset = num_picked * dim; + let data_offset = tmp_pivot * dim; + pivot_data[pivot_offset..pivot_offset + dim] + .copy_from_slice(&data[data_offset..data_offset + dim]); + + dist.par_iter_mut().enumerate().for_each(|(i, dist_i)| { + *dist_i = (*dist_i).min(calc_distance( + &data[i * dim..(i + 1) * dim], + &data[tmp_pivot * dim..(tmp_pivot + 1) * dim], + dim, + )); + }); + + num_picked += 1; + } +} + +/// k-means algorithm interface +pub fn k_means_clustering( + data: &[f32], + num_points: usize, + dim: usize, + centers: &mut [f32], + num_centers: usize, + max_reps: usize, +) -> ANNResult<(Vec>, Vec, f32)> { + k_meanspp_selecting_pivots(data, num_points, dim, centers, num_centers); + let (closest_docs, closest_center, residual) = + run_lloyds(data, num_points, dim, centers, num_centers, max_reps)?; + Ok((closest_docs, closest_center, residual)) +} + +#[cfg(test)] +mod kmeans_test { + use super::*; + use approx::assert_relative_eq; + use rand::Rng; + + #[test] + fn lloyds_iter_test() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + + let data: Vec = (1..=num_points * dim).map(|x| x as f32).collect(); + let mut centers = [1.0, 2.0, 7.0, 8.0, 19.0, 20.0]; + + let mut closest_docs: Vec> = vec![vec![]; num_centers]; + let mut closest_center: Vec = vec![0; num_points]; + let docs_l2sq: Vec = data + .chunks(dim) + .map(|chunk| chunk.iter().map(|val| val.powi(2)).sum()) + .collect(); + + let residual = lloyds_iter( + &data, + num_points, + dim, + &mut centers, + num_centers, + &docs_l2sq, + &mut closest_docs, + &mut closest_center, + ) + .unwrap(); + + let expected_centers: [f32; 6] = [2.0, 3.0, 9.0, 10.0, 17.0, 18.0]; + let expected_closest_docs: Vec> = + vec![vec![0, 1], vec![2, 3, 4, 5, 6], vec![7, 8, 9]]; + let expected_closest_center: [u32; 10] = [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]; + let expected_residual: f32 = 100.0; + + // sort data for assert + centers.sort_by(|a, b| a.partial_cmp(b).unwrap()); + for inner_vec in &mut closest_docs { + inner_vec.sort(); + } + closest_center.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + assert_eq!(centers, expected_centers); + assert_eq!(closest_docs, expected_closest_docs); + assert_eq!(closest_center, expected_closest_center); + assert_relative_eq!(residual, expected_residual, epsilon = 1.0e-6_f32); + } + + #[test] + fn run_lloyds_test() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + let max_reps = 5; + + let data: Vec = (1..=num_points * dim).map(|x| x as f32).collect(); + let mut centers = [1.0, 2.0, 7.0, 8.0, 19.0, 20.0]; + + let (mut closest_docs, mut closest_center, residual) = + run_lloyds(&data, num_points, dim, &mut centers, num_centers, max_reps).unwrap(); + + let expected_centers: [f32; 6] = [3.0, 4.0, 10.0, 11.0, 17.0, 18.0]; + let expected_closest_docs: Vec> = + vec![vec![0, 1, 2], vec![3, 4, 5, 6], vec![7, 8, 9]]; + let expected_closest_center: [u32; 10] = [0, 0, 0, 1, 1, 1, 1, 2, 2, 2]; + let expected_residual: f32 = 72.0; + + // sort data for assert + centers.sort_by(|a, b| a.partial_cmp(b).unwrap()); + for inner_vec in &mut closest_docs { + inner_vec.sort(); + } + closest_center.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + assert_eq!(centers, expected_centers); + assert_eq!(closest_docs, expected_closest_docs); + assert_eq!(closest_center, expected_closest_center); + assert_relative_eq!(residual, expected_residual, epsilon = 1.0e-6_f32); + } + + #[test] + fn selecting_pivots_test() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + + // Generate some random data points + let mut rng = rand::thread_rng(); + let data: Vec = (0..num_points * dim).map(|_| rng.gen()).collect(); + + let mut pivot_data = vec![0.0; num_centers * dim]; + + selecting_pivots(&data, num_points, dim, &mut pivot_data, num_centers); + + // Verify that each pivot point corresponds to a point in the data + for i in 0..num_centers { + let pivot_offset = i * dim; + let pivot = &pivot_data[pivot_offset..(pivot_offset + dim)]; + + // Make sure the pivot is found in the data + let mut found = false; + for j in 0..num_points { + let data_offset = j * dim; + let point = &data[data_offset..(data_offset + dim)]; + + if pivot == point { + found = true; + break; + } + } + assert!(found, "Pivot not found in data"); + } + } + + #[test] + fn k_meanspp_selecting_pivots_test() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + + // Generate some random data points + let mut rng = rand::thread_rng(); + let data: Vec = (0..num_points * dim).map(|_| rng.gen()).collect(); + + let mut pivot_data = vec![0.0; num_centers * dim]; + + k_meanspp_selecting_pivots(&data, num_points, dim, &mut pivot_data, num_centers); + + // Verify that each pivot point corresponds to a point in the data + for i in 0..num_centers { + let pivot_offset = i * dim; + let pivot = &pivot_data[pivot_offset..pivot_offset + dim]; + + // Make sure the pivot is found in the data + let mut found = false; + for j in 0..num_points { + let data_offset = j * dim; + let point = &data[data_offset..data_offset + dim]; + + if pivot == point { + found = true; + break; + } + } + assert!(found, "Pivot not found in data"); + } + } +} diff --git a/rust/diskann/src/utils/math_util.rs b/rust/diskann/src/utils/math_util.rs new file mode 100644 index 000000000..ef30c76ff --- /dev/null +++ b/rust/diskann/src/utils/math_util.rs @@ -0,0 +1,481 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Aligned allocator + +extern crate cblas; +extern crate openblas_src; + +use cblas::{sgemm, snrm2, Layout, Transpose}; +use rayon::prelude::*; +use std::{ + cmp::{min, Ordering}, + collections::BinaryHeap, + sync::{Arc, Mutex}, +}; + +use crate::common::{ANNError, ANNResult}; + +struct PivotContainer { + piv_id: usize, + piv_dist: f32, +} + +impl PartialOrd for PivotContainer { + fn partial_cmp(&self, other: &Self) -> Option { + other.piv_dist.partial_cmp(&self.piv_dist) + } +} + +impl Ord for PivotContainer { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Treat NaN as less than all other values. + // piv_dist should never be NaN. + self.partial_cmp(other).unwrap_or(Ordering::Less) + } +} + +impl PartialEq for PivotContainer { + fn eq(&self, other: &Self) -> bool { + self.piv_dist == other.piv_dist + } +} + +impl Eq for PivotContainer {} + +/// Calculate the Euclidean distance between two vectors +pub fn calc_distance(vec_1: &[f32], vec_2: &[f32], dim: usize) -> f32 { + let mut dist = 0.0; + for j in 0..dim { + let diff = vec_1[j] - vec_2[j]; + dist += diff * diff; + } + dist +} + +/// Compute L2-squared norms of data stored in row-major num_points * dim, +/// need to be pre-allocated +pub fn compute_vecs_l2sq(vecs_l2sq: &mut [f32], data: &[f32], num_points: usize, dim: usize) { + assert_eq!(vecs_l2sq.len(), num_points); + + vecs_l2sq + .par_iter_mut() + .enumerate() + .for_each(|(n_iter, vec_l2sq)| { + let slice = &data[n_iter * dim..(n_iter + 1) * dim]; + let norm = unsafe { snrm2(dim as i32, slice, 1) }; + *vec_l2sq = norm * norm; + }); +} + +/// Calculate k closest centers to data of num_points * dim (row-major) +/// Centers is num_centers * dim (row-major) +/// data_l2sq has pre-computed squared norms of data +/// centers_l2sq has pre-computed squared norms of centers +/// Pre-allocated center_index will contain id of nearest center +/// Pre-allocated dist_matrix should be num_points * num_centers and contain squared distances +/// Default value of k is 1 +/// Ideally used only by compute_closest_centers +#[allow(clippy::too_many_arguments)] +pub fn compute_closest_centers_in_block( + data: &[f32], + num_points: usize, + dim: usize, + centers: &[f32], + num_centers: usize, + docs_l2sq: &[f32], + centers_l2sq: &[f32], + center_index: &mut [u32], + dist_matrix: &mut [f32], + k: usize, +) -> ANNResult<()> { + if k > num_centers { + return Err(ANNError::log_index_error(format!( + "ERROR: k ({}) > num_centers({})", + k, num_centers + ))); + } + + let ones_a: Vec = vec![1.0; num_centers]; + let ones_b: Vec = vec![1.0; num_points]; + + unsafe { + sgemm( + Layout::RowMajor, + Transpose::None, + Transpose::Ordinary, + num_points as i32, + num_centers as i32, + 1, + 1.0, + docs_l2sq, + 1, + &ones_a, + 1, + 0.0, + dist_matrix, + num_centers as i32, + ); + } + + unsafe { + sgemm( + Layout::RowMajor, + Transpose::None, + Transpose::Ordinary, + num_points as i32, + num_centers as i32, + 1, + 1.0, + &ones_b, + 1, + centers_l2sq, + 1, + 1.0, + dist_matrix, + num_centers as i32, + ); + } + + unsafe { + sgemm( + Layout::RowMajor, + Transpose::None, + Transpose::Ordinary, + num_points as i32, + num_centers as i32, + dim as i32, + -2.0, + data, + dim as i32, + centers, + dim as i32, + 1.0, + dist_matrix, + num_centers as i32, + ); + } + + if k == 1 { + center_index + .par_iter_mut() + .enumerate() + .for_each(|(i, center_idx)| { + let mut min = f32::MAX; + let current = &dist_matrix[i * num_centers..(i + 1) * num_centers]; + let mut min_idx = 0; + for (j, &distance) in current.iter().enumerate() { + if distance < min { + min = distance; + min_idx = j; + } + } + *center_idx = min_idx as u32; + }); + } else { + center_index + .par_chunks_mut(k) + .enumerate() + .for_each(|(i, center_chunk)| { + let current = &dist_matrix[i * num_centers..(i + 1) * num_centers]; + let mut top_k_queue = BinaryHeap::new(); + for (j, &distance) in current.iter().enumerate() { + let this_piv = PivotContainer { + piv_id: j, + piv_dist: distance, + }; + if top_k_queue.len() < k { + top_k_queue.push(this_piv); + } else { + // Safe unwrap, top_k_queue is not empty + #[allow(clippy::unwrap_used)] + let mut top = top_k_queue.peek_mut().unwrap(); + if this_piv.piv_dist < top.piv_dist { + *top = this_piv; + } + } + } + for (_j, center_idx) in center_chunk.iter_mut().enumerate() { + if let Some(this_piv) = top_k_queue.pop() { + *center_idx = this_piv.piv_id as u32; + } else { + break; + } + } + }); + } + + Ok(()) +} + +/// Given data in num_points * new_dim row major +/// Pivots stored in full_pivot_data as num_centers * new_dim row major +/// Calculate the k closest pivot for each point and store it in vector +/// closest_centers_ivf (row major, num_points*k) (which needs to be allocated +/// outside) Additionally, if inverted index is not null (and pre-allocated), +/// it will return inverted index for each center, assuming each of the inverted +/// indices is an empty vector. Additionally, if pts_norms_squared is not null, +/// then it will assume that point norms are pre-computed and use those values +#[allow(clippy::too_many_arguments)] +pub fn compute_closest_centers( + data: &[f32], + num_points: usize, + dim: usize, + pivot_data: &[f32], + num_centers: usize, + k: usize, + closest_centers_ivf: &mut [u32], + mut inverted_index: Option<&mut Vec>>, + pts_norms_squared: Option<&[f32]>, +) -> ANNResult<()> { + if k > num_centers { + return Err(ANNError::log_index_error(format!( + "ERROR: k ({}) > num_centers({})", + k, num_centers + ))); + } + + let _is_norm_given_for_pts = pts_norms_squared.is_some(); + + let mut pivs_norms_squared = vec![0.0; num_centers]; + + let mut pts_norms_squared = if let Some(pts_norms) = pts_norms_squared { + pts_norms.to_vec() + } else { + let mut norms_squared = vec![0.0; num_points]; + compute_vecs_l2sq(&mut norms_squared, data, num_points, dim); + norms_squared + }; + + compute_vecs_l2sq(&mut pivs_norms_squared, pivot_data, num_centers, dim); + + let par_block_size = num_points; + let n_blocks = if num_points % par_block_size == 0 { + num_points / par_block_size + } else { + num_points / par_block_size + 1 + }; + + let mut closest_centers = vec![0u32; par_block_size * k]; + let mut distance_matrix = vec![0.0; num_centers * par_block_size]; + + for cur_blk in 0..n_blocks { + let data_cur_blk = &data[cur_blk * par_block_size * dim..]; + let num_pts_blk = min(par_block_size, num_points - cur_blk * par_block_size); + let pts_norms_blk = &mut pts_norms_squared[cur_blk * par_block_size..]; + + compute_closest_centers_in_block( + data_cur_blk, + num_pts_blk, + dim, + pivot_data, + num_centers, + pts_norms_blk, + &pivs_norms_squared, + &mut closest_centers, + &mut distance_matrix, + k, + )?; + + closest_centers_ivf.clone_from_slice(&closest_centers); + + if let Some(inverted_index_inner) = inverted_index.as_mut() { + let inverted_index_arc = Arc::new(Mutex::new(inverted_index_inner)); + + (0..num_points) + .into_par_iter() + .try_for_each(|j| -> ANNResult<()> { + let this_center_id = closest_centers[j] as usize; + let mut guard = inverted_index_arc.lock().map_err(|err| { + ANNError::log_index_error(format!( + "PoisonError: Lock poisoned when acquiring inverted_index_arc, err={}", + err + )) + })?; + guard[this_center_id].push(j); + + Ok(()) + })?; + } + } + + Ok(()) +} + +/// If to_subtract is true, will subtract nearest center from each row. +/// Else will add. +/// Output will be in data_load itself. +/// Nearest centers need to be provided in closest_centers. +pub fn process_residuals( + data_load: &mut [f32], + num_points: usize, + dim: usize, + cur_pivot_data: &[f32], + num_centers: usize, + closest_centers: &[u32], + to_subtract: bool, +) { + println!( + "Processing residuals of {} points in {} dimensions using {} centers", + num_points, dim, num_centers + ); + + data_load + .par_chunks_mut(dim) + .enumerate() + .for_each(|(n_iter, chunk)| { + let cur_pivot_index = closest_centers[n_iter] as usize * dim; + for d_iter in 0..dim { + if to_subtract { + chunk[d_iter] -= cur_pivot_data[cur_pivot_index + d_iter]; + } else { + chunk[d_iter] += cur_pivot_data[cur_pivot_index + d_iter]; + } + } + }); +} + +#[cfg(test)] +mod math_util_test { + use super::*; + use approx::assert_abs_diff_eq; + + #[test] + fn calc_distance_test() { + let vec1 = vec![1.0, 2.0, 3.0]; + let vec2 = vec![4.0, 5.0, 6.0]; + let dim = vec1.len(); + + let dist = calc_distance(&vec1, &vec2, dim); + + let expected = 27.0; + + assert_eq!(dist, expected); + } + + #[test] + fn compute_vecs_l2sq_test() { + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let num_points = 2; + let dim = 3; + let mut vecs_l2sq = vec![0.0; num_points]; + + compute_vecs_l2sq(&mut vecs_l2sq, &data, num_points, dim); + + let expected = vec![14.0, 77.0]; + + assert_eq!(vecs_l2sq.len(), num_points); + assert_abs_diff_eq!(vecs_l2sq[0], expected[0], epsilon = 1e-6); + assert_abs_diff_eq!(vecs_l2sq[1], expected[1], epsilon = 1e-6); + } + + #[test] + fn compute_closest_centers_in_block_test() { + let num_points = 10; + let dim = 5; + let num_centers = 3; + let data = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, + 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, + 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, + ]; + let centers = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 31.0, 32.0, 33.0, 34.0, 35.0, + ]; + let mut docs_l2sq = vec![0.0; num_points]; + compute_vecs_l2sq(&mut docs_l2sq, &data, num_points, dim); + let mut centers_l2sq = vec![0.0; num_centers]; + compute_vecs_l2sq(&mut centers_l2sq, ¢ers, num_centers, dim); + let mut center_index = vec![0; num_points]; + let mut dist_matrix = vec![0.0; num_points * num_centers]; + let k = 1; + + compute_closest_centers_in_block( + &data, + num_points, + dim, + ¢ers, + num_centers, + &docs_l2sq, + ¢ers_l2sq, + &mut center_index, + &mut dist_matrix, + k, + ) + .unwrap(); + + assert_eq!(center_index.len(), num_points); + let expected_center_index = vec![0, 0, 0, 1, 1, 1, 2, 2, 2, 2]; + assert_abs_diff_eq!(*center_index, expected_center_index); + + assert_eq!(dist_matrix.len(), num_points * num_centers); + let expected_dist_matrix = vec![ + 0.0, 2000.0, 4500.0, 125.0, 1125.0, 3125.0, 500.0, 500.0, 2000.0, 1125.0, 125.0, + 1125.0, 2000.0, 0.0, 500.0, 3125.0, 125.0, 125.0, 4500.0, 500.0, 0.0, 6125.0, 1125.0, + 125.0, 8000.0, 2000.0, 500.0, 10125.0, 3125.0, 1125.0, + ]; + assert_abs_diff_eq!(*dist_matrix, expected_dist_matrix, epsilon = 1e-2); + } + + #[test] + fn test_compute_closest_centers() { + let num_points = 4; + let dim = 3; + let num_centers = 2; + let mut data = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ]; + let pivot_data = vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0]; + let k = 1; + + let mut closest_centers_ivf = vec![0u32; num_points * k]; + let mut inverted_index: Vec> = vec![vec![], vec![]]; + + compute_closest_centers( + &data, + num_points, + dim, + &pivot_data, + num_centers, + k, + &mut closest_centers_ivf, + Some(&mut inverted_index), + None, + ) + .unwrap(); + + assert_eq!(closest_centers_ivf, vec![0, 0, 1, 1]); + + for vec in inverted_index.iter_mut() { + vec.sort_unstable(); + } + assert_eq!(inverted_index, vec![vec![0, 1], vec![2, 3]]); + } + + #[test] + fn process_residuals_test() { + let mut data_load = vec![1.0, 2.0, 3.0, 4.0]; + let num_points = 2; + let dim = 2; + let cur_pivot_data = vec![0.5, 1.5, 2.5, 3.5]; + let num_centers = 2; + let closest_centers = vec![0, 1]; + let to_subtract = true; + + process_residuals( + &mut data_load, + num_points, + dim, + &cur_pivot_data, + num_centers, + &closest_centers, + to_subtract, + ); + + assert_eq!(data_load, vec![0.5, 0.5, 0.5, 0.5]); + } +} diff --git a/rust/diskann/src/utils/mod.rs b/rust/diskann/src/utils/mod.rs new file mode 100644 index 000000000..df174f8f0 --- /dev/null +++ b/rust/diskann/src/utils/mod.rs @@ -0,0 +1,34 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod file_util; +pub use file_util::*; + +#[allow(clippy::module_inception)] +pub mod utils; +pub use utils::*; + +pub mod bit_vec_extension; +pub use bit_vec_extension::*; + +pub mod rayon_util; +pub use rayon_util::*; + +pub mod timer; +pub use timer::*; + +pub mod cached_reader; +pub use cached_reader::*; + +pub mod cached_writer; +pub use cached_writer::*; + +pub mod partition; +pub use partition::*; + +pub mod math_util; +pub use math_util::*; + +pub mod kmeans; +pub use kmeans::*; diff --git a/rust/diskann/src/utils/partition.rs b/rust/diskann/src/utils/partition.rs new file mode 100644 index 000000000..dbe686226 --- /dev/null +++ b/rust/diskann/src/utils/partition.rs @@ -0,0 +1,151 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::mem; +use std::{fs::File, path::Path}; +use std::io::{Write, Seek, SeekFrom}; +use rand::distributions::{Distribution, Uniform}; + +use crate::common::ANNResult; + +use super::CachedReader; + +/// streams data from the file, and samples each vector with probability p_val +/// and returns a matrix of size slice_size* ndims as floating point type. +/// the slice_size and ndims are set inside the function. +/// # Arguments +/// * `file_name` - filename where the data is +/// * `p_val` - possibility to sample data +/// * `sampled_vectors` - sampled vector chose by p_val possibility +/// * `slice_size` - how many sampled data return +/// * `dim` - each sample data dimension +pub fn gen_random_slice>(data_file: &str, mut p_val: f64) -> ANNResult<(Vec, usize, usize)> { + let read_blk_size = 64 * 1024 * 1024; + let mut reader = CachedReader::new(data_file, read_blk_size)?; + + let npts = reader.read_u32()? as usize; + let dim = reader.read_u32()? as usize; + let mut sampled_vectors: Vec = Vec::new(); + let mut slice_size = 0; + p_val = if p_val < 1f64 { p_val } else { 1f64 }; + + let mut generator = rand::thread_rng(); + let distribution = Uniform::from(0.0..1.0); + + for _ in 0..npts { + let mut cur_vector_bytes = vec![0u8; dim * mem::size_of::()]; + reader.read(&mut cur_vector_bytes)?; + let random_value = distribution.sample(&mut generator); + if random_value < p_val { + let ptr = cur_vector_bytes.as_ptr() as *const T; + let cur_vector_t = unsafe { std::slice::from_raw_parts(ptr, dim) }; + sampled_vectors.extend(cur_vector_t.iter().map(|&t| t.into())); + slice_size += 1; + } + } + + Ok((sampled_vectors, slice_size, dim)) +} + +/// Generate random sample data and write into output_file +pub fn gen_sample_data(data_file: &str, output_file: &str, sampling_rate: f64) -> ANNResult<()> { + let read_blk_size = 64 * 1024 * 1024; + let mut reader = CachedReader::new(data_file, read_blk_size)?; + + let sample_data_path = format!("{}_data.bin", output_file); + let sample_ids_path = format!("{}_ids.bin", output_file); + let mut sample_data_writer = File::create(Path::new(&sample_data_path))?; + let mut sample_id_writer = File::create(Path::new(&sample_ids_path))?; + + let mut num_sampled_pts = 0u32; + let one_const = 1u32; + let mut generator = rand::thread_rng(); + let distribution = Uniform::from(0.0..1.0); + + let npts_u32 = reader.read_u32()?; + let dim_u32 = reader.read_u32()?; + let dim = dim_u32 as usize; + sample_data_writer.write_all(&num_sampled_pts.to_le_bytes())?; + sample_data_writer.write_all(&dim_u32.to_le_bytes())?; + sample_id_writer.write_all(&num_sampled_pts.to_le_bytes())?; + sample_id_writer.write_all(&one_const.to_le_bytes())?; + + for id in 0..npts_u32 { + let mut cur_row_bytes = vec![0u8; dim * mem::size_of::()]; + reader.read(&mut cur_row_bytes)?; + let random_value = distribution.sample(&mut generator); + if random_value < sampling_rate { + sample_data_writer.write_all(&cur_row_bytes)?; + sample_id_writer.write_all(&id.to_le_bytes())?; + num_sampled_pts += 1; + } + } + + sample_data_writer.seek(SeekFrom::Start(0))?; + sample_data_writer.write_all(&num_sampled_pts.to_le_bytes())?; + sample_id_writer.seek(SeekFrom::Start(0))?; + sample_id_writer.write_all(&num_sampled_pts.to_le_bytes())?; + println!("Wrote {} points to sample file: {}", num_sampled_pts, sample_data_path); + + Ok(()) +} + +#[cfg(test)] +mod partition_test { + use std::{fs, io::Read}; + use byteorder::{ReadBytesExt, LittleEndian}; + + use crate::utils::file_exists; + + use super::*; + + #[test] + fn gen_sample_data_test() { + let file_name = "gen_sample_data_test.bin"; + //npoints=2, dim=8 + let data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0, + 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let sample_file_prefix = file_name.to_string() + "_sample"; + gen_sample_data::(file_name, sample_file_prefix.as_str(), 1f64).unwrap(); + + let sample_data_path = format!("{}_data.bin", sample_file_prefix); + let sample_ids_path = format!("{}_ids.bin", sample_file_prefix); + assert!(file_exists(sample_data_path.as_str())); + assert!(file_exists(sample_ids_path.as_str())); + + let mut data_file_reader = File::open(sample_data_path.as_str()).unwrap(); + let mut ids_file_reader = File::open(sample_ids_path.as_str()).unwrap(); + + let mut num_sampled_pts = data_file_reader.read_u32::().unwrap(); + assert_eq!(num_sampled_pts, 2); + num_sampled_pts = ids_file_reader.read_u32::().unwrap(); + assert_eq!(num_sampled_pts, 2); + + let dim = data_file_reader.read_u32::().unwrap() as usize; + assert_eq!(dim, 8); + assert_eq!(ids_file_reader.read_u32::().unwrap(), 1); + + let mut start = 8; + for i in 0..num_sampled_pts { + let mut data_bytes = vec![0u8; dim * 4]; + data_file_reader.read_exact(&mut data_bytes).unwrap(); + assert_eq!(data_bytes, data[start..start + dim * 4]); + + let id = ids_file_reader.read_u32::().unwrap(); + assert_eq!(id, i); + + start += dim * 4; + } + + fs::remove_file(file_name).expect("Failed to delete file"); + fs::remove_file(sample_data_path.as_str()).expect("Failed to delete file"); + fs::remove_file(sample_ids_path.as_str()).expect("Failed to delete file"); + } +} + diff --git a/rust/diskann/src/utils/rayon_util.rs b/rust/diskann/src/utils/rayon_util.rs new file mode 100644 index 000000000..f8174ee59 --- /dev/null +++ b/rust/diskann/src/utils/rayon_util.rs @@ -0,0 +1,33 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::ops::Range; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; + +use crate::common::ANNResult; + +/// based on thread_num, execute the task in parallel using Rayon or serial +#[inline] +pub fn execute_with_rayon(range: Range, num_threads: u32, f: F) -> ANNResult<()> +where F: Fn(usize) -> ANNResult<()> + Sync + Send + Copy +{ + if num_threads == 1 { + for i in range { + f(i)?; + } + Ok(()) + } else { + range.into_par_iter().try_for_each(f) + } +} + +/// set the thread count of Rayon, otherwise it will use threads as many as logical cores. +#[inline] +pub fn set_rayon_num_threads(num_threads: u32) { + std::env::set_var( + "RAYON_NUM_THREADS", + num_threads.to_string(), + ); +} + diff --git a/rust/diskann/src/utils/timer.rs b/rust/diskann/src/utils/timer.rs new file mode 100644 index 000000000..2f4b38ba7 --- /dev/null +++ b/rust/diskann/src/utils/timer.rs @@ -0,0 +1,101 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use platform::*; +use std::time::{Duration, Instant}; + +#[derive(Clone)] +pub struct Timer { + check_point: Instant, + pid: Option, + cycles: Option, +} + +impl Default for Timer { + fn default() -> Self { + Self::new() + } +} + +impl Timer { + pub fn new() -> Timer { + let pid = get_process_handle(); + let cycles = get_process_cycle_time(pid); + Timer { + check_point: Instant::now(), + pid, + cycles, + } + } + + pub fn reset(&mut self) { + self.check_point = Instant::now(); + self.cycles = get_process_cycle_time(self.pid); + } + + pub fn elapsed(&self) -> Duration { + Instant::now().duration_since(self.check_point) + } + + pub fn elapsed_seconds(&self) -> f64 { + self.elapsed().as_secs_f64() + } + + pub fn elapsed_gcycles(&self) -> f32 { + let cur_cycles = get_process_cycle_time(self.pid); + if let (Some(cur_cycles), Some(cycles)) = (cur_cycles, self.cycles) { + let spent_cycles = + ((cur_cycles - cycles) as f64 * 1.0f64) / (1024 * 1024 * 1024) as f64; + return spent_cycles as f32; + } + + 0.0 + } + + pub fn elapsed_seconds_for_step(&self, step: &str) -> String { + format!( + "Time for {}: {:.3} seconds, {:.3}B cycles", + step, + self.elapsed_seconds(), + self.elapsed_gcycles() + ) + } +} + +#[cfg(test)] +mod timer_tests { + use super::*; + use std::{thread, time}; + + #[test] + fn test_new() { + let timer = Timer::new(); + assert!(timer.check_point.elapsed().as_secs() < 1); + if cfg!(windows) { + assert!(timer.pid.is_some()); + assert!(timer.cycles.is_some()); + } + else { + assert!(timer.pid.is_none()); + assert!(timer.cycles.is_none()); + } + } + + #[test] + fn test_reset() { + let mut timer = Timer::new(); + thread::sleep(time::Duration::from_millis(100)); + timer.reset(); + assert!(timer.check_point.elapsed().as_millis() < 10); + } + + #[test] + fn test_elapsed() { + let timer = Timer::new(); + thread::sleep(time::Duration::from_millis(100)); + assert!(timer.elapsed().as_millis() > 100); + assert!(timer.elapsed_seconds() > 0.1); + } +} + diff --git a/rust/diskann/src/utils/utils.rs b/rust/diskann/src/utils/utils.rs new file mode 100644 index 000000000..2e80676af --- /dev/null +++ b/rust/diskann/src/utils/utils.rs @@ -0,0 +1,154 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::sync::Mutex; +use num_traits::Num; + +/// Non recursive mutex +pub type NonRecursiveMutex = Mutex<()>; + +/// Round up X to the nearest multiple of Y +#[inline] +pub fn round_up(x: T, y: T) -> T +where T : Num + Copy +{ + div_round_up(x, y) * y +} + +/// Rounded-up division +#[inline] +pub fn div_round_up(x: T, y: T) -> T +where T : Num + Copy +{ + (x / y) + if x % y != T::zero() {T::one()} else {T::zero()} +} + +/// Round down X to the nearest multiple of Y +#[inline] +pub fn round_down(x: T, y: T) -> T +where T : Num + Copy +{ + (x / y) * y +} + +/// Is aligned +#[inline] +pub fn is_aligned(x: T, y: T) -> bool +where T : Num + Copy +{ + x % y == T::zero() +} + +#[inline] +pub fn is_512_aligned(x: u64) -> bool { + is_aligned(x, 512) +} + +#[inline] +pub fn is_4096_aligned(x: u64) -> bool { + is_aligned(x, 4096) +} + +/// all metadata of individual sub-component files is written in first 4KB for unified files +pub const METADATA_SIZE: usize = 4096; + +pub const BUFFER_SIZE_FOR_CACHED_IO: usize = 1024 * 1048576; + +pub const PBSTR: &str = "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"; + +pub const PBWIDTH: usize = 60; + +macro_rules! convert_types { + ($name:ident, $intput_type:ty, $output_type:ty) => { + /// Write data into file + pub fn $name(srcmat: &[$intput_type], npts: usize, dim: usize) -> Vec<$output_type> { + let mut destmat: Vec<$output_type> = Vec::new(); + for i in 0..npts { + for j in 0..dim { + destmat.push(srcmat[i * dim + j] as $output_type); + } + } + destmat + } + }; +} +convert_types!(convert_types_usize_u8, usize, u8); +convert_types!(convert_types_usize_u32, usize, u32); +convert_types!(convert_types_usize_u64, usize, u64); +convert_types!(convert_types_u64_usize, u64, usize); +convert_types!(convert_types_u32_usize, u32, usize); + +#[cfg(test)] +mod file_util_test { + use super::*; + use std::any::type_name; + + #[test] + fn round_up_test() { + assert_eq!(round_up(252, 8), 256); + assert_eq!(round_up(256, 8), 256); + } + + #[test] + fn div_round_up_test() { + assert_eq!(div_round_up(252, 8), 32); + assert_eq!(div_round_up(256, 8), 32); + } + + #[test] + fn round_down_test() { + assert_eq!(round_down(252, 8), 248); + assert_eq!(round_down(256, 8), 256); + } + + #[test] + fn is_aligned_test() { + assert!(!is_aligned(252, 8)); + assert!(is_aligned(256, 8)); + } + + #[test] + fn is_512_aligned_test() { + assert!(!is_512_aligned(520)); + assert!(is_512_aligned(512)); + } + + #[test] + fn is_4096_aligned_test() { + assert!(!is_4096_aligned(4090)); + assert!(is_4096_aligned(4096)); + } + + #[test] + fn convert_types_test() { + let data = vec![0u64, 1u64, 2u64]; + let output = convert_types_u64_usize(&data, 3, 1); + assert_eq!(output.len(), 3); + assert_eq!(type_of(output[0]), "usize"); + assert_eq!(output[0], 0usize); + + let data = vec![0usize, 1usize, 2usize]; + let output = convert_types_usize_u8(&data, 3, 1); + assert_eq!(output.len(), 3); + assert_eq!(type_of(output[0]), "u8"); + assert_eq!(output[0], 0u8); + + let data = vec![0usize, 1usize, 2usize]; + let output = convert_types_usize_u64(&data, 3, 1); + assert_eq!(output.len(), 3); + assert_eq!(type_of(output[0]), "u64"); + assert_eq!(output[0], 0u64); + + let data = vec![0u32, 1u32, 2u32]; + let output = convert_types_u32_usize(&data, 3, 1); + assert_eq!(output.len(), 3); + assert_eq!(type_of(output[0]), "usize"); + assert_eq!(output[0],0usize); + } + + fn type_of(_: T) -> &'static str { + type_name::() + } +} + diff --git a/rust/diskann/tests/data/delete_set_50pts.bin b/rust/diskann/tests/data/delete_set_50pts.bin new file mode 100644 index 000000000..8d520e7c7 Binary files /dev/null and b/rust/diskann/tests/data/delete_set_50pts.bin differ diff --git a/rust/diskann/tests/data/disk_index_node_data_aligned_reader_truth.bin b/rust/diskann/tests/data/disk_index_node_data_aligned_reader_truth.bin new file mode 100644 index 000000000..737a1a34d Binary files /dev/null and b/rust/diskann/tests/data/disk_index_node_data_aligned_reader_truth.bin differ diff --git a/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_alligned_reader_test.index b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_alligned_reader_test.index new file mode 100644 index 000000000..55fcbb58d Binary files /dev/null and b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_alligned_reader_test.index differ diff --git a/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index new file mode 100644 index 000000000..88a86b7da Binary files /dev/null and b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index differ diff --git a/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_mem.index b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_mem.index new file mode 100644 index 000000000..974535776 Binary files /dev/null and b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_mem.index differ diff --git a/rust/diskann/tests/data/siftsmall_learn.bin b/rust/diskann/tests/data/siftsmall_learn.bin new file mode 100644 index 000000000..e08c7af7a Binary files /dev/null and b/rust/diskann/tests/data/siftsmall_learn.bin differ diff --git a/rust/diskann/tests/data/siftsmall_learn.bin_pq_compressed.bin b/rust/diskann/tests/data/siftsmall_learn.bin_pq_compressed.bin new file mode 100644 index 000000000..5f1ddab29 Binary files /dev/null and b/rust/diskann/tests/data/siftsmall_learn.bin_pq_compressed.bin differ diff --git a/rust/diskann/tests/data/siftsmall_learn.bin_pq_pivots.bin b/rust/diskann/tests/data/siftsmall_learn.bin_pq_pivots.bin new file mode 100644 index 000000000..e84f8d8a9 Binary files /dev/null and b/rust/diskann/tests/data/siftsmall_learn.bin_pq_pivots.bin differ diff --git a/rust/diskann/tests/data/siftsmall_learn_256pts.fbin b/rust/diskann/tests/data/siftsmall_learn_256pts.fbin new file mode 100644 index 000000000..357a9db87 Binary files /dev/null and b/rust/diskann/tests/data/siftsmall_learn_256pts.fbin differ diff --git a/rust/diskann/tests/data/siftsmall_learn_256pts_2.fbin b/rust/diskann/tests/data/siftsmall_learn_256pts_2.fbin new file mode 100644 index 000000000..9528e4bd9 Binary files /dev/null and b/rust/diskann/tests/data/siftsmall_learn_256pts_2.fbin differ diff --git a/rust/diskann/tests/data/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index b/rust/diskann/tests/data/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index new file mode 100644 index 000000000..55fcbb58d Binary files /dev/null and b/rust/diskann/tests/data/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index differ diff --git a/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_R4_L50_A1.2 b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_R4_L50_A1.2 new file mode 100644 index 000000000..9c803c3fa Binary files /dev/null and b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_R4_L50_A1.2 differ diff --git a/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_saturated_R4_L50_A1.2 b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_saturated_R4_L50_A1.2 new file mode 100644 index 000000000..a9dac1013 Binary files /dev/null and b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_saturated_R4_L50_A1.2 differ diff --git a/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2 b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2 new file mode 100644 index 000000000..817009044 Binary files /dev/null and b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2 differ diff --git a/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2.data b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2.data new file mode 100644 index 000000000..357a9db87 Binary files /dev/null and b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2.data differ diff --git a/rust/logger/Cargo.toml b/rust/logger/Cargo.toml new file mode 100644 index 000000000..e750d9530 --- /dev/null +++ b/rust/logger/Cargo.toml @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "logger" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +lazy_static = "1.4.0" +log="0.4.17" +once_cell = "1.17.1" +prost = "0.11.9" +prost-types = "0.11.9" +thiserror = "1.0.40" +win_etw_macros="0.1.8" +win_etw_provider="0.1.8" + +[build-dependencies] +prost-build = "0.11.9" + +[[example]] +name="trace_example" +path= "src/examples/trace_example.rs" + +[target."cfg(target_os=\"windows\")".build-dependencies.vcpkg] +version = "0.2" + diff --git a/rust/logger/build.rs b/rust/logger/build.rs new file mode 100644 index 000000000..76058f768 --- /dev/null +++ b/rust/logger/build.rs @@ -0,0 +1,33 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::env; + +extern crate prost_build; + +fn main() { + let protopkg = vcpkg::find_package("protobuf").unwrap(); + let protobuf_path = protopkg.link_paths[0].parent().unwrap(); + + let protobuf_bin_path = protobuf_path + .join("tools") + .join("protobuf") + .join("protoc.exe") + .to_str() + .unwrap() + .to_string(); + env::set_var("PROTOC", protobuf_bin_path); + + let protobuf_inc_path = protobuf_path + .join("include") + .join("google") + .join("protobuf") + .to_str() + .unwrap() + .to_string(); + env::set_var("PROTOC_INCLUDE", protobuf_inc_path); + + prost_build::compile_protos(&["src/indexlog.proto"], &["src/"]).unwrap(); +} + diff --git a/rust/logger/src/error_logger.rs b/rust/logger/src/error_logger.rs new file mode 100644 index 000000000..50069b477 --- /dev/null +++ b/rust/logger/src/error_logger.rs @@ -0,0 +1,29 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::log_error::LogError; +use crate::logger::indexlog::{ErrorLog, Log, LogLevel}; +use crate::message_handler::send_log; + +pub fn log_error(error_message: String) -> Result<(), LogError> { + let mut log = Log::default(); + let error_log = ErrorLog { + log_level: LogLevel::Error as i32, + error_message, + }; + log.error_log = Some(error_log); + + send_log(log) +} + +#[cfg(test)] +mod error_logger_test { + use super::*; + + #[test] + fn log_error_works() { + log_error(String::from("Error")).unwrap(); + } +} + diff --git a/rust/logger/src/examples/trace_example.rs b/rust/logger/src/examples/trace_example.rs new file mode 100644 index 000000000..7933a5699 --- /dev/null +++ b/rust/logger/src/examples/trace_example.rs @@ -0,0 +1,30 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use log::{debug, info, log_enabled, warn, Level}; +use logger::trace_logger::TraceLogger; + +// cargo run --example trace_example + +fn main() { + static LOGGER: TraceLogger = TraceLogger {}; + log::set_logger(&LOGGER) + .map(|()| log::set_max_level(log::LevelFilter::Trace)) + .unwrap(); + + info!("Rust logging n = {}", 42); + warn!("This is too much fun!"); + debug!("Maybe we can make this code work"); + + let error_is_enabled = log_enabled!(Level::Error); + let warn_is_enabled = log_enabled!(Level::Warn); + let info_is_enabled = log_enabled!(Level::Info); + let debug_is_enabled = log_enabled!(Level::Debug); + let trace_is_enabled = log_enabled!(Level::Trace); + println!( + "is_enabled? error: {:5?}, warn: {:5?}, info: {:5?}, debug: {:5?}, trace: {:5?}", + error_is_enabled, warn_is_enabled, info_is_enabled, debug_is_enabled, trace_is_enabled, + ); +} + diff --git a/rust/logger/src/indexlog.proto b/rust/logger/src/indexlog.proto new file mode 100644 index 000000000..68310ae41 --- /dev/null +++ b/rust/logger/src/indexlog.proto @@ -0,0 +1,50 @@ +syntax = "proto3"; + +package diskann_logger; + +message Log { + IndexConstructionLog IndexConstructionLog = 1; + DiskIndexConstructionLog DiskIndexConstructionLog = 2; + ErrorLog ErrorLog = 3; + TraceLog TraceLog = 100; +} + +enum LogLevel { + UNSPECIFIED = 0; + Error = 1; + Warn = 2; + Info = 3; + Debug = 4; + Trace = 5; +} + +message IndexConstructionLog { + float PercentageComplete = 1; + float TimeSpentInSeconds = 2; + float GCyclesSpent = 3; + LogLevel LogLevel = 4; +} + +message DiskIndexConstructionLog { + DiskIndexConstructionCheckpoint checkpoint = 1; + float TimeSpentInSeconds = 2; + float GCyclesSpent = 3; + LogLevel LogLevel = 4; +} + +enum DiskIndexConstructionCheckpoint { + None = 0; + PqConstruction = 1; + InmemIndexBuild = 2; + DiskLayout = 3; +} + +message TraceLog { + string LogLine = 1; + LogLevel LogLevel = 2; +} + +message ErrorLog { + string ErrorMessage = 1; + LogLevel LogLevel = 2; +} \ No newline at end of file diff --git a/rust/logger/src/lib.rs b/rust/logger/src/lib.rs new file mode 100644 index 000000000..6cfe2d589 --- /dev/null +++ b/rust/logger/src/lib.rs @@ -0,0 +1,19 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![cfg_attr( + not(test), + warn(clippy::panic, clippy::unwrap_used, clippy::expect_used) +)] + +pub mod logger { + pub mod indexlog { + include!(concat!(env!("OUT_DIR"), "/diskann_logger.rs")); + } +} + +pub mod error_logger; +pub mod log_error; +pub mod message_handler; +pub mod trace_logger; diff --git a/rust/logger/src/log_error.rs b/rust/logger/src/log_error.rs new file mode 100644 index 000000000..149d094a2 --- /dev/null +++ b/rust/logger/src/log_error.rs @@ -0,0 +1,27 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::sync::mpsc::SendError; + +use crate::logger::indexlog::Log; + +#[derive(thiserror::Error, Debug, Clone)] +pub enum LogError { + /// Sender failed to send message to the channel + #[error("IOError: {err}")] + SendError { + #[from] + err: SendError, + }, + + /// PoisonError which can be returned whenever a lock is acquired + /// Both Mutexes and RwLocks are poisoned whenever a thread fails while the lock is held + #[error("LockPoisonError: {err}")] + LockPoisonError { err: String }, + + /// Failed to create EtwPublisher + #[error("EtwProviderError: {err:?}")] + ETWProviderError { err: win_etw_provider::Error }, +} + diff --git a/rust/logger/src/message_handler.rs b/rust/logger/src/message_handler.rs new file mode 100644 index 000000000..37f352a28 --- /dev/null +++ b/rust/logger/src/message_handler.rs @@ -0,0 +1,167 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::log_error::LogError; +use crate::logger::indexlog::DiskIndexConstructionCheckpoint; +use crate::logger::indexlog::Log; +use crate::logger::indexlog::LogLevel; + +use std::sync::mpsc::{self, Sender}; +use std::sync::Mutex; +use std::thread; + +use win_etw_macros::trace_logging_provider; + +trait MessagePublisher { + fn publish(&self, log_level: LogLevel, message: &str); +} + +// ETW provider - the GUID specified here is that of the default provider for Geneva Metric Extensions +// We are just using it as a placeholder until we have a version of OpenTelemetry exporter for Rust +#[trace_logging_provider(guid = "edc24920-e004-40f6-a8e1-0e6e48f39d84")] +trait EtwTraceProvider { + fn write(msg: &str); +} + +struct EtwPublisher { + provider: EtwTraceProvider, + publish_to_stdout: bool, +} + +impl EtwPublisher { + pub fn new() -> Result { + let provider = EtwTraceProvider::new(); + Ok(EtwPublisher { + provider, + publish_to_stdout: true, + }) + } +} + +fn log_level_to_etw(level: LogLevel) -> win_etw_provider::Level { + match level { + LogLevel::Error => win_etw_provider::Level::ERROR, + LogLevel::Warn => win_etw_provider::Level::WARN, + LogLevel::Info => win_etw_provider::Level::INFO, + LogLevel::Debug => win_etw_provider::Level::VERBOSE, + LogLevel::Trace => win_etw_provider::Level(6), + LogLevel::Unspecified => win_etw_provider::Level(6), + } +} + +fn i32_to_log_level(value: i32) -> LogLevel { + match value { + 0 => LogLevel::Unspecified, + 1 => LogLevel::Error, + 2 => LogLevel::Warn, + 3 => LogLevel::Info, + 4 => LogLevel::Debug, + 5 => LogLevel::Trace, + _ => LogLevel::Unspecified, + } +} + +impl MessagePublisher for EtwPublisher { + fn publish(&self, log_level: LogLevel, message: &str) { + let options = win_etw_provider::EventOptions { + level: Some(log_level_to_etw(log_level)), + ..Default::default() + }; + self.provider.write(Some(&options), message); + + if self.publish_to_stdout { + println!("{}", message); + } + } +} + +struct MessageProcessor { + sender: Mutex>, +} + +impl MessageProcessor { + pub fn start_processing() -> Self { + let (sender, receiver) = mpsc::channel::(); + thread::spawn(move || -> Result<(), LogError> { + for message in receiver { + // Process the received message + if let Some(indexlog) = message.index_construction_log { + let str = format!( + "Time for {}% of index build completed: {:.3} seconds, {:.3}B cycles", + indexlog.percentage_complete, + indexlog.time_spent_in_seconds, + indexlog.g_cycles_spent + ); + publish(i32_to_log_level(indexlog.log_level), &str)?; + } + + if let Some(disk_index_log) = message.disk_index_construction_log { + let str = format!( + "Time for disk index build [Checkpoint: {:?}] completed: {:.3} seconds, {:.3}B cycles", + DiskIndexConstructionCheckpoint::from_i32(disk_index_log.checkpoint).unwrap_or(DiskIndexConstructionCheckpoint::None), + disk_index_log.time_spent_in_seconds, + disk_index_log.g_cycles_spent + ); + publish(i32_to_log_level(disk_index_log.log_level), &str)?; + } + + if let Some(tracelog) = message.trace_log { + let str = format!("{}:{}", tracelog.log_level, tracelog.log_line); + publish(i32_to_log_level(tracelog.log_level), &str)?; + } + + if let Some(err) = message.error_log { + publish(i32_to_log_level(err.log_level), &err.error_message)?; + } + } + + Ok(()) + }); + + let sender = Mutex::new(sender); + MessageProcessor { sender } + } + + /// Log the message. + fn log(&self, message: Log) -> Result<(), LogError> { + Ok(self + .sender + .lock() + .map_err(|err| LogError::LockPoisonError { + err: err.to_string(), + })? + .send(message)?) + } +} + +lazy_static::lazy_static! { + /// Singleton logger. + static ref PROCESSOR: MessageProcessor = { + + MessageProcessor::start_processing() + }; +} + +lazy_static::lazy_static! { + /// Singleton publisher. + static ref PUBLISHER: Result = { + EtwPublisher::new() + }; +} + +/// Send a message to the logging system. +pub fn send_log(message: Log) -> Result<(), LogError> { + PROCESSOR.log(message) +} + +fn publish(log_level: LogLevel, message: &str) -> Result<(), LogError> { + match *PUBLISHER { + Ok(ref etw_publisher) => { + etw_publisher.publish(log_level, message); + Ok(()) + } + Err(ref err) => Err(LogError::ETWProviderError { err: err.clone() }), + } +} + diff --git a/rust/logger/src/trace_logger.rs b/rust/logger/src/trace_logger.rs new file mode 100644 index 000000000..96ef38611 --- /dev/null +++ b/rust/logger/src/trace_logger.rs @@ -0,0 +1,41 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::logger::indexlog::{Log, TraceLog}; +use crate::message_handler::send_log; + +use log; + +pub struct TraceLogger {} + +fn level_to_i32(value: log::Level) -> i32 { + match value { + log::Level::Error => 1, + log::Level::Warn => 2, + log::Level::Info => 3, + log::Level::Debug => 4, + log::Level::Trace => 5, + } +} + +impl log::Log for TraceLogger { + fn enabled(&self, metadata: &log::Metadata) -> bool { + metadata.level() <= log::max_level() + } + + fn log(&self, record: &log::Record) { + let message = record.args().to_string(); + let metadata = record.metadata(); + let mut log = Log::default(); + let trace_log = TraceLog { + log_line: message, + log_level: level_to_i32(metadata.level()), + }; + log.trace_log = Some(trace_log); + let _ = send_log(log); + } + + fn flush(&self) {} +} + diff --git a/rust/platform/Cargo.toml b/rust/platform/Cargo.toml new file mode 100644 index 000000000..057f9e852 --- /dev/null +++ b/rust/platform/Cargo.toml @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "platform" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +log="0.4.18" +winapi = { version = "0.3.9", features = ["errhandlingapi", "fileapi", "ioapiset", "handleapi", "winnt", "minwindef", "basetsd", "winerror", "winbase"] } + diff --git a/rust/platform/src/file_handle.rs b/rust/platform/src/file_handle.rs new file mode 100644 index 000000000..23da8796a --- /dev/null +++ b/rust/platform/src/file_handle.rs @@ -0,0 +1,212 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::ffi::CString; +use std::{io, ptr}; + +use winapi::um::fileapi::OPEN_EXISTING; +use winapi::um::winbase::{FILE_FLAG_NO_BUFFERING, FILE_FLAG_OVERLAPPED, FILE_FLAG_RANDOM_ACCESS}; +use winapi::um::winnt::{FILE_SHARE_DELETE, FILE_SHARE_READ, FILE_SHARE_WRITE, GENERIC_READ, GENERIC_WRITE}; + +use winapi::{ + shared::minwindef::DWORD, + um::{ + errhandlingapi::GetLastError, + fileapi::CreateFileA, + handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + winnt::HANDLE, + }, +}; + +pub const FILE_ATTRIBUTE_READONLY: DWORD = 0x00000001; + +/// `AccessMode` determines how a file can be accessed. +/// These modes are used when creating or opening a file to decide what operations are allowed +/// to be performed on the file. +/// +/// # Variants +/// +/// - `Read`: The file is opened in read-only mode. +/// +/// - `Write`: The file is opened in write-only mode. +/// +/// - `ReadWrite`: The file is opened for both reading and writing. +pub enum AccessMode { + Read, + Write, + ReadWrite, +} + +/// `ShareMode` determines how a file can be shared. +/// +/// These modes are used when creating or opening a file to decide what operations other +/// opening instances of the file can perform on it. +/// # Variants +/// - `None`: Prevents other processes from opening a file if they request delete, +/// read, or write access. +/// +/// - `Read`: Allows subsequent open operations on the same file to request read access. +/// +/// - `Write`: Allows subsequent open operations on the same file file to request write access. +/// +/// - `Delete`: Allows subsequent open operations on the same file file to request delete access. +pub enum ShareMode { + None, + Read, + Write, + Delete, +} + +/// # Windows File Handle Wrapper +/// +/// Introduces a Rust-friendly wrapper around the native Windows `HANDLE` object, `FileHandle`. +/// `FileHandle` provides safe creation and automatic cleanup of Windows file handles, leveraging Rust's ownership model. + +/// `FileHandle` struct that wraps a native Windows `HANDLE` object +#[cfg(target_os = "windows")] +pub struct FileHandle { + handle: HANDLE, +} + +impl FileHandle { + /// Creates a new `FileHandle` by opening an existing file with the given access and shared mode. + /// + /// This function is marked unsafe because it creates a raw pointer to the filename and try to create + /// a Windows `HANDLE` object without checking if you have sufficient permissions. + /// + /// # Safety + /// + /// Ensure that the file specified by `file_name` is valid and the calling process has + /// sufficient permissions to perform the specified `access_mode` and `share_mode` operations. + /// + /// # Parameters + /// + /// - `file_name`: The name of the file. + /// - `access_mode`: The access mode to be used for the file. + /// - `share_mode`: The share mode to be used for the file + /// + /// # Errors + /// This function will return an error if the `file_name` is invalid or if the file cannot + /// be opened with the specified `access_mode` and `share_mode`. + pub unsafe fn new( + file_name: &str, + access_mode: AccessMode, + share_mode: ShareMode, + ) -> io::Result { + let file_name_c = CString::new(file_name).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Invalid file name. {}", file_name), + ) + })?; + + let dw_desired_access = match access_mode { + AccessMode::Read => GENERIC_READ, + AccessMode::Write => GENERIC_WRITE, + AccessMode::ReadWrite => GENERIC_READ | GENERIC_WRITE, + }; + + let dw_share_mode = match share_mode { + ShareMode::None => 0, + ShareMode::Read => FILE_SHARE_READ, + ShareMode::Write => FILE_SHARE_WRITE, + ShareMode::Delete => FILE_SHARE_DELETE, + }; + + let dw_flags_and_attributes = FILE_ATTRIBUTE_READONLY + | FILE_FLAG_NO_BUFFERING + | FILE_FLAG_OVERLAPPED + | FILE_FLAG_RANDOM_ACCESS; + + let handle = unsafe { + CreateFileA( + file_name_c.as_ptr(), + dw_desired_access, + dw_share_mode, + ptr::null_mut(), + OPEN_EXISTING, + dw_flags_and_attributes, + ptr::null_mut(), + ) + }; + + if handle == INVALID_HANDLE_VALUE { + let error_code = unsafe { GetLastError() }; + Err(io::Error::from_raw_os_error(error_code as i32)) + } else { + Ok(Self { handle }) + } + } + + pub fn raw_handle(&self) -> HANDLE { + self.handle + } +} + +impl Drop for FileHandle { + /// Automatically closes the `FileHandle` when it goes out of scope. + /// Any errors in closing the handle are logged, as `Drop` does not support returning `Result`. + fn drop(&mut self) { + let result = unsafe { CloseHandle(self.handle) }; + if result == 0 { + let error_code = unsafe { GetLastError() }; + let error = io::Error::from_raw_os_error(error_code as i32); + + // Only log the error if dropping the handle fails, since Rust's Drop trait does not support returning Result types from the drop method, + // and panicking in the drop method is considered bad practice + log::warn!("Error when dropping IOCompletionPort: {:?}", error); + } + } +} + +/// Returns a `FileHandle` with an `INVALID_HANDLE_VALUE`. +impl Default for FileHandle { + fn default() -> Self { + Self { + handle: INVALID_HANDLE_VALUE, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs::File; + use std::path::Path; + + #[test] + fn test_create_file() { + // Create a dummy file + let dummy_file_path = "dummy_file.txt"; + { + let _file = File::create(dummy_file_path).expect("Failed to create dummy file."); + } + + let path = Path::new(dummy_file_path); + { + let file_handle = unsafe { + FileHandle::new(path.to_str().unwrap(), AccessMode::Read, ShareMode::Read) + }; + + // Check that the file handle is valid + assert!(file_handle.is_ok()); + } + + // Try to delete the file. If the handle was correctly dropped, this should succeed. + match std::fs::remove_file(dummy_file_path) { + Ok(()) => (), // File was deleted successfully, which means the handle was closed. + Err(e) => panic!("Failed to delete file: {}", e), // Failed to delete the file, likely because the handle is still open. + } + } + + #[test] + fn test_file_not_found() { + let path = Path::new("non_existent_file.txt"); + let file_handle = + unsafe { FileHandle::new(path.to_str().unwrap(), AccessMode::Read, ShareMode::Read) }; + + // Check that opening a non-existent file returns an error + assert!(file_handle.is_err()); + } +} diff --git a/rust/platform/src/file_io.rs b/rust/platform/src/file_io.rs new file mode 100644 index 000000000..e5de24773 --- /dev/null +++ b/rust/platform/src/file_io.rs @@ -0,0 +1,154 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +/// The module provides unsafe wrappers around two Windows API functions: `ReadFile` and `GetQueuedCompletionStatus`. +/// +/// These wrappers aim to simplify and abstract the use of these functions, providing easier error handling and a safer interface. +/// They return standard Rust `io::Result` types for convenience and consistency with the rest of the Rust standard library. +use std::io; +use std::ptr; + +use winapi::{ + ctypes::c_void, + shared::{ + basetsd::ULONG_PTR, + minwindef::{DWORD, FALSE}, + winerror::{ERROR_IO_PENDING, WAIT_TIMEOUT}, + }, + um::{ + errhandlingapi::GetLastError, fileapi::ReadFile, ioapiset::GetQueuedCompletionStatus, + minwinbase::OVERLAPPED, + }, +}; + +use crate::FileHandle; +use crate::IOCompletionPort; + +/// Asynchronously queue a read request from a file into a buffer slice. +/// +/// Wraps the unsafe Windows API function `ReadFile`, making it safe to call only when the overlapped buffer +/// remains valid and unchanged anywhere else during the entire async operation. +/// +/// Returns a boolean indicating whether the read operation completed synchronously or is pending. +/// +/// # Safety +/// +/// This function is marked as `unsafe` because it uses raw pointers and requires the caller to ensure +/// that the buffer slice and the overlapped buffer stay valid during the whole async operation. +pub unsafe fn read_file_to_slice( + file_handle: &FileHandle, + buffer_slice: &mut [T], + overlapped: *mut OVERLAPPED, + offset: u64, +) -> io::Result { + let num_bytes = std::mem::size_of_val(buffer_slice); + unsafe { + ptr::write(overlapped, std::mem::zeroed()); + (*overlapped).u.s_mut().Offset = offset as u32; + (*overlapped).u.s_mut().OffsetHigh = (offset >> 32) as u32; + } + + let result = unsafe { + ReadFile( + file_handle.raw_handle(), + buffer_slice.as_mut_ptr() as *mut c_void, + num_bytes as DWORD, + ptr::null_mut(), + overlapped, + ) + }; + + match result { + FALSE => { + let error = unsafe { GetLastError() }; + if error != ERROR_IO_PENDING { + Err(io::Error::from_raw_os_error(error as i32)) + } else { + Ok(false) + } + } + _ => Ok(true), + } +} + +/// Retrieves the results of an asynchronous I/O operation on an I/O completion port. +/// +/// Wraps the unsafe Windows API function `GetQueuedCompletionStatus`, making it safe to call only when the overlapped buffer +/// remains valid and unchanged anywhere else during the entire async operation. +/// +/// Returns a boolean indicating whether an I/O operation completed synchronously or is still pending. +/// +/// # Safety +/// +/// This function is marked as `unsafe` because it uses raw pointers and requires the caller to ensure +/// that the overlapped buffer stays valid during the whole async operation. +pub unsafe fn get_queued_completion_status( + completion_port: &IOCompletionPort, + lp_number_of_bytes: &mut DWORD, + lp_completion_key: &mut ULONG_PTR, + lp_overlapped: *mut *mut OVERLAPPED, + dw_milliseconds: DWORD, +) -> io::Result { + let result = unsafe { + GetQueuedCompletionStatus( + completion_port.raw_handle(), + lp_number_of_bytes, + lp_completion_key, + lp_overlapped, + dw_milliseconds, + ) + }; + + match result { + 0 => { + let error = unsafe { GetLastError() }; + if error == WAIT_TIMEOUT { + Ok(false) + } else { + Err(io::Error::from_raw_os_error(error as i32)) + } + } + _ => Ok(true), + } +} + +#[cfg(test)] +mod tests { + use crate::file_handle::{AccessMode, ShareMode}; + + use super::*; + use std::fs::File; + use std::io::Write; + use std::path::Path; + + #[test] + fn test_read_file_to_slice() { + // Create a temporary file and write some data into it + let path = Path::new("temp.txt"); + { + let mut file = File::create(path).unwrap(); + file.write_all(b"Hello, world!").unwrap(); + } + + let mut buffer: [u8; 512] = [0; 512]; + let mut overlapped = unsafe { std::mem::zeroed::() }; + { + let file_handle = unsafe { + FileHandle::new(path.to_str().unwrap(), AccessMode::Read, ShareMode::Read) + } + .unwrap(); + + // Call the function under test + let result = + unsafe { read_file_to_slice(&file_handle, &mut buffer, &mut overlapped, 0) }; + + assert!(result.is_ok()); + let result_str = std::str::from_utf8(&buffer[.."Hello, world!".len()]).unwrap(); + assert_eq!(result_str, "Hello, world!"); + } + + // Clean up + std::fs::remove_file("temp.txt").unwrap(); + } +} diff --git a/rust/platform/src/io_completion_port.rs b/rust/platform/src/io_completion_port.rs new file mode 100644 index 000000000..5bb332281 --- /dev/null +++ b/rust/platform/src/io_completion_port.rs @@ -0,0 +1,142 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::io; + +use winapi::{ + ctypes::c_void, + shared::{basetsd::ULONG_PTR, minwindef::DWORD}, + um::{ + errhandlingapi::GetLastError, + handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + ioapiset::CreateIoCompletionPort, + winnt::HANDLE, + }, +}; + +use crate::FileHandle; + +/// This module provides a safe and idiomatic Rust interface over the IOCompletionPort handle and associated Windows API functions. +/// This struct represents an I/O completion port, which is an object used in asynchronous I/O operations on Windows. +pub struct IOCompletionPort { + io_completion_port: HANDLE, +} + +impl IOCompletionPort { + /// Create a new IOCompletionPort. + /// This function wraps the Windows CreateIoCompletionPort function, providing error handling and automatic resource management. + /// + /// # Arguments + /// + /// * `file_handle` - A reference to a FileHandle to associate with the IOCompletionPort. + /// * `existing_completion_port` - An optional reference to an existing IOCompletionPort. If provided, the new IOCompletionPort will be associated with it. + /// * `completion_key` - The completion key associated with the file handle. + /// * `number_of_concurrent_threads` - The maximum number of threads that the operating system can allow to concurrently process I/O completion packets for the I/O completion port. + /// + /// # Return + /// + /// Returns a Result with the new IOCompletionPort if successful, or an io::Error if the function fails. + pub fn new( + file_handle: &FileHandle, + existing_completion_port: Option<&IOCompletionPort>, + completion_key: ULONG_PTR, + number_of_concurrent_threads: DWORD, + ) -> io::Result { + let io_completion_port = unsafe { + CreateIoCompletionPort( + file_handle.raw_handle(), + existing_completion_port + .map_or(std::ptr::null_mut::(), |io_completion_port| { + io_completion_port.raw_handle() + }), + completion_key, + number_of_concurrent_threads, + ) + }; + + if io_completion_port == INVALID_HANDLE_VALUE { + let error_code = unsafe { GetLastError() }; + return Err(io::Error::from_raw_os_error(error_code as i32)); + } + + Ok(IOCompletionPort { io_completion_port }) + } + + pub fn raw_handle(&self) -> HANDLE { + self.io_completion_port + } +} + +impl Drop for IOCompletionPort { + /// Drop method for IOCompletionPort. + /// This wraps the Windows CloseHandle function, providing automatic resource cleanup when the IOCompletionPort is dropped. + /// If an error occurs while dropping, it is logged and the drop continues. This is because panicking in Drop can cause unwinding issues. + fn drop(&mut self) { + let result = unsafe { CloseHandle(self.io_completion_port) }; + if result == 0 { + let error_code = unsafe { GetLastError() }; + let error = io::Error::from_raw_os_error(error_code as i32); + + // Only log the error if dropping the handle fails, since Rust's Drop trait does not support returning Result types from the drop method, + // and panicking in the drop method is considered bad practice + log::warn!("Error when dropping IOCompletionPort: {:?}", error); + } + } +} + +impl Default for IOCompletionPort { + /// Create a default IOCompletionPort, whose handle is set to INVALID_HANDLE_VALUE. + /// Returns a new IOCompletionPort with handle set to INVALID_HANDLE_VALUE. + fn default() -> Self { + Self { + io_completion_port: INVALID_HANDLE_VALUE, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::file_handle::{AccessMode, ShareMode}; + + #[test] + fn create_io_completion_port() { + let file_name = "../diskann/tests/data/delete_set_50pts.bin"; + let file_handle = unsafe { FileHandle::new(file_name, AccessMode::Read, ShareMode::Read) } + .expect("Failed to create file handle."); + + let io_completion_port = IOCompletionPort::new(&file_handle, None, 0, 0); + + assert!( + io_completion_port.is_ok(), + "Failed to create IOCompletionPort." + ); + } + + #[test] + fn drop_io_completion_port() { + let file_name = "../diskann/tests/data/delete_set_50pts.bin"; + let file_handle = unsafe { FileHandle::new(file_name, AccessMode::Read, ShareMode::Read) } + .expect("Failed to create file handle."); + + let io_completion_port = IOCompletionPort::new(&file_handle, None, 0, 0) + .expect("Failed to create IOCompletionPort."); + + // After this line, io_completion_port goes out of scope and its Drop trait will be called. + let _ = io_completion_port; + // We have no easy way to test that the Drop trait works correctly, but if it doesn't, + // a resource leak or other problem may become apparent in later tests or in real use of the code. + } + + #[test] + fn default_io_completion_port() { + let io_completion_port = IOCompletionPort::default(); + assert_eq!( + io_completion_port.raw_handle(), + INVALID_HANDLE_VALUE, + "Default IOCompletionPort did not have INVALID_HANDLE_VALUE." + ); + } +} + diff --git a/rust/platform/src/lib.rs b/rust/platform/src/lib.rs new file mode 100644 index 000000000..e28257078 --- /dev/null +++ b/rust/platform/src/lib.rs @@ -0,0 +1,20 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![cfg_attr( + not(test), + warn(clippy::panic, clippy::unwrap_used, clippy::expect_used) +)] + +pub mod perf; +pub use perf::{get_process_cycle_time, get_process_handle}; + +pub mod file_io; +pub use file_io::{get_queued_completion_status, read_file_to_slice}; + +pub mod file_handle; +pub use file_handle::FileHandle; + +pub mod io_completion_port; +pub use io_completion_port::IOCompletionPort; diff --git a/rust/platform/src/perf.rs b/rust/platform/src/perf.rs new file mode 100644 index 000000000..1ea146f9a --- /dev/null +++ b/rust/platform/src/perf.rs @@ -0,0 +1,50 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[cfg(target_os = "windows")] +#[link(name = "kernel32")] +extern "system" { + fn OpenProcess(dwDesiredAccess: u32, bInheritHandle: bool, dwProcessId: u32) -> usize; + fn QueryProcessCycleTime(hProcess: usize, lpCycleTime: *mut u64) -> bool; + fn GetCurrentProcessId() -> u32; +} + +/// Get current process handle. +pub fn get_process_handle() -> Option { + if cfg!(windows) { + const PROCESS_QUERY_INFORMATION: u32 = 0x0400; + const PROCESS_VM_READ: u32 = 0x0010; + + unsafe { + let current_process_id = GetCurrentProcessId(); + let handle = OpenProcess( + PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, + false, + current_process_id, + ); + if handle == 0 { + None + } else { + Some(handle) + } + } + } else { + None + } +} + +pub fn get_process_cycle_time(process_handle: Option) -> Option { + let mut cycle_time: u64 = 0; + if cfg!(windows) { + if let Some(handle) = process_handle { + let result = unsafe { QueryProcessCycleTime(handle, &mut cycle_time as *mut u64) }; + if result { + return Some(cycle_time); + } + } + } + + None +} + diff --git a/rust/project.code-workspace b/rust/project.code-workspace new file mode 100644 index 000000000..29bed0024 --- /dev/null +++ b/rust/project.code-workspace @@ -0,0 +1,58 @@ +{ + "folders": [ + { + "path": "." + } + ], + "settings": { + "search.exclude": { + "target": true, + }, + "files.exclude": { + "target": true, + }, + "rust-analyzer.linkedProjects": [ + ".\\vector\\Cargo.toml", + ".\\vector\\Cargo.toml", + ".\\vector\\Cargo.toml", + ".\\diskann\\Cargo.toml" + ], + "[rust]": { + "editor.defaultFormatter": "rust-lang.rust-analyzer", + "editor.formatOnSave": true, + } + }, + "launch": { + "version": "0.2.0", + "configurations": [ + { + "name": "Build memory index", + "type": "cppvsdbg", + "request": "launch", + "program": "${workspaceRoot}\\target\\debug\\build_memory_index.exe", + "args": [ + "--data_type", + "float", + "--dist_fn", + "l2", + "--data_path", + ".\\base1m.fbin", + "--index_path_prefix", + ".\\rust_index_sift_base_R32_L50_A1.2_T1", + "-R", + "64", + "-L", + "100", + "--alpha", + "1.2", + "-T", + "1" + ], + "stopAtEntry": false, + "cwd": "c:\\data", + "environment": [], + "externalConsole": true + }, + ] + } +} \ No newline at end of file diff --git a/rust/readme.md b/rust/readme.md new file mode 100644 index 000000000..a6c5a1bd4 --- /dev/null +++ b/rust/readme.md @@ -0,0 +1,25 @@ + +# readme + +run commands under disnann_rust directory. + +build: +``` +cargo build // Debug + +cargo build -r // Release +``` + + +run: +``` +cargo run // Debug + +cargo run -r // Release +``` + + +test: +``` +cargo test +``` diff --git a/rust/rust-toolchain.toml b/rust/rust-toolchain.toml new file mode 100644 index 000000000..183a72c9c --- /dev/null +++ b/rust/rust-toolchain.toml @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[toolchain] +channel = "stable" diff --git a/rust/vector/Cargo.toml b/rust/vector/Cargo.toml new file mode 100644 index 000000000..709a2905c --- /dev/null +++ b/rust/vector/Cargo.toml @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "vector" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +half = "2.2.1" +thiserror = "1.0.40" +bytemuck = "1.7.0" + +[build-dependencies] +cc = "1.0.79" + +[dev-dependencies] +base64 = "0.21.2" +bincode = "1.3.3" +serde = "1.0.163" +approx = "0.5.1" +rand = "0.8.5" + diff --git a/rust/vector/build.rs b/rust/vector/build.rs new file mode 100644 index 000000000..2d36c213c --- /dev/null +++ b/rust/vector/build.rs @@ -0,0 +1,29 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +fn main() { + println!("cargo:rerun-if-changed=distance.c"); + if cfg!(target_os = "macos") { + std::env::set_var("CFLAGS", "-mavx2 -mfma -Wno-error -MP -O2 -D NDEBUG -D MKL_ILP64 -D USE_AVX2 -D USE_ACCELERATED_PQ -D NOMINMAX -D _TARGET_ARM_APPLE_DARWIN"); + + cc::Build::new() + .file("distance.c") + .warnings_into_errors(true) + .debug(false) + .target("x86_64-apple-darwin") + .compile("nativefunctions.lib"); + } else { + std::env::set_var("CFLAGS", "/permissive- /MP /ifcOutput /GS- /W3 /Gy /Zi /Gm- /O2 /Ob2 /Zc:inline /fp:fast /D NDEBUG /D MKL_ILP64 /D USE_AVX2 /D USE_ACCELERATED_PQ /D NOMINMAX /fp:except- /errorReport:prompt /WX /openmp:experimental /Zc:forScope /GR /arch:AVX2 /Gd /Oy /Oi /MD /std:c++14 /FC /EHsc /nologo /Ot"); + // std::env::set_var("CFLAGS", "/permissive- /MP /ifcOutput /GS- /W3 /Gy /Zi /Gm- /Obd /Zc:inline /fp:fast /D DEBUG /D MKL_ILP64 /D USE_AVX2 /D USE_ACCELERATED_PQ /D NOMINMAX /fp:except- /errorReport:prompt /WX /openmp:experimental /Zc:forScope /GR /arch:AVX512 /Gd /Oy /Oi /MD /std:c++14 /FC /EHsc /nologo /Ot"); + + cc::Build::new() + .file("distance.c") + .warnings_into_errors(true) + .debug(false) + .compile("nativefunctions"); + + println!("cargo:rustc-link-arg=nativefunctions.lib"); + } +} + diff --git a/rust/vector/distance.c b/rust/vector/distance.c new file mode 100644 index 000000000..ee5333a53 --- /dev/null +++ b/rust/vector/distance.c @@ -0,0 +1,35 @@ +#include +#include + +inline __m256i load_128bit_to_256bit(const __m128i *ptr) +{ + __m128i value128 = _mm_loadu_si128(ptr); + __m256i value256 = _mm256_castsi128_si256(value128); + return _mm256_inserti128_si256(value256, _mm_setzero_si128(), 1); +} + +float distance_compare_avx512f_f16(const unsigned char *vec1, const unsigned char *vec2, size_t size) +{ + __m512 sum_squared_diff = _mm512_setzero_ps(); + + for (int i = 0; i < size / 16; i += 1) + { + __m512 v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(vec1 + i * 2 * 16))); + __m512 v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(vec2 + i * 2 * 16))); + + __m512 diff = _mm512_sub_ps(v1, v2); + sum_squared_diff = _mm512_fmadd_ps(diff, diff, sum_squared_diff); + } + + size_t i = (size / 16) * 16; + + if (i != size) + { + __m512 va = _mm512_cvtph_ps(load_128bit_to_256bit((const __m128i *)(vec1 + i * 2))); + __m512 vb = _mm512_cvtph_ps(load_128bit_to_256bit((const __m128i *)(vec2 + i * 2))); + __m512 diff512 = _mm512_sub_ps(va, vb); + sum_squared_diff = _mm512_fmadd_ps(diff512, diff512, sum_squared_diff); + } + + return _mm512_reduce_add_ps(sum_squared_diff); +} diff --git a/rust/vector/src/distance.rs b/rust/vector/src/distance.rs new file mode 100644 index 000000000..8ca6cb250 --- /dev/null +++ b/rust/vector/src/distance.rs @@ -0,0 +1,442 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::l2_float_distance::{distance_l2_vector_f16, distance_l2_vector_f32}; +use crate::{Half, Metric}; + +/// Distance contract for full-precision vertex +pub trait FullPrecisionDistance { + /// Get the distance between vertex a and vertex b + fn distance_compare(a: &[T; N], b: &[T; N], vec_type: Metric) -> f32; +} + +// reason = "Not supported Metric type Metric::Cosine" +#[allow(clippy::panic)] +impl FullPrecisionDistance for [f32; N] { + /// Calculate distance between two f32 Vertex + #[inline(always)] + fn distance_compare(a: &[f32; N], b: &[f32; N], metric: Metric) -> f32 { + match metric { + Metric::L2 => distance_l2_vector_f32::(a, b), + _ => panic!("Not supported Metric type {:?}", metric), + } + } +} + +// reason = "Not supported Metric type Metric::Cosine" +#[allow(clippy::panic)] +impl FullPrecisionDistance for [Half; N] { + fn distance_compare(a: &[Half; N], b: &[Half; N], metric: Metric) -> f32 { + match metric { + Metric::L2 => distance_l2_vector_f16::(a, b), + _ => panic!("Not supported Metric type {:?}", metric), + } + } +} + +// reason = "Not yet supported Vector i8" +#[allow(clippy::panic)] +impl FullPrecisionDistance for [i8; N] { + fn distance_compare(_a: &[i8; N], _b: &[i8; N], _metric: Metric) -> f32 { + panic!("Not supported VectorType i8") + } +} + +// reason = "Not yet supported Vector u8" +#[allow(clippy::panic)] +impl FullPrecisionDistance for [u8; N] { + fn distance_compare(_a: &[u8; N], _b: &[u8; N], _metric: Metric) -> f32 { + panic!("Not supported VectorType u8") + } +} + +#[cfg(test)] +mod distance_test { + use super::*; + + #[repr(C, align(32))] + pub struct F32Slice112([f32; 112]); + + #[repr(C, align(32))] + pub struct F16Slice112([Half; 112]); + + fn get_turing_test_data() -> (F32Slice112, F32Slice112) { + let a_slice: [f32; 112] = [ + 0.13961786, + -0.031577103, + -0.09567415, + 0.06695563, + -0.1588727, + 0.089852564, + -0.019837005, + 0.07497972, + 0.010418192, + -0.054594643, + 0.08613386, + -0.05103466, + 0.16568437, + -0.02703799, + 0.00728657, + -0.15313251, + 0.16462992, + -0.030570814, + 0.11635703, + 0.23938893, + 0.018022912, + -0.12646551, + 0.018048918, + -0.035986554, + 0.031986624, + -0.015286017, + 0.010117953, + -0.032691937, + 0.12163067, + -0.04746277, + 0.010213069, + -0.043672588, + -0.099362016, + 0.06599016, + -0.19397286, + -0.13285528, + -0.22040887, + 0.017690737, + -0.104262285, + -0.0044555613, + -0.07383778, + -0.108652934, + 0.13399786, + 0.054912474, + 0.20181285, + 0.1795591, + -0.05425621, + -0.10765217, + 0.1405377, + -0.14101997, + -0.12017701, + 0.011565498, + 0.06952187, + 0.060136646, + 0.0023214167, + 0.04204699, + 0.048470616, + 0.17398086, + 0.024218207, + -0.15626553, + -0.11291045, + -0.09688122, + 0.14393932, + -0.14713104, + -0.108876854, + 0.035279203, + -0.05440188, + 0.017205412, + 0.011413814, + 0.04009471, + 0.11070237, + -0.058998976, + 0.07260045, + -0.057893746, + -0.0036240944, + -0.0064988653, + -0.13842176, + -0.023219328, + 0.0035885905, + -0.0719257, + -0.21335067, + 0.11415403, + -0.0059823603, + 0.12091869, + 0.08136634, + -0.10769281, + 0.024518685, + 0.0009200326, + -0.11628049, + 0.07448965, + 0.13736208, + -0.04144517, + -0.16426727, + -0.06380103, + -0.21386267, + 0.022373492, + -0.05874115, + 0.017314062, + -0.040344074, + 0.01059176, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ]; + let b_slice: [f32; 112] = [ + -0.07209058, + -0.17755842, + -0.030627966, + 0.163028, + -0.2233766, + 0.057412963, + 0.0076995124, + -0.017121306, + -0.015759075, + -0.026947778, + -0.010282468, + -0.23968373, + -0.021486737, + -0.09903155, + 0.09361805, + 0.0042711576, + -0.08695552, + -0.042165346, + 0.064218745, + -0.06707651, + 0.07846054, + 0.12235762, + -0.060716823, + 0.18496591, + -0.13023394, + 0.022469055, + 0.056764495, + 0.07168404, + -0.08856144, + -0.15343173, + 0.099879816, + -0.033529017, + 0.0795304, + -0.009242254, + -0.10254546, + 0.13086525, + -0.101518914, + -0.1031299, + -0.056826904, + 0.033196196, + 0.044143833, + -0.049787212, + -0.018148342, + -0.11172959, + -0.06776237, + -0.09185828, + -0.24171598, + 0.05080982, + -0.0727684, + 0.045031235, + -0.11363879, + -0.063389264, + 0.105850354, + -0.19847773, + 0.08828623, + -0.087071925, + 0.033512704, + 0.16118294, + 0.14111553, + 0.020884402, + -0.088860825, + 0.018745849, + 0.047522716, + -0.03665169, + 0.15726231, + -0.09930561, + 0.057844743, + -0.10532736, + -0.091297254, + 0.067029804, + 0.04153976, + 0.06393326, + 0.054578528, + 0.0038539872, + 0.1023088, + -0.10653885, + -0.108500294, + -0.046606563, + 0.020439683, + -0.120957725, + -0.13334097, + -0.13425854, + -0.20481694, + 0.07009538, + 0.08660361, + -0.0096641015, + 0.095316306, + -0.002898167, + -0.19680002, + 0.08466311, + 0.04812689, + -0.028978813, + 0.04780206, + -0.2001506, + -0.036866356, + -0.023720587, + 0.10731964, + 0.05517358, + -0.09580819, + 0.14595725, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ]; + + (F32Slice112(a_slice), F32Slice112(b_slice)) + } + + fn get_turing_test_data_f16() -> (F16Slice112, F16Slice112) { + let (a_slice, b_slice) = get_turing_test_data(); + let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x)); + let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x)); + + ( + F16Slice112(a_data.collect::>().try_into().unwrap()), + F16Slice112(b_data.collect::>().try_into().unwrap()), + ) + } + + use crate::test_util::*; + use approx::assert_abs_diff_eq; + + #[test] + fn test_dist_l2_float_turing() { + // two vectors are allocated in the contiguous heap memory + let (a_slice, b_slice) = get_turing_test_data(); + let distance = <[f32; 112] as FullPrecisionDistance>::distance_compare( + &a_slice.0, + &b_slice.0, + Metric::L2, + ); + + assert_abs_diff_eq!( + distance, + no_vector_compare_f32(&a_slice.0, &b_slice.0), + epsilon = 1e-6 + ); + } + + #[test] + fn test_dist_l2_f16_turing() { + // two vectors are allocated in the contiguous heap memory + let (a_slice, b_slice) = get_turing_test_data_f16(); + let distance = <[Half; 112] as FullPrecisionDistance>::distance_compare( + &a_slice.0, + &b_slice.0, + Metric::L2, + ); + + // Note the variance between the full 32 bit precision and the 16 bit precision + assert_eq!(distance, no_vector_compare_f16(&a_slice.0, &b_slice.0)); + } + + #[test] + fn distance_test() { + #[repr(C, align(32))] + struct Vector32ByteAligned { + v: [f32; 512], + } + + // two vectors are allocated in the contiguous heap memory + let two_vec = Box::new(Vector32ByteAligned { + v: [ + 69.02492, 78.84786, 63.125072, 90.90581, 79.2592, 70.81731, 3.0829668, 33.33287, + 20.777142, 30.147898, 23.681915, 42.553043, 12.602162, 7.3808074, 19.157589, + 65.6791, 76.44677, 76.89124, 86.40756, 84.70118, 87.86142, 16.126896, 5.1277637, + 95.11038, 83.946945, 22.735607, 11.548555, 59.51482, 24.84603, 15.573776, 78.27185, + 71.13179, 38.574017, 80.0228, 13.175261, 62.887978, 15.205181, 18.89392, 96.13162, + 87.55455, 34.179806, 62.920044, 4.9305916, 54.349373, 21.731495, 14.982187, + 40.262867, 20.15214, 36.61963, 72.450806, 55.565, 95.5375, 93.73356, 95.36308, + 66.30762, 58.0397, 18.951357, 67.11702, 43.043316, 30.65622, 99.85361, 2.5889993, + 27.844774, 39.72441, 46.463238, 71.303764, 90.45308, 36.390602, 63.344395, + 26.427078, 35.99528, 82.35505, 32.529175, 23.165905, 74.73179, 9.856939, 59.38126, + 35.714924, 79.81213, 46.704124, 24.47884, 36.01743, 0.46678782, 29.528152, + 1.8980742, 24.68853, 75.58984, 98.72279, 68.62601, 11.890173, 49.49361, 55.45572, + 72.71067, 34.107483, 51.357758, 76.400635, 81.32725, 66.45081, 17.848074, + 62.398876, 94.20444, 2.10886, 17.416393, 64.88253, 29.000723, 62.434315, 53.907238, + 70.51412, 78.70744, 55.181683, 64.45116, 23.419212, 53.68544, 43.506958, 46.89598, + 35.905994, 64.51397, 91.95555, 20.322979, 74.80128, 97.548744, 58.312725, 78.81985, + 31.911612, 14.445949, 49.85094, 70.87396, 40.06766, 7.129991, 78.48008, 75.21636, + 93.623604, 95.95479, 29.571129, 22.721554, 26.73875, 52.075504, 56.783104, + 94.65493, 61.778534, 85.72401, 85.369514, 29.922367, 41.410553, 94.12884, + 80.276855, 55.604828, 54.70947, 74.07216, 44.61955, 31.38113, 68.48596, 34.56782, + 14.424729, 48.204506, 9.675444, 32.01946, 92.32695, 36.292683, 78.31955, 98.05327, + 14.343918, 46.017002, 95.90888, 82.63626, 16.873539, 3.698051, 7.8042626, + 64.194405, 96.71023, 67.93692, 21.618402, 51.92182, 22.834194, 61.56986, 19.749891, + 55.31206, 38.29552, 67.57593, 67.145836, 38.92673, 94.95708, 72.38746, 90.70901, + 69.43995, 9.394085, 31.646872, 88.20112, 9.134722, 99.98214, 5.423498, 41.51995, + 76.94409, 77.373276, 3.2966614, 9.611201, 57.231106, 30.747868, 76.10228, 91.98308, + 70.893585, 0.9067178, 43.96515, 16.321218, 27.734184, 83.271835, 88.23312, + 87.16445, 5.556643, 15.627432, 58.547127, 93.6459, 40.539192, 49.124157, 91.13276, + 57.485855, 8.827019, 4.9690843, 46.511234, 53.91469, 97.71925, 20.135271, + 23.353004, 70.92099, 93.38748, 87.520134, 51.684677, 29.89813, 9.110392, 65.809204, + 34.16554, 93.398605, 84.58669, 96.409645, 9.876037, 94.767784, 99.21523, 1.9330144, + 94.92429, 75.12728, 17.218828, 97.89164, 35.476578, 77.629456, 69.573746, + 40.200542, 42.117836, 5.861628, 75.45282, 82.73633, 0.98086596, 77.24894, + 11.248695, 61.070026, 52.692616, 80.5449, 80.76036, 29.270136, 67.60252, 48.782394, + 95.18851, 83.47162, 52.068756, 46.66002, 90.12216, 15.515327, 33.694042, 96.963036, + 73.49627, 62.805485, 44.715607, 59.98627, 3.8921833, 37.565327, 29.69184, + 39.429665, 83.46899, 44.286453, 21.54851, 56.096413, 18.169249, 5.214751, + 14.691341, 99.779335, 26.32643, 67.69903, 36.41243, 67.27333, 12.157213, 96.18984, + 2.438283, 78.14289, 0.14715195, 98.769, 53.649532, 21.615898, 39.657497, 95.45616, + 18.578386, 71.47976, 22.348118, 17.85519, 6.3717127, 62.176777, 22.033644, + 23.178005, 79.44858, 89.70233, 37.21273, 71.86182, 21.284317, 52.908623, 30.095518, + 63.64478, 77.55823, 80.04871, 15.133011, 30.439043, 70.16561, 4.4014096, 89.28944, + 26.29093, 46.827854, 11.764729, 61.887516, 47.774887, 57.19503, 59.444664, + 28.592825, 98.70386, 1.2497544, 82.28431, 46.76423, 83.746124, 53.032673, 86.53457, + 99.42168, 90.184, 92.27852, 9.059965, 71.75723, 70.45299, 10.924053, 68.329704, + 77.27232, 6.677854, 75.63629, 57.370533, 17.09031, 10.554659, 99.56178, 37.53221, + 72.311104, 75.7565, 65.2042, 36.096478, 64.69502, 38.88497, 64.33723, 84.87812, + 66.84958, 8.508932, 79.134, 83.431015, 66.72124, 61.801838, 64.30524, 37.194263, + 77.94725, 89.705185, 23.643505, 19.505919, 48.40264, 43.01083, 21.171177, + 18.717121, 10.805857, 69.66983, 77.85261, 57.323063, 3.28964, 38.758026, 5.349946, + 7.46572, 57.485138, 30.822384, 33.9411, 95.53746, 65.57723, 42.1077, 28.591347, + 11.917269, 5.031073, 31.835615, 19.34116, 85.71027, 87.4516, 1.3798475, 70.70583, + 51.988052, 45.217144, 14.308596, 54.557167, 86.18323, 79.13666, 76.866745, + 46.010685, 79.739235, 44.667603, 39.36416, 72.605896, 73.83187, 13.137412, + 6.7911267, 63.952374, 10.082436, 86.00318, 99.760376, 92.84948, 63.786434, + 3.4429908, 18.244314, 75.65299, 14.964747, 70.126366, 80.89449, 91.266655, + 96.58798, 46.439327, 38.253975, 87.31036, 21.093178, 37.19671, 58.28973, 9.75231, + 12.350321, 25.75115, 87.65073, 53.610504, 36.850048, 18.66356, 94.48941, 83.71898, + 44.49315, 44.186737, 19.360733, 84.365974, 46.76272, 44.924366, 50.279808, + 54.868866, 91.33004, 18.683397, 75.13282, 15.070831, 47.04839, 53.780903, + 26.911152, 74.65651, 57.659935, 25.604189, 37.235474, 65.39667, 53.952206, + 40.37131, 59.173275, 96.00756, 54.591274, 10.787476, 69.51549, 31.970142, + 25.408005, 55.972492, 85.01888, 97.48981, 91.006134, 28.98619, 97.151276, + 34.388496, 47.498177, 11.985874, 64.73775, 33.877014, 13.370312, 34.79146, + 86.19321, 15.019405, 94.07832, 93.50433, 60.168625, 50.95409, 38.27827, 47.458614, + 32.83715, 69.54998, 69.0361, 84.1418, 34.270298, 74.23852, 70.707466, 78.59845, + 9.651399, 24.186779, 58.255756, 53.72362, 92.46477, 97.75528, 20.257462, 30.122698, + 50.41517, 28.156603, 42.644154, + ], + }); + + let distance = compare::(256, Metric::L2, &two_vec.v); + + assert_eq!(distance, 429141.2); + } + + fn compare(dim: usize, metric: Metric, v: &[f32]) -> f32 + where + for<'a> [T; N]: FullPrecisionDistance, + { + let a_ptr = v.as_ptr(); + let b_ptr = unsafe { a_ptr.add(dim) }; + + let a_ref = + <&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(a_ptr, dim) }).unwrap(); + let b_ref = + <&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(b_ptr, dim) }).unwrap(); + + <[f32; N]>::distance_compare(a_ref, b_ref, metric) + } +} diff --git a/rust/vector/src/distance_test.rs b/rust/vector/src/distance_test.rs new file mode 100644 index 000000000..0def0264a --- /dev/null +++ b/rust/vector/src/distance_test.rs @@ -0,0 +1,152 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[cfg(test)] +mod e2e_test { + + #[repr(C, align(32))] + pub struct F32Slice104([f32; 104]); + + #[repr(C, align(32))] + pub struct F16Slice104([Half; 104]); + + use approx::assert_abs_diff_eq; + + use crate::half::Half; + use crate::l2_float_distance::{distance_l2_vector_f16, distance_l2_vector_f32}; + + fn no_vector_compare_f32(a: &[f32], b: &[f32]) -> f32 { + let mut sum = 0.0; + for i in 0..a.len() { + let a_f32 = a[i]; + let b_f32 = b[i]; + let diff = a_f32 - b_f32; + sum += diff * diff; + } + sum + } + + fn no_vector_compare(a: &[Half], b: &[Half]) -> f32 { + let mut sum = 0.0; + for i in 0..a.len() { + let a_f32 = a[i].to_f32(); + let b_f32 = b[i].to_f32(); + let diff = a_f32 - b_f32; + sum += diff * diff; + } + sum + } + + #[test] + fn avx2_matches_novector() { + for i in 1..3 { + let (f1, f2) = get_test_data(0, i); + + let distance_f32x8 = distance_l2_vector_f32::<104>(&f1.0, &f2.0); + let distance = no_vector_compare_f32(&f1.0, &f2.0); + + assert_abs_diff_eq!(distance, distance_f32x8, epsilon = 1e-6); + } + } + + #[test] + fn avx2_matches_novector_random() { + let (f1, f2) = get_test_data_random(); + + let distance_f32x8 = distance_l2_vector_f32::<104>(&f1.0, &f2.0); + let distance = no_vector_compare_f32(&f1.0, &f2.0); + + assert_abs_diff_eq!(distance, distance_f32x8, epsilon = 1e-4); + } + + #[test] + fn avx_f16_matches_novector() { + for i in 1..3 { + let (f1, f2) = get_test_data_f16(0, i); + let _a_slice = f1.0.map(|x| x.to_f32().to_string()).join(", "); + let _b_slice = f2.0.map(|x| x.to_f32().to_string()).join(", "); + + let expected = no_vector_compare(f1.0[0..].as_ref(), f2.0[0..].as_ref()); + let distance_f16x8 = distance_l2_vector_f16::<104>(&f1.0, &f2.0); + + assert_abs_diff_eq!(distance_f16x8, expected, epsilon = 1e-4); + } + } + + #[test] + fn avx_f16_matches_novector_random() { + let (f1, f2) = get_test_data_f16_random(); + + let expected = no_vector_compare(f1.0[0..].as_ref(), f2.0[0..].as_ref()); + let distance_f16x8 = distance_l2_vector_f16::<104>(&f1.0, &f2.0); + + assert_abs_diff_eq!(distance_f16x8, expected, epsilon = 1e-4); + } + + fn get_test_data_f16(i1: usize, i2: usize) -> (F16Slice104, F16Slice104) { + let (a_slice, b_slice) = get_test_data(i1, i2); + let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x)); + let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x)); + + ( + F16Slice104(a_data.collect::>().try_into().unwrap()), + F16Slice104(b_data.collect::>().try_into().unwrap()), + ) + } + + fn get_test_data(i1: usize, i2: usize) -> (F32Slice104, F32Slice104) { + use base64::{engine::general_purpose, Engine as _}; + + let b64 = general_purpose::STANDARD.decode(TEST_DATA).unwrap(); + + let decoded: Vec> = bincode::deserialize(&b64).unwrap(); + debug_assert!(decoded.len() > i1); + debug_assert!(decoded.len() > i2); + + let mut f1 = F32Slice104([0.0; 104]); + let v1 = &decoded[i1]; + debug_assert!(v1.len() == 104); + f1.0.copy_from_slice(v1); + + let mut f2 = F32Slice104([0.0; 104]); + let v2 = &decoded[i2]; + debug_assert!(v2.len() == 104); + f2.0.copy_from_slice(v2); + + (f1, f2) + } + + fn get_test_data_f16_random() -> (F16Slice104, F16Slice104) { + let (a_slice, b_slice) = get_test_data_random(); + let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x)); + let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x)); + + ( + F16Slice104(a_data.collect::>().try_into().unwrap()), + F16Slice104(b_data.collect::>().try_into().unwrap()), + ) + } + + fn get_test_data_random() -> (F32Slice104, F32Slice104) { + use rand::Rng; + + let mut rng = rand::thread_rng(); + let mut f1 = F32Slice104([0.0; 104]); + + for i in 0..104 { + f1.0[i] = rng.gen_range(-1.0..1.0); + } + + let mut f2 = F32Slice104([0.0; 104]); + + for i in 0..104 { + f2.0[i] = rng.gen_range(-1.0..1.0); + } + + (f1, f2) + } + + const TEST_DATA: &str = "BQAAAAAAAABoAAAAAAAAAPz3Dj7+VgG9z/DDvQkgiT2GryK+nwS4PTeBorz4jpk9ELEqPKKeX73zZrA9uAlRvSqpKT7Gft28LsTuO8XOHL6/lCg+pW/6vJhM7j1fInU+yaSTPC2AAb5T25M8o2YTvWgEAz00cnq8xcUlPPvnBb2AGfk9UmhCvbdUJzwH4jK9UH7Lvdklhz3SoEa+NwsIvt2yYb4q7JA8d4fVvfX/kbtDOJe9boXevbw2CT7n62A9B6hOPlfeNz7CO169vnjcvR3pDz6KZxC+XR/2vTd9PTx7YY492FF2PekiGDt3OSw9IIlGPQooMj5DZcY8EgQgvpg9572paca91GQTPoWpFr7U+t697YAQPYHUXr1d8ow8AQE7PFo6JD3tt+I96ahxvYuvlD3+IW29N4Jtu2/01Ltvvg2+dja+vI8uazvITZO9mXhavpfJ6T2tB8S7OKT3PWWjpj0Mjty9advIPFgucTp3JO69CI6YPaWoDD5pwim9rjUovh2qgr3R/lq+nUi3PI+acL041o081D8lvRCJLTwAAAAAAAAAAAAAAAAAAAAAaAAAAAAAAAA6pJO94NE1voDn+rzQ8CY+1rxkvtspaz0xTPw7+0GMvC0ZgbyWwdy8zHcovKdvdb70BLC8DtHKvdK6vz0R9Ys7vBWyvZK1LL0ehYM9aV+JveuvoD2ilvo9NLJ4vbRnPT4MXAW+BhG4POOBaD0Vz5I9s1+1vTUdHb7Kjcw9uVUJvdbgoj3TbBe8WwPSvYoBBj4m6c+9xTXTvVTDaL28+Ac9KtA0Pa3tS73Vq5S8fNLkvf/Gir0yILy9ZYR3vvUdUD2ZB5W9rHI4PXS76L070oG9EsjYPb89S75pz7Q9xFKyvZ5ECT0kDSU+l4AQPsQVqzyq/LW95ZCZPC6nQj0VIBa9XwkhPr1gy72c7mw937XXvQ76ur3sRok9mCUqPXHvgj28jV89LZN8O0eH0T0KMdq9ZzXevYbmPr0fcac8r7j3vYmKCL4Sewm+iLtRviuOjz08XbE9LlYevDI1wz0s7z278oVJvtpjrT20IEU9+mTtvBjMQz1H9Ey+LQEXva1Rwrxmyts9sf1hPRY3xL3RdRU+AAAAAAAAAAAAAAAAAAAAAGgAAAAAAAAARqSTvbYJpLx1x869cW67PeeJhb7/cBu9m0eFPQO3oL0I+L49YQDavTYSez3SmTg96hBGPuh4oL2x2ow6WdCUO6XUSz4xcU88GReAvVfekj0Ph3Y9z43hvBzT5z1I2my9UVy3vAj8jL08Gtm9CfJcPRihTr1+8Yu9TiP+PNrJa77Dfa09IhpEPesJNr0XzFU8yye3PZKFyz3uzJ09FLRUvYq3l73X4X07DDUzvq9VXjwWtg8+JrzYPcFCkr0jDCg9T9zlvZbZjz4Y8pM89xo8PgAcfbvYSnY8XoFKvO05/L36yzE8J+5yPqfe5r2AZFq8ULRDvnkTgrw+S7q9qGYLvQDZYL1T8d09bFikvZw3+jsYLdO8H3GVveHBYT4gnsE8ZBIJPpzOEj7OSDC+ZYu+vFc1Erzko4M9GqLtPBHH5TwpeRs+miC4PBHH5Tw9Z9k9VUsUPjnppj0oC5C9mcqDvY7y1rxdvZU8PdFAPov9lz0bOmq94kdyPBBokTxtOj89fu4avSsazj1P7iE+x8YkPAAAAAAAAAAAAAAAAAAAAABoAAAAAAAAAHEruT3mgKM8JnEvvAsfHL63906+ifhgvldl1r14OeO9waUyuw3yUzx+PDW9UbDhPQP4Lb4KRRk+Oky2vaLfaT30mrA9YMeZPfzPMz4h42M+XfCHva4AGr6MOSM+iBOzvdsaE7xFxgI+gJGXvVMzE75kHY+8oAWNvVqNK7yOx589fU3lvVVPg730Cwk+DKkEPWYtxjqQ2MK9H0T+vTnGQj2yq5w8L49BvrEJrzyB4Yo9AXV7PYGCLr3MxsG9oWM7PTyu8TzEOhW+dyWrvUTxHD2nL+c9+VKFPcthhLsc0PM8FdyPPeLj/z1WAHS8ZvW2PGg4Cb5u3IU9g4CovSHW+L2CWoG++nZnPAi2ST3HmUC9P5rJuxQbU765lwU+7FLBPUPTfL0uGgk+yKy2PYwXaT1I4I+9AU6VPQ5QaDx9mdE8Qg8zPfGCUjzD/io9rr+BvTNDqT0MFNi9mHatvS1iJD0nVrK78WmIPE0QsL3PAQq9cMRgPWXmmr3yTcw9UcXrPccwa76+cBq+5iVOvUg9c70AAAAAAAAAAAAAAAAAAAAAaAAAAAAAAAB/K7k9hCsnPUJXJr2Wg4a9MEtXve33Sj0VJZ89pciEvWLqwLzUgyu8ADTGPAVenL2UZ/c96YtMved+Wr3LUro9H8a7vGTSA77C5n69Lf3pPQj4KD5cFKq9fZ0uvvYQCT7b23G9XGMCPrGuy736Z9A9kZzFPSuCSD7/9/07Y4/6POxLir3/JBS9qFKMvkSzjryPgVY+ugq8PC9yhbsXaiq+O6WfPcvFK7vZXAy+goAQvXpHHj5jwPI87eokvrySET5QoOm8h8ixOhXzKb5s8+A9sjcJPjiLAz598yQ9yCYSPq6eGz4rvjE82lvGvWuIOLx23zK9hHg8vTWOv70/Tse81fA6Pr2wNz34Eza+2Uj3PZ3trr0aXAI9PCkKPiybe721P9U9QkNLO927jT3LpRA+mpJUvUeU6rwC/Qa+lr4Cvgrpnj1pQ/i9TxhSvJqYr72RS6y8aQLTPQzPiz3vSRY94NfrPJl6LL2adjO8iYfPuhRzZz2f7R8+iVskPcUeXr12ZiI+nd3xvIYv8bwqYlg+AAAAAAAAAAAAAAAAAAAAAA=="; +} + diff --git a/rust/vector/src/half.rs b/rust/vector/src/half.rs new file mode 100644 index 000000000..87d7df6a1 --- /dev/null +++ b/rust/vector/src/half.rs @@ -0,0 +1,82 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use bytemuck::{Pod, Zeroable}; +use half::f16; +use std::convert::AsRef; +use std::fmt; + +// Define the Half type as a new type over f16. +// the memory layout of the Half struct will be the same as the memory layout of the f16 type itself. +// The Half struct serves as a simple wrapper around the f16 type and does not introduce any additional memory overhead. +// Test function: +// use half::f16; +// pub struct Half(f16); +// fn main() { +// let size_of_half = std::mem::size_of::(); +// let alignment_of_half = std::mem::align_of::(); +// println!("Size of Half: {} bytes", size_of_half); +// println!("Alignment of Half: {} bytes", alignment_of_half); +// } +// Output: +// Size of Half: 2 bytes +// Alignment of Half: 2 bytes +pub struct Half(f16); + +unsafe impl Pod for Half {} +unsafe impl Zeroable for Half {} + +// Implement From for Half +impl From for f32 { + fn from(val: Half) -> Self { + val.0.to_f32() + } +} + +// Implement AsRef for Half so that it can be used in distance_compare. +impl AsRef for Half { + fn as_ref(&self) -> &f16 { + &self.0 + } +} + +// Implement From for Half. +impl Half { + pub fn from_f32(value: f32) -> Self { + Self(f16::from_f32(value)) + } +} + +// Implement Default for Half. +impl Default for Half { + fn default() -> Self { + Self(f16::from_f32(Default::default())) + } +} + +// Implement Clone for Half. +impl Clone for Half { + fn clone(&self) -> Self { + Half(self.0) + } +} + +// Implement PartialEq for Half. +impl fmt::Debug for Half { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Half({:?})", self.0) + } +} + +impl Copy for Half {} + +impl Half { + pub fn to_f32(&self) -> f32 { + self.0.to_f32() + } +} + +unsafe impl Send for Half {} +unsafe impl Sync for Half {} + diff --git a/rust/vector/src/l2_float_distance.rs b/rust/vector/src/l2_float_distance.rs new file mode 100644 index 000000000..b818899bf --- /dev/null +++ b/rust/vector/src/l2_float_distance.rs @@ -0,0 +1,78 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Distance calculation for L2 Metric + +#[cfg(not(target_feature = "avx2"))] +compile_error!("Library must be compiled with -C target-feature=+avx2"); + +use std::arch::x86_64::*; + +use crate::Half; + +/// Calculate the distance by vector arithmetic +#[inline(never)] +pub fn distance_l2_vector_f16(a: &[Half; N], b: &[Half; N]) -> f32 { + debug_assert_eq!(N % 8, 0); + + // make sure the addresses are bytes aligned + debug_assert_eq!(a.as_ptr().align_offset(32), 0); + debug_assert_eq!(b.as_ptr().align_offset(32), 0); + + unsafe { + let mut sum = _mm256_setzero_ps(); + let a_ptr = a.as_ptr() as *const __m128i; + let b_ptr = b.as_ptr() as *const __m128i; + + // Iterate over the elements in steps of 8 + for i in (0..N).step_by(8) { + let a_vec = _mm256_cvtph_ps(_mm_load_si128(a_ptr.add(i / 8))); + let b_vec = _mm256_cvtph_ps(_mm_load_si128(b_ptr.add(i / 8))); + + let diff = _mm256_sub_ps(a_vec, b_vec); + sum = _mm256_fmadd_ps(diff, diff, sum); + } + + let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(sum, 1), _mm256_castps256_ps128(sum)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + _mm_cvtss_f32(x32) + } +} + +/// Calculate the distance by vector arithmetic +#[inline(never)] +pub fn distance_l2_vector_f32(a: &[f32; N], b: &[f32; N]) -> f32 { + debug_assert_eq!(N % 8, 0); + + // make sure the addresses are bytes aligned + debug_assert_eq!(a.as_ptr().align_offset(32), 0); + debug_assert_eq!(b.as_ptr().align_offset(32), 0); + + unsafe { + let mut sum = _mm256_setzero_ps(); + + // Iterate over the elements in steps of 8 + for i in (0..N).step_by(8) { + let a_vec = _mm256_load_ps(&a[i]); + let b_vec = _mm256_load_ps(&b[i]); + let diff = _mm256_sub_ps(a_vec, b_vec); + sum = _mm256_fmadd_ps(diff, diff, sum); + } + + let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(sum, 1), _mm256_castps256_ps128(sum)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + _mm_cvtss_f32(x32) + } +} + diff --git a/rust/vector/src/lib.rs b/rust/vector/src/lib.rs new file mode 100644 index 000000000..d221070b5 --- /dev/null +++ b/rust/vector/src/lib.rs @@ -0,0 +1,26 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![cfg_attr( + not(test), + warn(clippy::panic, clippy::unwrap_used, clippy::expect_used) +)] + +// #![feature(stdsimd)] +// mod f32x16; +// Uncomment above 2 to experiment with f32x16 +mod distance; +mod half; +mod l2_float_distance; +mod metric; +mod utils; + +pub use crate::half::Half; +pub use distance::FullPrecisionDistance; +pub use metric::Metric; +pub use utils::prefetch_vector; + +#[cfg(test)] +mod distance_test; +mod test_util; diff --git a/rust/vector/src/metric.rs b/rust/vector/src/metric.rs new file mode 100644 index 000000000..c60ef291b --- /dev/null +++ b/rust/vector/src/metric.rs @@ -0,0 +1,36 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] +use std::str::FromStr; + +/// Distance metric +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum Metric { + /// Squared Euclidean (L2-Squared) + L2, + + /// Cosine similarity + /// TODO: T should be float for Cosine distance + Cosine, +} + +#[derive(thiserror::Error, Debug)] +pub enum ParseMetricError { + #[error("Invalid format for Metric: {0}")] + InvalidFormat(String), +} + +impl FromStr for Metric { + type Err = ParseMetricError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "l2" => Ok(Metric::L2), + "cosine" => Ok(Metric::Cosine), + _ => Err(ParseMetricError::InvalidFormat(String::from(s))), + } + } +} + diff --git a/rust/vector/src/test_util.rs b/rust/vector/src/test_util.rs new file mode 100644 index 000000000..7cfc92985 --- /dev/null +++ b/rust/vector/src/test_util.rs @@ -0,0 +1,29 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[cfg(test)] +use crate::Half; + +#[cfg(test)] +pub fn no_vector_compare_f16(a: &[Half], b: &[Half]) -> f32 { + let mut sum = 0.0; + debug_assert_eq!(a.len(), b.len()); + + for i in 0..a.len() { + sum += (a[i].to_f32() - b[i].to_f32()).powi(2); + } + sum +} + +#[cfg(test)] +pub fn no_vector_compare_f32(a: &[f32], b: &[f32]) -> f32 { + let mut sum = 0.0; + debug_assert_eq!(a.len(), b.len()); + + for i in 0..a.len() { + sum += (a[i] - b[i]).powi(2); + } + sum +} + diff --git a/rust/vector/src/utils.rs b/rust/vector/src/utils.rs new file mode 100644 index 000000000..a61c99aad --- /dev/null +++ b/rust/vector/src/utils.rs @@ -0,0 +1,21 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0}; + +/// Prefetch the given vector in chunks of 64 bytes, which is a cache line size +/// NOTE: good efficiency when total_vec_size is integral multiple of 64 +#[inline] +pub fn prefetch_vector(vec: &[T]) { + let vec_ptr = vec.as_ptr() as *const i8; + let vecsize = std::mem::size_of_val(vec); + let max_prefetch_size = (vecsize / 64) * 64; + + for d in (0..max_prefetch_size).step_by(64) { + unsafe { + _mm_prefetch(vec_ptr.add(d), _MM_HINT_T0); + } + } +} + diff --git a/rust/vector_base64/Cargo.toml b/rust/vector_base64/Cargo.toml new file mode 100644 index 000000000..6f50ad96e --- /dev/null +++ b/rust/vector_base64/Cargo.toml @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "vector_base64" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +base64 = "0.21.2" +bincode = "1.3.3" +half = "2.2.1" +serde = "1.0.163" + diff --git a/rust/vector_base64/src/main.rs b/rust/vector_base64/src/main.rs new file mode 100644 index 000000000..2867436a9 --- /dev/null +++ b/rust/vector_base64/src/main.rs @@ -0,0 +1,82 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::fs::File; +use std::io::{self, BufReader, Read}; +use std::{env, vec}; + +fn main() -> io::Result<()> { + // Retrieve command-line arguments + let args: Vec = env::args().collect(); + + // Check if the correct number of arguments is provided + if args.len() != 4 { + print_usage(); + return Ok(()); + } + + // Retrieve the input and output file paths from the arguments + let input_file_path = &args[1]; + let item_count: usize = args[2].parse::().unwrap(); + let return_dimension: usize = args[3].parse::().unwrap(); + + // Open the input file for reading + let mut input_file = BufReader::new(File::open(input_file_path)?); + + // Read the first 8 bytes as metadata + let mut metadata = [0; 8]; + input_file.read_exact(&mut metadata)?; + + // Extract the number of points and dimension from the metadata + let _ = i32::from_le_bytes(metadata[..4].try_into().unwrap()); + let mut dimension: usize = (i32::from_le_bytes(metadata[4..].try_into().unwrap())) as usize; + if return_dimension < dimension { + dimension = return_dimension; + } + + let mut float_array = Vec::>::with_capacity(item_count); + + // Process each data point + for _ in 0..item_count { + // Read one data point from the input file + let mut buffer = vec![0; dimension * std::mem::size_of::()]; + match input_file.read_exact(&mut buffer) { + Ok(()) => { + let mut float_data = buffer + .chunks_exact(4) + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect::>(); + + let mut i = return_dimension; + while i > dimension { + float_data.push(0.0); + i -= 1; + } + + float_array.push(float_data); + } + Err(err) => { + println!("Error: {}", err); + break; + } + } + } + + use base64::{engine::general_purpose, Engine as _}; + + let encoded: Vec = bincode::serialize(&float_array).unwrap(); + let b64 = general_purpose::STANDARD.encode(encoded); + println!("Float {}", b64); + + Ok(()) +} + +/// Prints the usage information +fn print_usage() { + println!("Usage: program_name input_file "); + println!( + "Itemcount is the number of items to convert. Expand to dimension if provided is smaller" + ); +} + diff --git a/scripts/dev/install-dev-deps-ubuntu.bash b/scripts/dev/install-dev-deps-ubuntu.bash new file mode 100755 index 000000000..84f558ed6 --- /dev/null +++ b/scripts/dev/install-dev-deps-ubuntu.bash @@ -0,0 +1,12 @@ +#!/bin/bash + +apt install cmake \ + g++ \ + libaio-dev \ + libgoogle-perftools-dev \ + libunwind-dev \ + clang-format \ + libboost-dev \ + libboost-program-options-dev \ + libboost-test-dev \ + libmkl-full-dev \ No newline at end of file diff --git a/setup.py b/setup.py index dc453fcb2..ff5bed187 100644 --- a/setup.py +++ b/setup.py @@ -3,13 +3,14 @@ import os import re +import shutil import subprocess import sys from pathlib import Path -from typing import List from setuptools import Extension, setup from setuptools.command.build_ext import build_ext +from setuptools.command.install_lib import install_lib # Convert distutils Windows platform specifiers to CMake -A arguments PLAT_TO_CMAKE = { @@ -45,7 +46,7 @@ def build_extension(self, ext: CMakeExtension) -> None: f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", f"-DPYTHON_EXECUTABLE={sys.executable}", f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm - f"-DVERSION_INFO={self.distribution.get_version()}" # commented out because we want this set in the CMake file + f"-DVERSION_INFO={self.distribution.get_version()}" # commented out, we want this set in the CMake file ] build_args = [] # Adding CMake arguments set as environment variable @@ -126,8 +127,43 @@ def build_extension(self, ext: CMakeExtension) -> None: ) +class InstallCMakeLibs(install_lib): + def run(self): + """ + Windows only copy from the x64/Release directory and place them in the package + """ + + self.announce("Moving library files", level=3) + + self.skip_build = True + + # we only need to move the windows build output + windows_build_output_dir = Path('.') / 'x64' / 'Release' + + if windows_build_output_dir.exists(): + libs = [ + os.path.join(windows_build_output_dir, _lib) for _lib in + os.listdir(windows_build_output_dir) if + os.path.isfile(os.path.join(windows_build_output_dir, _lib)) and + os.path.splitext(_lib)[1] in [".dll", '.lib', '.pyd', '.exp'] + ] + + for lib in libs: + shutil.move( + lib, + os.path.join(self.build_dir, 'diskannpy', os.path.basename(lib)) + ) + + super().run() + + setup( - ext_modules=[CMakeExtension("diskannpy", ".")], - cmdclass={"build_ext": CMakeBuild}, - zip_safe=False + ext_modules=[CMakeExtension("diskannpy._diskannpy", ".")], + cmdclass={ + "build_ext": CMakeBuild, + 'install_lib': InstallCMakeLibs + }, + zip_safe=False, + package_dir={"diskannpy": "python/src"}, + exclude_package_data={"diskannpy": ["diskann_bindings.cpp"]} ) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d993ccc69..2206a01f7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,20 +1,26 @@ #Copyright(c) Microsoft Corporation.All rights reserved. #Licensed under the MIT license. -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_COMPILE_WARNING_AS_ERROR ON) if(MSVC) add_subdirectory(dll) else() #file(GLOB CPP_SOURCES *.cpp) - set(CPP_SOURCES ann_exception.cpp disk_utils.cpp distance.cpp index.cpp + set(CPP_SOURCES abstract_data_store.cpp ann_exception.cpp disk_utils.cpp + distance.cpp index.cpp in_mem_graph_store.cpp in_mem_data_store.cpp linux_aligned_file_reader.cpp math_utils.cpp natural_number_map.cpp + in_mem_data_store.cpp in_mem_graph_store.cpp natural_number_set.cpp memory_mapper.cpp partition.cpp pq.cpp - pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp) + pq_flash_index.cpp scratch.cpp logger.cpp utils.cpp filter_utils.cpp index_factory.cpp abstract_index.cpp) if (RESTAPI) list(APPEND CPP_SOURCES restapi/search_wrapper.cpp restapi/server.cpp) endif() add_library(${PROJECT_NAME} ${CPP_SOURCES}) add_library(${PROJECT_NAME}_s STATIC ${CPP_SOURCES}) endif() -install() + +if (NOT MSVC) + install(TARGETS ${PROJECT_NAME} LIBRARY) +endif() diff --git a/src/abstract_data_store.cpp b/src/abstract_data_store.cpp new file mode 100644 index 000000000..a980bd545 --- /dev/null +++ b/src/abstract_data_store.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include + +#include "abstract_data_store.h" + +namespace diskann +{ + +template +AbstractDataStore::AbstractDataStore(const location_t capacity, const size_t dim) + : _capacity(capacity), _dim(dim) +{ +} + +template location_t AbstractDataStore::capacity() const +{ + return _capacity; +} + +template size_t AbstractDataStore::get_dims() const +{ + return _dim; +} + +template location_t AbstractDataStore::resize(const location_t new_num_points) +{ + if (new_num_points > _capacity) + { + return expand(new_num_points); + } + else if (new_num_points < _capacity) + { + return shrink(new_num_points); + } + else + { + return _capacity; + } +} + +template DISKANN_DLLEXPORT class AbstractDataStore; +template DISKANN_DLLEXPORT class AbstractDataStore; +template DISKANN_DLLEXPORT class AbstractDataStore; +} // namespace diskann diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp new file mode 100644 index 000000000..518f8b7dd --- /dev/null +++ b/src/abstract_index.cpp @@ -0,0 +1,280 @@ +#include "common_includes.h" +#include "windows_customizations.h" +#include "abstract_index.h" + +namespace diskann +{ + +template +void AbstractIndex::build(const data_type *data, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, const std::vector &tags) +{ + auto any_data = std::any(data); + auto any_tags_vec = TagVector(tags); + this->_build(any_data, num_points_to_load, parameters, any_tags_vec); +} + +template +std::pair AbstractIndex::search(const data_type *query, const size_t K, const uint32_t L, + IDType *indices, float *distances) +{ + auto any_indices = std::any(indices); + auto any_query = std::any(query); + return _search(any_query, K, L, any_indices, distances); +} + +template +size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, + float *distances, std::vector &res_vectors) +{ + auto any_query = std::any(query); + auto any_tags = std::any(tags); + auto any_res_vectors = DataVector(res_vectors); + return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors); +} + +template +std::pair AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label, + const size_t K, const uint32_t L, IndexType *indices, + float *distances) +{ + auto any_indices = std::any(indices); + return _search_with_filters(query, raw_label, K, L, any_indices, distances); +} + +template +void AbstractIndex::search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices) +{ + auto any_query = std::any(query); + this->_search_with_optimized_layout(any_query, K, L, indices); +} + +template +int AbstractIndex::insert_point(const data_type *point, const tag_type tag) +{ + auto any_point = std::any(point); + auto any_tag = std::any(tag); + return this->_insert_point(any_point, any_tag); +} + +template int AbstractIndex::lazy_delete(const tag_type &tag) +{ + auto any_tag = std::any(tag); + return this->_lazy_delete(any_tag); +} + +template +void AbstractIndex::lazy_delete(const std::vector &tags, std::vector &failed_tags) +{ + auto any_tags = TagVector(tags); + auto any_failed_tags = TagVector(failed_tags); + this->_lazy_delete(any_tags, any_failed_tags); +} + +template void AbstractIndex::get_active_tags(tsl::robin_set &active_tags) +{ + auto any_active_tags = TagRobinSet(active_tags); + this->_get_active_tags(any_active_tags); +} + +template void AbstractIndex::set_start_points_at_random(data_type radius, uint32_t random_seed) +{ + auto any_radius = std::any(radius); + this->_set_start_points_at_random(any_radius, random_seed); +} + +template int AbstractIndex::get_vector_by_tag(tag_type &tag, data_type *vec) +{ + auto any_tag = std::any(tag); + auto any_data_ptr = std::any(vec); + return this->_get_vector_by_tag(any_tag, any_data_ptr); +} + +// exports +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const float *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const int8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); +template DISKANN_DLLEXPORT void AbstractIndex::build(const uint8_t *data, + const size_t num_points_to_load, + const IndexWriteParameters ¶meters, + const std::vector &tags); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::search( + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( + const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( + const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, + const uint32_t L, int32_t *tags, + float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t +AbstractIndex::search_with_tags(const uint8_t *query, const uint64_t K, const uint32_t L, + int32_t *tags, float *distances, std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, + const uint64_t K, const uint32_t L, + int32_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, + const uint32_t L, uint32_t *tags, + float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, + const uint64_t K, const uint32_t L, + uint32_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, + const uint32_t L, int64_t *tags, + float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t +AbstractIndex::search_with_tags(const uint8_t *query, const uint64_t K, const uint32_t L, + int64_t *tags, float *distances, std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, + const uint64_t K, const uint32_t L, + int64_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, + const uint32_t L, uint64_t *tags, + float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( + const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const int8_t *query, + const uint64_t K, const uint32_t L, + uint64_t *tags, float *distances, + std::vector &res_vectors); + +template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const float *query, size_t K, + size_t L, uint32_t *indices); +template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const uint8_t *query, size_t K, + size_t L, uint32_t *indices); +template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const int8_t *query, size_t K, + size_t L, uint32_t *indices); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const int32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const int32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const int32_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const uint32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const uint32_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const uint32_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const int64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const int64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const int64_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const float *point, const uint64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const uint8_t *point, const uint64_t tag); +template DISKANN_DLLEXPORT int AbstractIndex::insert_point(const int8_t *point, const uint64_t tag); + +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const int32_t &tag); +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const uint32_t &tag); +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const int64_t &tag); +template DISKANN_DLLEXPORT int AbstractIndex::lazy_delete(const uint64_t &tag); + +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); +template DISKANN_DLLEXPORT void AbstractIndex::lazy_delete(const std::vector &tags, + std::vector &failed_tags); + +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); +template DISKANN_DLLEXPORT void AbstractIndex::get_active_tags(tsl::robin_set &active_tags); + +template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random(float radius, uint32_t random_seed); +template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random(uint8_t radius, + uint32_t random_seed); +template DISKANN_DLLEXPORT void AbstractIndex::set_start_points_at_random(int8_t radius, uint32_t random_seed); + +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int32_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int32_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int32_t &tag, int8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint32_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint32_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint32_t &tag, int8_t *vec); + +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int64_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int64_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(int64_t &tag, int8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint64_t &tag, float *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint64_t &tag, uint8_t *vec); +template DISKANN_DLLEXPORT int AbstractIndex::get_vector_by_tag(uint64_t &tag, int8_t *vec); + +} // namespace diskann diff --git a/src/ann_exception.cpp b/src/ann_exception.cpp index 7ba70a6d9..ba55e3655 100644 --- a/src/ann_exception.cpp +++ b/src/ann_exception.cpp @@ -18,7 +18,7 @@ std::string package_string(const std::string &item_name, const std::string &item } ANNException::ANNException(const std::string &message, int errorCode, const std::string &funcSig, - const std::string &fileName, unsigned lineNum) + const std::string &fileName, uint32_t lineNum) : ANNException(package_string(std::string("FUNC"), funcSig) + package_string(std::string("FILE"), fileName) + package_string(std::string("LINE"), std::to_string(lineNum)) + " " + message, errorCode) @@ -26,7 +26,7 @@ ANNException::ANNException(const std::string &message, int errorCode, const std: } FileException::FileException(const std::string &filename, std::system_error &e, const std::string &funcSig, - const std::string &fileName, unsigned int lineNum) + const std::string &fileName, uint32_t lineNum) : ANNException(std::string(" While opening file \'") + filename + std::string("\', error code: ") + std::to_string(e.code().value()) + " " + e.code().message(), e.code().value(), funcSig, fileName, lineNum) diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index a2372abc3..08adb186c 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -24,9 +24,9 @@ namespace diskann void add_new_file_to_single_index(std::string index_file, std::string new_file) { - std::unique_ptr<_u64[]> metadata; - _u64 nr, nc; - diskann::load_bin<_u64>(index_file, metadata, nr, nc); + std::unique_ptr metadata; + uint64_t nr, nc; + diskann::load_bin(index_file, metadata, nr, nc); if (nc != 1) { std::stringstream stream; @@ -34,9 +34,9 @@ void add_new_file_to_single_index(std::string index_file, std::string new_file) throw diskann::ANNException(stream.str(), -1); } size_t index_ending_offset = metadata[nr - 1]; - _u64 read_blk_size = 64 * 1024 * 1024; + size_t read_blk_size = 64 * 1024 * 1024; cached_ofstream writer(index_file, read_blk_size); - _u64 check_file_size = get_file_size(index_file); + size_t check_file_size = get_file_size(index_file); if (check_file_size != index_ending_offset) { std::stringstream stream; @@ -57,7 +57,7 @@ void add_new_file_to_single_index(std::string index_file, std::string new_file) size_t num_blocks = DIV_ROUND_UP(fsize, read_blk_size); char *dump = new char[read_blk_size]; - for (_u64 i = 0; i < num_blocks; i++) + for (uint64_t i = 0; i < num_blocks; i++) { size_t cur_block_size = read_blk_size > fsize - (i * read_blk_size) ? fsize - (i * read_blk_size) : read_blk_size; @@ -68,12 +68,12 @@ void add_new_file_to_single_index(std::string index_file, std::string new_file) // writer.close(); delete[] dump; - std::vector<_u64> new_meta; - for (_u64 i = 0; i < nr; i++) + std::vector new_meta; + for (uint64_t i = 0; i < nr; i++) new_meta.push_back(metadata[i]); new_meta.push_back(metadata[nr - 1] + fsize); - diskann::save_bin<_u64>(index_file, new_meta.data(), new_meta.size(), 1); + diskann::save_bin(index_file, new_meta.data(), new_meta.size(), 1); } double get_memory_budget(double search_ram_budget) @@ -96,7 +96,7 @@ double get_memory_budget(const std::string &mem_budget_str) size_t calculate_num_pq_chunks(double final_index_ram_limit, size_t points_num, uint32_t dim, const std::vector ¶m_list) { - size_t num_pq_chunks = (size_t)(std::floor)(_u64(final_index_ram_limit / (double)points_num)); + size_t num_pq_chunks = (size_t)(std::floor)(uint64_t(final_index_ram_limit / (double)points_num)); diskann::cout << "Calculated num_pq_chunks :" << num_pq_chunks << std::endl; if (param_list.size() >= 6) { @@ -217,7 +217,7 @@ T *load_warmup(const std::string &cache_warmup_file, uint64_t &warmup_num, uint6 Support for Merging Many Vamana Indices ***************************************************/ -void read_idmap(const std::string &fname, std::vector &ivecs) +void read_idmap(const std::string &fname, std::vector &ivecs) { uint32_t npts32, dim; size_t actual_file_size = get_file_size(fname); @@ -239,27 +239,27 @@ void read_idmap(const std::string &fname, std::vector &ivecs) } int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suffix, const std::string &idmaps_prefix, - const std::string &idmaps_suffix, const _u64 nshards, unsigned max_degree, + const std::string &idmaps_suffix, const uint64_t nshards, uint32_t max_degree, const std::string &output_vamana, const std::string &medoids_file, bool use_filters, const std::string &labels_to_medoids_file) { // Read ID maps std::vector vamana_names(nshards); - std::vector> idmaps(nshards); - for (_u64 shard = 0; shard < nshards; shard++) + std::vector> idmaps(nshards); + for (uint64_t shard = 0; shard < nshards; shard++) { vamana_names[shard] = vamana_prefix + std::to_string(shard) + vamana_suffix; read_idmap(idmaps_prefix + std::to_string(shard) + idmaps_suffix, idmaps[shard]); } // find max node id - _u64 nnodes = 0; - _u64 nelems = 0; + size_t nnodes = 0; + size_t nelems = 0; for (auto &idmap : idmaps) { for (auto &id : idmap) { - nnodes = std::max(nnodes, (_u64)id); + nnodes = std::max(nnodes, (size_t)id); } nelems += idmap.size(); } @@ -267,15 +267,15 @@ int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suf diskann::cout << "# nodes: " << nnodes << ", max. degree: " << max_degree << std::endl; // compute inverse map: node -> shards - std::vector> node_shard; + std::vector> node_shard; node_shard.reserve(nelems); - for (_u64 shard = 0; shard < nshards; shard++) + for (size_t shard = 0; shard < nshards; shard++) { diskann::cout << "Creating inverse map -- shard #" << shard << std::endl; - for (_u64 idx = 0; idx < idmaps[shard].size(); idx++) + for (size_t idx = 0; idx < idmaps[shard].size(); idx++) { - _u64 node_id = idmaps[shard][idx]; - node_shard.push_back(std::make_pair((_u32)node_id, (_u32)shard)); + size_t node_id = idmaps[shard][idx]; + node_shard.push_back(std::make_pair((uint32_t)node_id, (uint32_t)shard)); } } std::sort(node_shard.begin(), node_shard.end(), [](const auto &left, const auto &right) { @@ -287,29 +287,29 @@ int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suf // combined file if (use_filters) { - std::unordered_map> global_label_to_medoids; + std::unordered_map> global_label_to_medoids; - for (_u64 i = 0; i < nshards; i++) + for (size_t i = 0; i < nshards; i++) { std::ifstream mapping_reader; std::string map_file = vamana_names[i] + "_labels_to_medoids.txt"; mapping_reader.open(map_file); std::string line, token; - unsigned line_cnt = 0; + uint32_t line_cnt = 0; while (std::getline(mapping_reader, line)) { std::istringstream iss(line); - _u32 cnt = 0; - _u32 medoid; - _u32 label; + uint32_t cnt = 0; + uint32_t medoid = 0; + uint32_t label = 0; while (std::getline(iss, token, ',')) { token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - unsigned token_as_num = std::stoul(token); + uint32_t token_as_num = std::stoul(token); if (cnt == 0) label = token_as_num; @@ -329,7 +329,7 @@ int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suf { mapping_writer << iter.first << ", "; auto &vec = iter.second; - for (_u32 idx = 0; idx < vec.size() - 1; idx++) + for (uint32_t idx = 0; idx < vec.size() - 1; idx++) { mapping_writer << vec[idx] << ", "; } @@ -340,7 +340,7 @@ int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suf // create cached vamana readers std::vector vamana_readers(nshards); - for (_u64 i = 0; i < nshards; i++) + for (size_t i = 0; i < nshards; i++) { vamana_readers[i].open(vamana_names[i], BUFFER_SIZE_FOR_CACHED_IO); size_t expected_file_size; @@ -348,8 +348,8 @@ int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suf } size_t vamana_metadata_size = - sizeof(_u64) + sizeof(_u32) + sizeof(_u32) + sizeof(_u64); // expected file size + max degree + medoid_id + - // frozen_point info + sizeof(uint64_t) + sizeof(uint32_t) + sizeof(uint32_t) + sizeof(uint64_t); // expected file size + max degree + + // medoid_id + frozen_point info // create cached vamana writers cached_ofstream merged_vamana_writer(output_vamana, BUFFER_SIZE_FOR_CACHED_IO); @@ -360,34 +360,34 @@ int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suf merged_vamana_writer.write((char *)&merged_index_size, sizeof(uint64_t)); // we will overwrite the index size at the end - unsigned output_width = max_degree; - unsigned max_input_width = 0; - // read width from each vamana to advance buffer by sizeof(unsigned) bytes + uint32_t output_width = max_degree; + uint32_t max_input_width = 0; + // read width from each vamana to advance buffer by sizeof(uint32_t) bytes for (auto &reader : vamana_readers) { - unsigned input_width; - reader.read((char *)&input_width, sizeof(unsigned)); + uint32_t input_width; + reader.read((char *)&input_width, sizeof(uint32_t)); max_input_width = input_width > max_input_width ? input_width : max_input_width; } diskann::cout << "Max input width: " << max_input_width << ", output width: " << output_width << std::endl; - merged_vamana_writer.write((char *)&output_width, sizeof(unsigned)); + merged_vamana_writer.write((char *)&output_width, sizeof(uint32_t)); std::ofstream medoid_writer(medoids_file.c_str(), std::ios::binary); - _u32 nshards_u32 = (_u32)nshards; - _u32 one_val = 1; + uint32_t nshards_u32 = (uint32_t)nshards; + uint32_t one_val = 1; medoid_writer.write((char *)&nshards_u32, sizeof(uint32_t)); medoid_writer.write((char *)&one_val, sizeof(uint32_t)); - _u64 vamana_index_frozen = 0; // as of now the functionality to merge many overlapping vamana - // indices is supported only for bulk indices without frozen point. - // Hence the final index will also not have any frozen points. - for (_u64 shard = 0; shard < nshards; shard++) + uint64_t vamana_index_frozen = 0; // as of now the functionality to merge many overlapping vamana + // indices is supported only for bulk indices without frozen point. + // Hence the final index will also not have any frozen points. + for (uint64_t shard = 0; shard < nshards; shard++) { - unsigned medoid; + uint32_t medoid; // read medoid - vamana_readers[shard].read((char *)&medoid, sizeof(unsigned)); - vamana_readers[shard].read((char *)&vamana_index_frozen, sizeof(_u64)); + vamana_readers[shard].read((char *)&medoid, sizeof(uint32_t)); + vamana_readers[shard].read((char *)&vamana_index_frozen, sizeof(uint64_t)); assert(vamana_index_frozen == false); // rename medoid medoid = idmaps[shard][medoid]; @@ -395,9 +395,9 @@ int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suf medoid_writer.write((char *)&medoid, sizeof(uint32_t)); // write renamed medoid if (shard == (nshards - 1)) //--> uncomment if running hierarchical - merged_vamana_writer.write((char *)&medoid, sizeof(unsigned)); + merged_vamana_writer.write((char *)&medoid, sizeof(uint32_t)); } - merged_vamana_writer.write((char *)&merged_index_frozen, sizeof(_u64)); + merged_vamana_writer.write((char *)&merged_index_frozen, sizeof(uint64_t)); medoid_writer.close(); diskann::cout << "Starting merge" << std::endl; @@ -407,23 +407,23 @@ int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suf std::mt19937 urng(rng()); std::vector nhood_set(nnodes, 0); - std::vector final_nhood; + std::vector final_nhood; - unsigned nnbrs = 0, shard_nnbrs = 0; - unsigned cur_id = 0; + uint32_t nnbrs = 0, shard_nnbrs = 0; + uint32_t cur_id = 0; for (const auto &id_shard : node_shard) { - unsigned node_id = id_shard.first; - unsigned shard_id = id_shard.second; + uint32_t node_id = id_shard.first; + uint32_t shard_id = id_shard.second; if (cur_id < node_id) { // Gopal. random_shuffle() is deprecated. std::shuffle(final_nhood.begin(), final_nhood.end(), urng); - nnbrs = (unsigned)(std::min)(final_nhood.size(), (uint64_t)max_degree); + nnbrs = (uint32_t)(std::min)(final_nhood.size(), (uint64_t)max_degree); // write into merged ofstream - merged_vamana_writer.write((char *)&nnbrs, sizeof(unsigned)); - merged_vamana_writer.write((char *)final_nhood.data(), nnbrs * sizeof(unsigned)); - merged_index_size += (sizeof(unsigned) + nnbrs * sizeof(unsigned)); + merged_vamana_writer.write((char *)&nnbrs, sizeof(uint32_t)); + merged_vamana_writer.write((char *)final_nhood.data(), nnbrs * sizeof(uint32_t)); + merged_index_size += (sizeof(uint32_t) + nnbrs * sizeof(uint32_t)); if (cur_id % 499999 == 1) { diskann::cout << "." << std::flush; @@ -435,18 +435,18 @@ int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suf final_nhood.clear(); } // read from shard_id ifstream - vamana_readers[shard_id].read((char *)&shard_nnbrs, sizeof(unsigned)); + vamana_readers[shard_id].read((char *)&shard_nnbrs, sizeof(uint32_t)); if (shard_nnbrs == 0) { diskann::cout << "WARNING: shard #" << shard_id << ", node_id " << node_id << " has 0 nbrs" << std::endl; } - std::vector shard_nhood(shard_nnbrs); + std::vector shard_nhood(shard_nnbrs); if (shard_nnbrs > 0) - vamana_readers[shard_id].read((char *)shard_nhood.data(), shard_nnbrs * sizeof(unsigned)); + vamana_readers[shard_id].read((char *)shard_nhood.data(), shard_nnbrs * sizeof(uint32_t)); // rename nodes - for (_u64 j = 0; j < shard_nnbrs; j++) + for (uint64_t j = 0; j < shard_nnbrs; j++) { if (nhood_set[idmaps[shard_id][shard_nhood[j]]] == 0) { @@ -458,14 +458,14 @@ int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suf // Gopal. random_shuffle() is deprecated. std::shuffle(final_nhood.begin(), final_nhood.end(), urng); - nnbrs = (unsigned)(std::min)(final_nhood.size(), (uint64_t)max_degree); + nnbrs = (uint32_t)(std::min)(final_nhood.size(), (uint64_t)max_degree); // write into merged ofstream - merged_vamana_writer.write((char *)&nnbrs, sizeof(unsigned)); + merged_vamana_writer.write((char *)&nnbrs, sizeof(uint32_t)); if (nnbrs > 0) { - merged_vamana_writer.write((char *)final_nhood.data(), nnbrs * sizeof(unsigned)); + merged_vamana_writer.write((char *)final_nhood.data(), nnbrs * sizeof(uint32_t)); } - merged_index_size += (sizeof(unsigned) + nnbrs * sizeof(unsigned)); + merged_index_size += (sizeof(uint32_t) + nnbrs * sizeof(uint32_t)); for (auto &p : final_nhood) nhood_set[p] = 0; final_nhood.clear(); @@ -488,32 +488,32 @@ int merge_shards(const std::string &vamana_prefix, const std::string &vamana_suf the new nodes at the end. The dummy map contains the real graph id of the new nodes added to the graph */ template -void breakup_dense_points(const std::string data_file, const std::string labels_file, _u32 density, +void breakup_dense_points(const std::string data_file, const std::string labels_file, uint32_t density, const std::string out_data_file, const std::string out_labels_file, const std::string out_metadata_file) { std::string token, line; std::ifstream labels_stream(labels_file); T *data; - _u64 npts, ndims; + uint64_t npts, ndims; diskann::load_bin(data_file, data, npts, ndims); - std::unordered_map<_u32, _u32> dummy_pt_ids; - _u32 next_dummy_id = (_u32)npts; + std::unordered_map dummy_pt_ids; + uint32_t next_dummy_id = (uint32_t)npts; - _u32 point_cnt = 0; + uint32_t point_cnt = 0; std::vector> labels_per_point; labels_per_point.resize(npts); - _u32 dense_pts = 0; + uint32_t dense_pts = 0; if (labels_stream.is_open()) { while (getline(labels_stream, line)) { std::stringstream iss(line); - _u32 lbl_cnt = 0; - _u32 label_host = point_cnt; + uint32_t lbl_cnt = 0; + uint32_t label_host = point_cnt; while (getline(iss, token, ',')) { if (lbl_cnt == density) @@ -522,13 +522,13 @@ void breakup_dense_points(const std::string data_file, const std::string labels_ dense_pts++; label_host = next_dummy_id; labels_per_point.resize(next_dummy_id + 1); - dummy_pt_ids[next_dummy_id] = (_u32)point_cnt; + dummy_pt_ids[next_dummy_id] = (uint32_t)point_cnt; next_dummy_id++; lbl_cnt = 0; } token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - unsigned token_as_num = std::stoul(token); + uint32_t token_as_num = std::stoul(token); labels_per_point[label_host].push_back(token_as_num); lbl_cnt++; } @@ -543,9 +543,9 @@ void breakup_dense_points(const std::string data_file, const std::string labels_ diskann::cout << labels_per_point.size() << " is the new number of points" << std::endl; std::ofstream label_writer(out_labels_file); assert(label_writer.is_open()); - for (_u32 i = 0; i < labels_per_point.size(); i++) + for (uint32_t i = 0; i < labels_per_point.size(); i++) { - for (_u32 j = 0; j < (labels_per_point[i].size() - 1); j++) + for (uint32_t j = 0; j < (labels_per_point[i].size() - 1); j++) { label_writer << labels_per_point[i][j] << ","; } @@ -579,11 +579,11 @@ void extract_shard_labels(const std::string &in_label_file, const std::string &s // point in labels file diskann::cout << "Extracting labels for shard" << std::endl; - _u32 *ids = nullptr; - _u64 num_ids, tmp_dim; + uint32_t *ids = nullptr; + uint64_t num_ids, tmp_dim; diskann::load_bin(shard_ids_bin, ids, num_ids, tmp_dim); - _u32 counter = 0, shard_counter = 0; + uint32_t counter = 0, shard_counter = 0; std::string cur_line; std::ifstream label_reader(in_label_file); @@ -611,55 +611,49 @@ void extract_shard_labels(const std::string &in_label_file, const std::string &s } template -int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, +int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq, - bool use_filters, const std::string &label_file, + uint32_t num_threads, bool use_filters, const std::string &label_file, const std::string &labels_to_medoids_file, const std::string &universal_label, - const _u32 Lf) + const uint32_t Lf, uint32_t universal_label_num = 0) { size_t base_num, base_dim; diskann::get_bin_metadata(base_file, base_num, base_dim); - double full_index_ram = estimate_ram_usage(base_num, base_dim, sizeof(T), R); + double full_index_ram = estimate_ram_usage(base_num, (uint32_t)base_dim, sizeof(T), R); // TODO: Make this honest when there is filter support if (full_index_ram < ram_budget * 1024 * 1024 * 1024) { diskann::cout << "Full index fits in RAM budget, should consume at most " << full_index_ram / (1024 * 1024 * 1024) << "GiBs, so building in one shot" << std::endl; - diskann::Parameters paras; - paras.Set("L", (unsigned)L); - paras.Set("Lf", (unsigned)Lf); - paras.Set("R", (unsigned)R); - paras.Set("C", 750); - paras.Set("alpha", 1.2f); - paras.Set("num_rnds", 2); - if (!use_filters) - paras.Set("saturate_graph", 1); - else - paras.Set("saturate_graph", 0); + + diskann::IndexWriteParameters paras = diskann::IndexWriteParametersBuilder(L, R) + .with_filter_list_size(Lf) + .with_saturate_graph(!use_filters) + .with_num_threads(num_threads) + .build(); using TagT = uint32_t; - paras.Set("save_path", mem_index_path); - std::unique_ptr> _pvamanaIndex = - std::unique_ptr>(new diskann::Index( - compareMetric, base_dim, base_num, false, false, false, build_pq_bytes > 0, build_pq_bytes, use_opq)); + diskann::Index _index(compareMetric, base_dim, base_num, false, false, false, + build_pq_bytes > 0, build_pq_bytes, use_opq); if (!use_filters) - _pvamanaIndex->build(base_file.c_str(), base_num, paras); + _index.build(base_file.c_str(), base_num, paras); else { if (universal_label != "") { // indicates no universal label - LabelT unv_label_as_num = 0; - _pvamanaIndex->set_universal_label(unv_label_as_num); + // LabelT unv_label_as_num = 0; + _index.set_universal_label(universal_label_num); } - _pvamanaIndex->build_filtered_index(base_file.c_str(), label_file, base_num, paras); + _index.build_filtered_index(base_file.c_str(), label_file, base_num, paras); } - _pvamanaIndex->save(mem_index_path.c_str()); + _index.save(mem_index_path.c_str()); if (use_filters) { - // need to copy the labels_to_medoids file to the specified input file + // need to copy the labels_to_medoids file to the specified input + // file std::remove(labels_to_medoids_file.c_str()); std::string mem_labels_to_medoid_file = mem_index_path + "_labels_to_medoids.txt"; copy_file(mem_labels_to_medoid_file, labels_to_medoids_file); @@ -697,36 +691,28 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr std::string shard_index_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index"; - diskann::Parameters paras; - paras.Set("L", L); - paras.Set("Lf", Lf); - paras.Set("R", (2 * (R / 3))); - paras.Set("C", 750); - paras.Set("alpha", 1.2f); - paras.Set("num_rnds", 2); - paras.Set("saturate_graph", 0); - paras.Set("save_path", shard_index_file); - - _u64 shard_base_dim, shard_base_pts; + diskann::IndexWriteParameters paras = + diskann::IndexWriteParametersBuilder(L, (2 * R / 3)).with_filter_list_size(Lf).build(); + + uint64_t shard_base_dim, shard_base_pts; get_bin_metadata(shard_base_file, shard_base_pts, shard_base_dim); - std::unique_ptr> _pvamanaIndex = std::unique_ptr>( - new diskann::Index(compareMetric, shard_base_dim, shard_base_pts, false, false, false, - build_pq_bytes > 0, build_pq_bytes, use_opq)); + diskann::Index _index(compareMetric, shard_base_dim, shard_base_pts, false, false, false, build_pq_bytes > 0, + build_pq_bytes, use_opq); if (!use_filters) { - _pvamanaIndex->build(shard_base_file.c_str(), shard_base_pts, paras); + _index.build(shard_base_file.c_str(), shard_base_pts, paras); } else { diskann::extract_shard_labels(label_file, shard_ids_file, shard_labels_file); if (universal_label != "") { // indicates no universal label - LabelT unv_label_as_num = 0; - _pvamanaIndex->set_universal_label(unv_label_as_num); +// LabelT unv_label_as_num = 0; + _index.set_universal_label(universal_label_num); } - _pvamanaIndex->build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts, paras); + _index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts, paras); } - _pvamanaIndex->save(shard_index_file.c_str()); + _index.save(shard_index_file.c_str()); // copy universal label file from first shard to the final destination // index, since all shards anyway share the universal label if (p == 0) @@ -781,8 +767,8 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr // 99.9 latency not blowing up template uint32_t optimize_beamwidth(std::unique_ptr> &pFlashIndex, T *tuning_sample, - _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, - uint32_t start_bw) + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, + uint32_t nthreads, uint32_t start_bw) { uint32_t cur_bw = start_bw; double max_qps = 0; @@ -797,7 +783,7 @@ uint32_t optimize_beamwidth(std::unique_ptr> &p auto s = std::chrono::high_resolution_clock::now(); #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) - for (_s64 i = 0; i < (int64_t)tuning_sample_num; i++) + for (int64_t i = 0; i < (int64_t)tuning_sample_num; i++) { pFlashIndex->cached_beam_search(tuning_sample + (i * tuning_sample_aligned_dim), 1, L, tuning_sample_result_ids_64.data() + (i * 1), @@ -835,11 +821,11 @@ template void create_disk_layout(const std::string base_file, const std::string mem_index_file, const std::string output_file, const std::string reorder_data_file) { - unsigned npts, ndims; + uint32_t npts, ndims; // amount to read or write in one shot - _u64 read_blk_size = 64 * 1024 * 1024; - _u64 write_blk_size = read_blk_size; + size_t read_blk_size = 64 * 1024 * 1024; + size_t write_blk_size = read_blk_size; cached_ifstream base_reader(base_file, read_blk_size); base_reader.read((char *)&npts, sizeof(uint32_t)); base_reader.read((char *)&ndims, sizeof(uint32_t)); @@ -852,7 +838,7 @@ void create_disk_layout(const std::string base_file, const std::string mem_index bool append_reorder_data = false; std::ifstream reorder_data_reader; - unsigned npts_reorder_file = 0, ndims_reorder_file = 0; + uint32_t npts_reorder_file = 0, ndims_reorder_file = 0; if (reorder_data_file != std::string("")) { append_reorder_data = true; @@ -862,11 +848,12 @@ void create_disk_layout(const std::string base_file, const std::string mem_index try { reorder_data_reader.open(reorder_data_file, std::ios::binary); - reorder_data_reader.read((char *)&npts_reorder_file, sizeof(unsigned)); - reorder_data_reader.read((char *)&ndims_reorder_file, sizeof(unsigned)); + reorder_data_reader.read((char *)&npts_reorder_file, sizeof(uint32_t)); + reorder_data_reader.read((char *)&ndims_reorder_file, sizeof(uint32_t)); if (npts_reorder_file != npts) - throw ANNException("Mismatch in num_points between reorder data file and base file", -1, __FUNCSIG__, - __FILE__, __LINE__); + throw ANNException("Mismatch in num_points between reorder " + "data file and base file", + -1, __FUNCSIG__, __FILE__, __LINE__); if (reorder_data_file_size != 8 + sizeof(float) * (size_t)npts_reorder_file * (size_t)ndims_reorder_file) throw ANNException("Discrepancy in reorder data file size ", -1, __FUNCSIG__, __FILE__, __LINE__); } @@ -883,7 +870,7 @@ void create_disk_layout(const std::string base_file, const std::string mem_index cached_ofstream diskann_writer(output_file, write_blk_size); // metadata: width, medoid - unsigned width_u32, medoid_u32; + uint32_t width_u32, medoid_u32; size_t index_file_size; vamana_reader.read((char *)&index_file_size, sizeof(uint64_t)); @@ -896,18 +883,18 @@ void create_disk_layout(const std::string base_file, const std::string mem_index throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } - _u64 vamana_frozen_num = false, vamana_frozen_loc = 0; + uint64_t vamana_frozen_num = false, vamana_frozen_loc = 0; - vamana_reader.read((char *)&width_u32, sizeof(unsigned)); - vamana_reader.read((char *)&medoid_u32, sizeof(unsigned)); - vamana_reader.read((char *)&vamana_frozen_num, sizeof(_u64)); + vamana_reader.read((char *)&width_u32, sizeof(uint32_t)); + vamana_reader.read((char *)&medoid_u32, sizeof(uint32_t)); + vamana_reader.read((char *)&vamana_frozen_num, sizeof(uint64_t)); // compute - _u64 medoid, max_node_len, nnodes_per_sector; - npts_64 = (_u64)npts; - medoid = (_u64)medoid_u32; + uint64_t medoid, max_node_len, nnodes_per_sector; + npts_64 = (uint64_t)npts; + medoid = (uint64_t)medoid_u32; if (vamana_frozen_num == 1) vamana_frozen_loc = medoid; - max_node_len = (((_u64)width_u32 + 1) * sizeof(unsigned)) + (ndims_64 * sizeof(T)); + max_node_len = (((uint64_t)width_u32 + 1) * sizeof(uint32_t)) + (ndims_64 * sizeof(T)); nnodes_per_sector = SECTOR_LEN / max_node_len; diskann::cout << "medoid: " << medoid << "B" << std::endl; @@ -917,22 +904,22 @@ void create_disk_layout(const std::string base_file, const std::string mem_index // SECTOR_LEN buffer for each sector std::unique_ptr sector_buf = std::make_unique(SECTOR_LEN); std::unique_ptr node_buf = std::make_unique(max_node_len); - unsigned &nnbrs = *(unsigned *)(node_buf.get() + ndims_64 * sizeof(T)); - unsigned *nhood_buf = (unsigned *)(node_buf.get() + (ndims_64 * sizeof(T)) + sizeof(unsigned)); + uint32_t &nnbrs = *(uint32_t *)(node_buf.get() + ndims_64 * sizeof(T)); + uint32_t *nhood_buf = (uint32_t *)(node_buf.get() + (ndims_64 * sizeof(T)) + sizeof(uint32_t)); // number of sectors (1 for meta data) - _u64 n_sectors = ROUND_UP(npts_64, nnodes_per_sector) / nnodes_per_sector; - _u64 n_reorder_sectors = 0; - _u64 n_data_nodes_per_sector = 0; + uint64_t n_sectors = ROUND_UP(npts_64, nnodes_per_sector) / nnodes_per_sector; + uint64_t n_reorder_sectors = 0; + uint64_t n_data_nodes_per_sector = 0; if (append_reorder_data) { n_data_nodes_per_sector = SECTOR_LEN / (ndims_reorder_file * sizeof(float)); n_reorder_sectors = ROUND_UP(npts_64, n_data_nodes_per_sector) / n_data_nodes_per_sector; } - _u64 disk_index_file_size = (n_sectors + n_reorder_sectors + 1) * SECTOR_LEN; + uint64_t disk_index_file_size = (n_sectors + n_reorder_sectors + 1) * SECTOR_LEN; - std::vector<_u64> output_file_meta; + std::vector output_file_meta; output_file_meta.push_back(npts_64); output_file_meta.push_back(ndims_64); output_file_meta.push_back(medoid); @@ -940,7 +927,7 @@ void create_disk_layout(const std::string base_file, const std::string mem_index output_file_meta.push_back(nnodes_per_sector); output_file_meta.push_back(vamana_frozen_num); output_file_meta.push_back(vamana_frozen_loc); - output_file_meta.push_back((_u64)append_reorder_data); + output_file_meta.push_back((uint64_t)append_reorder_data); if (append_reorder_data) { output_file_meta.push_back(n_sectors + 1); @@ -953,42 +940,42 @@ void create_disk_layout(const std::string base_file, const std::string mem_index std::unique_ptr cur_node_coords = std::make_unique(ndims_64); diskann::cout << "# sectors: " << n_sectors << std::endl; - _u64 cur_node_id = 0; - for (_u64 sector = 0; sector < n_sectors; sector++) + uint64_t cur_node_id = 0; + for (uint64_t sector = 0; sector < n_sectors; sector++) { if (sector % 100000 == 0) { diskann::cout << "Sector #" << sector << "written" << std::endl; } memset(sector_buf.get(), 0, SECTOR_LEN); - for (_u64 sector_node_id = 0; sector_node_id < nnodes_per_sector && cur_node_id < npts_64; sector_node_id++) + for (uint64_t sector_node_id = 0; sector_node_id < nnodes_per_sector && cur_node_id < npts_64; sector_node_id++) { memset(node_buf.get(), 0, max_node_len); // read cur node's nnbrs - vamana_reader.read((char *)&nnbrs, sizeof(unsigned)); + vamana_reader.read((char *)&nnbrs, sizeof(uint32_t)); // sanity checks on nnbrs assert(nnbrs > 0); assert(nnbrs <= width_u32); // read node's nhood - vamana_reader.read((char *)nhood_buf, (std::min)(nnbrs, width_u32) * sizeof(unsigned)); + vamana_reader.read((char *)nhood_buf, (std::min)(nnbrs, width_u32) * sizeof(uint32_t)); if (nnbrs > width_u32) { - vamana_reader.seekg((nnbrs - width_u32) * sizeof(unsigned), vamana_reader.cur); + vamana_reader.seekg((nnbrs - width_u32) * sizeof(uint32_t), vamana_reader.cur); } // write coords of node first - // T *node_coords = data + ((_u64) ndims_64 * cur_node_id); + // T *node_coords = data + ((uint64_t) ndims_64 * cur_node_id); base_reader.read((char *)cur_node_coords.get(), sizeof(T) * ndims_64); memcpy(node_buf.get(), cur_node_coords.get(), ndims_64 * sizeof(T)); // write nnbrs - *(unsigned *)(node_buf.get() + ndims_64 * sizeof(T)) = (std::min)(nnbrs, width_u32); + *(uint32_t *)(node_buf.get() + ndims_64 * sizeof(T)) = (std::min)(nnbrs, width_u32); // write nhood next - memcpy(node_buf.get() + ndims_64 * sizeof(T) + sizeof(unsigned), nhood_buf, - (std::min)(nnbrs, width_u32) * sizeof(unsigned)); + memcpy(node_buf.get() + ndims_64 * sizeof(T) + sizeof(uint32_t), nhood_buf, + (std::min)(nnbrs, width_u32) * sizeof(uint32_t)); // get offset into sector_buf char *sector_node_buf = sector_buf.get() + (sector_node_id * max_node_len); @@ -1007,7 +994,7 @@ void create_disk_layout(const std::string base_file, const std::string mem_index auto vec_len = ndims_reorder_file * sizeof(float); std::unique_ptr vec_buf = std::make_unique(vec_len); - for (_u64 sector = 0; sector < n_reorder_sectors; sector++) + for (uint64_t sector = 0; sector < n_reorder_sectors; sector++) { if (sector % 100000 == 0) { @@ -1016,7 +1003,7 @@ void create_disk_layout(const std::string base_file, const std::string mem_index memset(sector_buf.get(), 0, SECTOR_LEN); - for (_u64 sector_node_id = 0; sector_node_id < n_data_nodes_per_sector && sector_node_id < npts_64; + for (uint64_t sector_node_id = 0; sector_node_id < n_data_nodes_per_sector && sector_node_id < npts_64; sector_node_id++) { memset(vec_buf.get(), 0, vec_len); @@ -1030,14 +1017,15 @@ void create_disk_layout(const std::string base_file, const std::string mem_index } } diskann_writer.close(); - diskann::save_bin<_u64>(output_file, output_file_meta.data(), output_file_meta.size(), 1, 0); + diskann::save_bin(output_file, output_file_meta.data(), output_file_meta.size(), 1, 0); diskann::cout << "Output disk index file written to " << output_file << std::endl; } template int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, - diskann::Metric compareMetric, bool use_opq, bool use_filters, const std::string &label_file, - const std::string &universal_label, const _u32 filter_threshold, const _u32 Lf) + diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, + const uint32_t Lf) { std::stringstream parser; parser << std::string(indexBuildParameters); @@ -1047,7 +1035,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const { param_list.push_back(cur_param); } - if (param_list.size() < 5 || param_list.size() > 8) + if (param_list.size() < 5 || param_list.size() > 9) { diskann::cout << "Correct usage of parameters is R (max degree)\n" "L (indexing list size, better if >= R)\n" @@ -1059,7 +1047,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const "reorder (set true to include full precision in data file" ": optional paramter, use only when using disk PQ\n" "build_PQ_byte (number of PQ bytes for inde build; set 0 to use " - "full precision vectors)" + "full precision vectors)\n" + "QD Quantized Dimension to overwrite the derived dim from B " << std::endl; return -1; } @@ -1097,7 +1086,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const } } - if (param_list.size() == 8) + if (param_list.size() >= 8) { build_pq_bytes = atoi(param_list[7].c_str()); } @@ -1107,7 +1096,9 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const std::string labels_file_original = label_file; std::string index_prefix_path(indexFilePath); std::string labels_file_to_use = index_prefix_path + "_label_formatted.txt"; - std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin"; + std::string pq_pivots_path_base = codebook_prefix; + std::string pq_pivots_path = file_exists(pq_pivots_path_base) ? pq_pivots_path_base + "_pq_pivots.bin" + : index_prefix_path + "_pq_pivots.bin"; std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin"; std::string mem_index_path = index_prefix_path + "_mem.index"; std::string disk_index_path = index_prefix_path + "_disk.index"; @@ -1148,8 +1139,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const diskann::cout << timer.elapsed_seconds_for_step("preprocessing data for inner product") << std::endl; } - unsigned R = (unsigned)atoi(param_list[0].c_str()); - unsigned L = (unsigned)atoi(param_list[1].c_str()); + uint32_t R = (uint32_t)atoi(param_list[0].c_str()); + uint32_t L = (uint32_t)atoi(param_list[1].c_str()); double final_index_ram_limit = get_memory_budget(param_list[2]); if (final_index_ram_limit <= 0) @@ -1165,7 +1156,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const std::cerr << "Not building index. Please provide more RAM budget" << std::endl; return -1; } - _u32 num_threads = (_u32)atoi(param_list[4].c_str()); + uint32_t num_threads = (uint32_t)atoi(param_list[4].c_str()); if (num_threads != 0) { @@ -1182,9 +1173,9 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const // into replica dummy points which evenly distribute the filters. The rest // of index build happens on the augmented base and labels std::string augmented_data_file, augmented_labels_file; + std::uint32_t universal_label_id = 0; if (use_filters) { - std::uint32_t universal_label_id = 0; convert_labels_string_to_int(labels_file_original, labels_file_to_use, disk_labels_int_map_file, universal_label, universal_label_id); augmented_data_file = index_prefix_path + "_augmented_data.bin"; @@ -1194,8 +1185,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const dummy_remap_file = index_prefix_path + "_dummy_remap.txt"; breakup_dense_points(data_file_to_use, labels_file_to_use, filter_threshold, augmented_data_file, augmented_labels_file, - dummy_remap_file); // RKNOTE: This has large memory footprint, need - // to make this streaming + dummy_remap_file); // RKNOTE: This has large memory footprint, + // need to make this streaming data_file_to_use = augmented_data_file; labels_file_to_use = augmented_labels_file; } @@ -1212,17 +1203,25 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const generate_disk_quantized_data(data_file_to_use, disk_pq_pivots_path, disk_pq_compressed_vectors_path, compareMetric, p_val, disk_pq_dims); } - size_t num_pq_chunks = (size_t)(std::floor)(_u64(final_index_ram_limit / points_num)); + size_t num_pq_chunks = (size_t)(std::floor)(uint64_t(final_index_ram_limit / points_num)); num_pq_chunks = num_pq_chunks <= 0 ? 1 : num_pq_chunks; num_pq_chunks = num_pq_chunks > dim ? dim : num_pq_chunks; num_pq_chunks = num_pq_chunks > MAX_PQ_CHUNKS ? MAX_PQ_CHUNKS : num_pq_chunks; + if (param_list.size() >= 9 && atoi(param_list[8].c_str()) <= MAX_PQ_CHUNKS && atoi(param_list[8].c_str()) > 0) + { + std::cout << "Use quantized dimension (QD) to overwrite derived quantized " + "dimension from search_DRAM_budget (B)" + << std::endl; + num_pq_chunks = atoi(param_list[8].c_str()); + } + diskann::cout << "Compressing " << dim << "-dimensional data into " << num_pq_chunks << " bytes per vector." << std::endl; generate_quantized_data(data_file_to_use, pq_pivots_path, pq_compressed_vectors_path, compareMetric, p_val, - num_pq_chunks, use_opq); + num_pq_chunks, use_opq, codebook_prefix); diskann::cout << timer.elapsed_seconds_for_step("generating quantized data") << std::endl; // Gopal. Splitting diskann_dll into separate DLLs for search and build. @@ -1234,8 +1233,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const timer.reset(); diskann::build_merged_vamana_index(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, indexing_ram_budget, mem_index_path, medoids_path, centroids_path, - build_pq_bytes, use_opq, use_filters, labels_file_to_use, - labels_to_medoids_path, universal_label, Lf); + build_pq_bytes, use_opq, num_threads, use_filters, labels_file_to_use, + labels_to_medoids_path, universal_label, Lf, universal_label_id); diskann::cout << timer.elapsed_seconds_for_step("building merged vamana index") << std::endl; timer.reset(); @@ -1246,10 +1245,10 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const else { if (!reorder_data) - diskann::create_disk_layout<_u8>(disk_pq_compressed_vectors_path, mem_index_path, disk_index_path); + diskann::create_disk_layout(disk_pq_compressed_vectors_path, mem_index_path, disk_index_path); else - diskann::create_disk_layout<_u8>(disk_pq_compressed_vectors_path, mem_index_path, disk_index_path, - data_file_to_use.c_str()); + diskann::create_disk_layout(disk_pq_compressed_vectors_path, mem_index_path, disk_index_path, + data_file_to_use.c_str()); } diskann::cout << timer.elapsed_seconds_for_step("generating disk layout") << std::endl; @@ -1316,91 +1315,97 @@ template DISKANN_DLLEXPORT float *load_warmup(MemoryMappedFiles &files, c template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( std::unique_ptr> &pFlashIndex, int8_t *tuning_sample, - _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( std::unique_ptr> &pFlashIndex, uint8_t *tuning_sample, - _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, float *tuning_sample, _u64 tuning_sample_num, - _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); + std::unique_ptr> &pFlashIndex, float *tuning_sample, + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( std::unique_ptr> &pFlashIndex, int8_t *tuning_sample, - _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( std::unique_ptr> &pFlashIndex, uint8_t *tuning_sample, - _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, float *tuning_sample, _u64 tuning_sample_num, - _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); + std::unique_ptr> &pFlashIndex, float *tuning_sample, + uint64_t tuning_sample_num, uint64_t tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, - bool use_filters, const std::string &label_file, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, - const _u32 filter_threshold, const _u32 Lf); + const uint32_t filter_threshold, const uint32_t Lf); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, - bool use_filters, const std::string &label_file, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, - const _u32 filter_threshold, const _u32 Lf); + const uint32_t filter_threshold, const uint32_t Lf); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, - bool use_filters, const std::string &label_file, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, - const _u32 filter_threshold, const _u32 Lf); + const uint32_t filter_threshold, const uint32_t Lf); // LabelT = uint16 template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, - bool use_filters, const std::string &label_file, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, - const _u32 filter_threshold, const _u32 Lf); + const uint32_t filter_threshold, const uint32_t Lf); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, - bool use_filters, const std::string &label_file, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, - const _u32 filter_threshold, const _u32 Lf); + const uint32_t filter_threshold, const uint32_t Lf); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, - bool use_filters, const std::string &label_file, + const std::string &codebook_prefix, bool use_filters, + const std::string &label_file, const std::string &universal_label, - const _u32 filter_threshold, const _u32 Lf); + const uint32_t filter_threshold, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const _u32 Lf); + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const _u32 Lf); + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const _u32 Lf); + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); // Label=16_t template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const _u32 Lf); + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const _u32 Lf); + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); template DISKANN_DLLEXPORT int build_merged_vamana_index( - std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, + std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, - size_t build_pq_bytes, bool use_opq, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const _u32 Lf); + size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); }; // namespace diskann diff --git a/src/distance.cpp b/src/distance.cpp index 15d30c8cb..31ab9d3ff 100644 --- a/src/distance.cpp +++ b/src/distance.cpp @@ -22,6 +22,49 @@ namespace diskann { +// +// Base Class Implementatons +// +template +float Distance::compare(const T *a, const T *b, const float normA, const float normB, uint32_t length) const +{ + throw std::logic_error("This function is not implemented."); +} + +template uint32_t Distance::post_normalization_dimension(uint32_t orig_dimension) const +{ + return orig_dimension; +} + +template diskann::Metric Distance::get_metric() const +{ + return _distance_metric; +} + +template bool Distance::preprocessing_required() const +{ + return false; +} + +template +void Distance::preprocess_base_points(T *original_data, const size_t orig_dim, const size_t num_points) +{ +} + +template void Distance::preprocess_query(const T *query_vec, const size_t query_dim, T *scratch_query) +{ + std::memcpy(scratch_query, query_vec, query_dim * sizeof(T)); +} + +template size_t Distance::get_required_alignment() const +{ + return _alignment_factor; +} + +template Distance::~Distance() +{ +} + // // Cosine distance functions. // @@ -104,7 +147,7 @@ float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size) c #else int32_t result = 0; #pragma omp simd reduction(+ : result) aligned(a, b : 8) - for (_s32 i = 0; i < (_s32)size; i++) + for (int32_t i = 0; i < (int32_t)size; i++) { result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * ((int32_t)((int16_t)a[i] - (int16_t)b[i])); } @@ -181,12 +224,12 @@ float DistanceL2Float::compare(const float *a, const float *b, uint32_t size) co return result; } -float SlowDistanceL2Float::compare(const float *a, const float *b, uint32_t length) const +template float SlowDistanceL2::compare(const T *a, const T *b, uint32_t length) const { float result = 0.0f; for (uint32_t i = 0; i < length; i++) { - result += (a[i] - b[i]) * (a[i] - b[i]); + result += ((float)(a[i] - b[i])) * (a[i] - b[i]); } return result; } @@ -259,7 +302,7 @@ float AVXDistanceL2Float::compare(const float *, const float *, uint32_t) const } #endif -template float DistanceInnerProduct::inner_product(const T *a, const T *b, unsigned size) const +template float DistanceInnerProduct::inner_product(const T *a, const T *b, uint32_t size) const { if (!std::is_floating_point::value) { @@ -281,9 +324,9 @@ template float DistanceInnerProduct::inner_product(const T *a, c __m256 sum; __m256 l0, l1; __m256 r0, r1; - unsigned D = (size + 7) & ~7U; - unsigned DR = D % 16; - unsigned DD = D - DR; + uint32_t D = (size + 7) & ~7U; + uint32_t DR = D % 16; + uint32_t DD = D - DR; const float *l = (float *)a; const float *r = (float *)b; const float *e_l = l + DD; @@ -296,7 +339,7 @@ template float DistanceInnerProduct::inner_product(const T *a, c AVX_DOT(e_l, e_r, sum, l0, r0); } - for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) + for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16) { AVX_DOT(l, r, sum, l0, r0); AVX_DOT(l + 8, r + 8, sum, l1, r1); @@ -314,9 +357,9 @@ template float DistanceInnerProduct::inner_product(const T *a, c __m128 sum; __m128 l0, l1, l2, l3; __m128 r0, r1, r2, r3; - unsigned D = (size + 3) & ~3U; - unsigned DR = D % 16; - unsigned DD = D - DR; + uint32_t D = (size + 3) & ~3U; + uint32_t DR = D % 16; + uint32_t DD = D - DR; const float *l = a; const float *r = b; const float *e_l = l + DD; @@ -335,7 +378,7 @@ template float DistanceInnerProduct::inner_product(const T *a, c default: break; } - for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) + for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16) { SSE_DOT(l, r, sum, l0, r0); SSE_DOT(l + 4, r + 4, sum, l1, r1); @@ -372,14 +415,14 @@ template float DistanceInnerProduct::inner_product(const T *a, c return result; } -template float DistanceFastL2::compare(const T *a, const T *b, float norm, unsigned size) const +template float DistanceFastL2::compare(const T *a, const T *b, float norm, uint32_t size) const { float result = -2 * DistanceInnerProduct::inner_product(a, b, size); result += norm; return result; } -template float DistanceFastL2::norm(const T *a, unsigned size) const +template float DistanceFastL2::norm(const T *a, uint32_t size) const { if (!std::is_floating_point::value) { @@ -397,9 +440,9 @@ template float DistanceFastL2::norm(const T *a, unsigned size) c __m256 sum; __m256 l0, l1; - unsigned D = (size + 7) & ~7U; - unsigned DR = D % 16; - unsigned DD = D - DR; + uint32_t D = (size + 7) & ~7U; + uint32_t DR = D % 16; + uint32_t DD = D - DR; const float *l = (float *)a; const float *e_l = l + DD; float unpack[8] __attribute__((aligned(32))) = {0, 0, 0, 0, 0, 0, 0, 0}; @@ -409,7 +452,7 @@ template float DistanceFastL2::norm(const T *a, unsigned size) c { AVX_L2NORM(e_l, sum, l0); } - for (unsigned i = 0; i < DD; i += 16, l += 16) + for (uint32_t i = 0; i < DD; i += 16, l += 16) { AVX_L2NORM(l, sum, l0); AVX_L2NORM(l + 8, sum, l1); @@ -425,9 +468,9 @@ template float DistanceFastL2::norm(const T *a, unsigned size) c __m128 sum; __m128 l0, l1, l2, l3; - unsigned D = (size + 3) & ~3U; - unsigned DR = D % 16; - unsigned DD = D - DR; + uint32_t D = (size + 3) & ~3U; + uint32_t DR = D % 16; + uint32_t DD = D - DR; const float *l = a; const float *e_l = l + DD; float unpack[4] __attribute__((aligned(16))) = {0, 0, 0, 0}; @@ -444,7 +487,7 @@ template float DistanceFastL2::norm(const T *a, unsigned size) c default: break; } - for (unsigned i = 0; i < DD; i += 16, l += 16) + for (uint32_t i = 0; i < DD; i += 16, l += 16) { SSE_L2NORM(l, sum, l0); SSE_L2NORM(l + 4, sum, l1); @@ -492,9 +535,9 @@ float AVXDistanceInnerProductFloat::compare(const float *a, const float *b, uint __m256 sum; __m256 l0, l1; __m256 r0, r1; - unsigned D = (size + 7) & ~7U; - unsigned DR = D % 16; - unsigned DD = D - DR; + uint32_t D = (size + 7) & ~7U; + uint32_t DR = D % 16; + uint32_t DD = D - DR; const float *l = (float *)a; const float *r = (float *)b; const float *e_l = l + DD; @@ -511,7 +554,7 @@ float AVXDistanceInnerProductFloat::compare(const float *a, const float *b, uint AVX_DOT(e_l, e_r, sum, l0, r0); } - for (unsigned i = 0; i < DD; i += 16, l += 16, r += 16) + for (uint32_t i = 0; i < DD; i += 16, l += 16, r += 16) { AVX_DOT(l, r, sum, l0, r0); AVX_DOT(l + 8, r + 8, sum, l1, r1); @@ -522,6 +565,40 @@ float AVXDistanceInnerProductFloat::compare(const float *a, const float *b, uint return -result; } +uint32_t AVXNormalizedCosineDistanceFloat::post_normalization_dimension(uint32_t orig_dimension) const +{ + return orig_dimension; +} +bool AVXNormalizedCosineDistanceFloat::preprocessing_required() const +{ + return true; +} +void AVXNormalizedCosineDistanceFloat::preprocess_base_points(float *original_data, const size_t orig_dim, + const size_t num_points) +{ + for (uint32_t i = 0; i < num_points; i++) + { + normalize((float *)(original_data + i * orig_dim), orig_dim); + } +} + +void AVXNormalizedCosineDistanceFloat::preprocess_query(const float *query_vec, const size_t query_dim, + float *query_scratch) +{ + normalize_and_copy(query_vec, (uint32_t)query_dim, query_scratch); +} + +void AVXNormalizedCosineDistanceFloat::normalize_and_copy(const float *query_vec, const uint32_t query_dim, + float *query_target) const +{ + float norm = get_norm(query_vec, query_dim); + + for (uint32_t i = 0; i < query_dim; i++) + { + query_target[i] = query_vec[i] / norm; + } +} + // Get the right distance function for the given metric. template <> diskann::Distance *get_distance_function(diskann::Metric m) { @@ -540,7 +617,7 @@ template <> diskann::Distance *get_distance_function(diskann::Metric m) else { diskann::cout << "L2: Older CPU. Using slow distance computation" << std::endl; - return new diskann::SlowDistanceL2Float(); + return new diskann::SlowDistanceL2(); } } else if (m == diskann::Metric::COSINE) @@ -592,7 +669,7 @@ template <> diskann::Distance *get_distance_function(diskann::Metric m) diskann::cout << "Older CPU. Using slow distance computation " "SlowDistanceL2Int." << std::endl; - return new diskann::SlowDistanceL2Int(); + return new diskann::SlowDistanceL2(); } } else if (m == diskann::Metric::COSINE) @@ -616,7 +693,8 @@ template <> diskann::Distance *get_distance_function(diskann::Metric m) if (m == diskann::Metric::L2) { #ifdef _WINDOWS - diskann::cout << "WARNING: AVX/AVX2 distance function not defined for Uint8. Using " + diskann::cout << "WARNING: AVX/AVX2 distance function not defined for Uint8. " + "Using " "slow version. " "Contact gopalsr@microsoft.com if you need AVX/AVX2 support." << std::endl; @@ -634,7 +712,7 @@ template <> diskann::Distance *get_distance_function(diskann::Metric m) else { std::stringstream stream; - stream << "Only L2 and cosine supported for unsigned byte vectors." << std::endl; + stream << "Only L2 and cosine supported for uint32_t byte vectors." << std::endl; diskann::cerr << stream.str() << std::endl; throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } @@ -648,4 +726,8 @@ template DISKANN_DLLEXPORT class DistanceFastL2; template DISKANN_DLLEXPORT class DistanceFastL2; template DISKANN_DLLEXPORT class DistanceFastL2; +template DISKANN_DLLEXPORT class SlowDistanceL2; +template DISKANN_DLLEXPORT class SlowDistanceL2; +template DISKANN_DLLEXPORT class SlowDistanceL2; + } // namespace diskann diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index f9abe1e15..d00cfeb95 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -1,11 +1,13 @@ #Copyright(c) Microsoft Corporation.All rights reserved. #Licensed under the MIT license. -add_library(${PROJECT_NAME} SHARED dllmain.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp - ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp ../math_utils.cpp ../disk_utils.cpp - ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp) +add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp + ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp + ../in_mem_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) set(TARGET_DIR "$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>") + set(DISKANN_DLL_IMPLIB "${TARGET_DIR}/${PROJECT_NAME}.lib") target_compile_definitions(${PROJECT_NAME} PRIVATE _USRDLL _WINDLL) diff --git a/src/filter_utils.cpp b/src/filter_utils.cpp new file mode 100644 index 000000000..965762d1f --- /dev/null +++ b/src/filter_utils.cpp @@ -0,0 +1,284 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include + +#include +#include "filter_utils.h" +#include "index.h" +#include "parameters.h" +#include "utils.h" + +namespace diskann +{ +/* + * Using passed in parameters and files generated from step 3, + * builds a vanilla diskANN index for each label. + * + * Each index is saved under the following path: + * final_index_path_prefix + "_" + label + */ +template +void generate_label_indices(path input_data_path, path final_index_path_prefix, label_set all_labels, uint32_t R, + uint32_t L, float alpha, uint32_t num_threads) +{ + diskann::IndexWriteParameters label_index_build_parameters = diskann::IndexWriteParametersBuilder(L, R) + .with_saturate_graph(false) + .with_alpha(alpha) + .with_num_threads(num_threads) + .build(); + + std::cout << "Generating indices per label..." << std::endl; + // for each label, build an index on resp. points + double total_indexing_time = 0.0, indexing_percentage = 0.0; + std::cout.setstate(std::ios_base::failbit); + diskann::cout.setstate(std::ios_base::failbit); + for (const auto &lbl : all_labels) + { + path curr_label_input_data_path(input_data_path + "_" + lbl); + path curr_label_index_path(final_index_path_prefix + "_" + lbl); + + size_t number_of_label_points, dimension; + diskann::get_bin_metadata(curr_label_input_data_path, number_of_label_points, dimension); + diskann::Index index(diskann::Metric::L2, dimension, number_of_label_points, false, false); + + auto index_build_timer = std::chrono::high_resolution_clock::now(); + index.build(curr_label_input_data_path.c_str(), number_of_label_points, label_index_build_parameters); + std::chrono::duration current_indexing_time = + std::chrono::high_resolution_clock::now() - index_build_timer; + + total_indexing_time += current_indexing_time.count(); + indexing_percentage += (1 / (double)all_labels.size()); + print_progress(indexing_percentage); + + index.save(curr_label_index_path.c_str()); + } + std::cout.clear(); + diskann::cout.clear(); + + std::cout << "\nDone. Generated per-label indices in " << total_indexing_time << " seconds\n" << std::endl; +} + +// for use on systems without writev (i.e. Windows) +template +tsl::robin_map> generate_label_specific_vector_files_compat( + path input_data_path, tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels) +{ + auto file_writing_timer = std::chrono::high_resolution_clock::now(); + std::ifstream input_data_stream(input_data_path); + + uint32_t number_of_points, dimension; + input_data_stream.read((char *)&number_of_points, sizeof(uint32_t)); + input_data_stream.read((char *)&dimension, sizeof(uint32_t)); + const uint32_t VECTOR_SIZE = dimension * sizeof(T); + if (number_of_points != point_ids_to_labels.size()) + { + std::cerr << "Error: number of points in labels file and data file differ." << std::endl; + throw; + } + + tsl::robin_map labels_to_vectors; + tsl::robin_map labels_to_curr_vector; + tsl::robin_map> label_id_to_orig_id; + + for (const auto &lbl : all_labels) + { + uint32_t number_of_label_pts = labels_to_number_of_points[lbl]; + char *vectors = (char *)malloc(number_of_label_pts * VECTOR_SIZE); + if (vectors == nullptr) + { + throw; + } + labels_to_vectors[lbl] = vectors; + labels_to_curr_vector[lbl] = 0; + label_id_to_orig_id[lbl].reserve(number_of_label_pts); + } + + for (uint32_t point_id = 0; point_id < number_of_points; point_id++) + { + char *curr_vector = (char *)malloc(VECTOR_SIZE); + input_data_stream.read(curr_vector, VECTOR_SIZE); + for (const auto &lbl : point_ids_to_labels[point_id]) + { + char *curr_label_vector_ptr = labels_to_vectors[lbl] + (labels_to_curr_vector[lbl] * VECTOR_SIZE); + memcpy(curr_label_vector_ptr, curr_vector, VECTOR_SIZE); + labels_to_curr_vector[lbl]++; + label_id_to_orig_id[lbl].push_back(point_id); + } + free(curr_vector); + } + + for (const auto &lbl : all_labels) + { + path curr_label_input_data_path(input_data_path + "_" + lbl); + uint32_t number_of_label_pts = labels_to_number_of_points[lbl]; + + std::ofstream label_file_stream; + label_file_stream.exceptions(std::ios::badbit | std::ios::failbit); + label_file_stream.open(curr_label_input_data_path, std::ios_base::binary); + label_file_stream.write((char *)&number_of_label_pts, sizeof(uint32_t)); + label_file_stream.write((char *)&dimension, sizeof(uint32_t)); + label_file_stream.write((char *)labels_to_vectors[lbl], number_of_label_pts * VECTOR_SIZE); + + label_file_stream.close(); + free(labels_to_vectors[lbl]); + } + input_data_stream.close(); + + std::chrono::duration file_writing_time = std::chrono::high_resolution_clock::now() - file_writing_timer; + std::cout << "generated " << all_labels.size() << " label-specific vector files for index building in time " + << file_writing_time.count() << "\n" + << std::endl; + + return label_id_to_orig_id; +} + +/* + * Manually loads a graph index in from a given file. + * + * Returns both the graph index and the size of the file in bytes. + */ +load_label_index_return_values load_label_index(path label_index_path, uint32_t label_number_of_points) +{ + std::ifstream label_index_stream; + label_index_stream.exceptions(std::ios::badbit | std::ios::failbit); + label_index_stream.open(label_index_path, std::ios::binary); + + uint64_t index_file_size, index_num_frozen_points; + uint32_t index_max_observed_degree, index_entry_point; + const size_t INDEX_METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t); + label_index_stream.read((char *)&index_file_size, sizeof(uint64_t)); + label_index_stream.read((char *)&index_max_observed_degree, sizeof(uint32_t)); + label_index_stream.read((char *)&index_entry_point, sizeof(uint32_t)); + label_index_stream.read((char *)&index_num_frozen_points, sizeof(uint64_t)); + size_t bytes_read = INDEX_METADATA; + + std::vector> label_index(label_number_of_points); + uint32_t nodes_read = 0; + while (bytes_read != index_file_size) + { + uint32_t current_node_num_neighbors; + label_index_stream.read((char *)¤t_node_num_neighbors, sizeof(uint32_t)); + nodes_read++; + + std::vector current_node_neighbors(current_node_num_neighbors); + label_index_stream.read((char *)current_node_neighbors.data(), current_node_num_neighbors * sizeof(uint32_t)); + label_index[nodes_read - 1].swap(current_node_neighbors); + bytes_read += sizeof(uint32_t) * (current_node_num_neighbors + 1); + } + + return std::make_tuple(label_index, index_file_size); +} + +/* + * Parses the label datafile, which has comma-separated labels on + * each line. Line i corresponds to point id i. + * + * Returns three objects via std::tuple: + * 1. map: key is point id, value is vector of labels said point has + * 2. map: key is label, value is number of points with the label + * 3. the label universe as a set + */ +parse_label_file_return_values parse_label_file(path label_data_path, std::string universal_label) +{ + std::ifstream label_data_stream(label_data_path); + std::string line, token; + uint32_t line_cnt = 0; + + // allows us to reserve space for the points_to_labels vector + while (std::getline(label_data_stream, line)) + line_cnt++; + label_data_stream.clear(); + label_data_stream.seekg(0, std::ios::beg); + + // values to return + std::vector point_ids_to_labels(line_cnt); + tsl::robin_map labels_to_number_of_points; + label_set all_labels; + + std::vector points_with_universal_label; + line_cnt = 0; + while (std::getline(label_data_stream, line)) + { + std::istringstream current_labels_comma_separated(line); + label_set current_labels; + + // get point id + uint32_t point_id = line_cnt; + + // parse comma separated labels + bool current_universal_label_check = false; + while (getline(current_labels_comma_separated, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + + // if token is empty, there's no labels for the point + if (token == universal_label) + { + points_with_universal_label.push_back(point_id); + current_universal_label_check = true; + } + else + { + all_labels.insert(token); + current_labels.insert(token); + labels_to_number_of_points[token]++; + } + } + + if (current_labels.size() <= 0 && !current_universal_label_check) + { + std::cerr << "Error: " << point_id << " has no labels." << std::endl; + exit(-1); + } + point_ids_to_labels[point_id] = current_labels; + line_cnt++; + } + + // for every point with universal label, set its label set to all labels + // also, increment the count for number of points a label has + for (const auto &point_id : points_with_universal_label) + { + point_ids_to_labels[point_id] = all_labels; + for (const auto &lbl : all_labels) + labels_to_number_of_points[lbl]++; + } + + std::cout << "Identified " << all_labels.size() << " distinct label(s) for " << point_ids_to_labels.size() + << " points\n" + << std::endl; + + return std::make_tuple(point_ids_to_labels, labels_to_number_of_points, all_labels); +} + +template DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, + label_set all_labels, uint32_t R, uint32_t L, float alpha, + uint32_t num_threads); +template DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, + label_set all_labels, uint32_t R, uint32_t L, + float alpha, uint32_t num_threads); +template DISKANN_DLLEXPORT void generate_label_indices(path input_data_path, path final_index_path_prefix, + label_set all_labels, uint32_t R, uint32_t L, + float alpha, uint32_t num_threads); + +template DISKANN_DLLEXPORT tsl::robin_map> +generate_label_specific_vector_files_compat(path input_data_path, + tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels); +template DISKANN_DLLEXPORT tsl::robin_map> +generate_label_specific_vector_files_compat(path input_data_path, + tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels); +template DISKANN_DLLEXPORT tsl::robin_map> +generate_label_specific_vector_files_compat(path input_data_path, + tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels); + +} // namespace diskann \ No newline at end of file diff --git a/src/in_mem_data_store.cpp b/src/in_mem_data_store.cpp new file mode 100644 index 000000000..f5f973917 --- /dev/null +++ b/src/in_mem_data_store.cpp @@ -0,0 +1,370 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include "in_mem_data_store.h" + +#include "utils.h" + +namespace diskann +{ + +template +InMemDataStore::InMemDataStore(const location_t num_points, const size_t dim, + std::shared_ptr> distance_fn) + : AbstractDataStore(num_points, dim), _distance_fn(distance_fn) +{ + _aligned_dim = ROUND_UP(dim, _distance_fn->get_required_alignment()); + alloc_aligned(((void **)&_data), this->_capacity * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); + std::memset(_data, 0, this->_capacity * _aligned_dim * sizeof(data_t)); +} + +template InMemDataStore::~InMemDataStore() +{ + if (_data != nullptr) + { + aligned_free(this->_data); + } +} + +template size_t InMemDataStore::get_aligned_dim() const +{ + return _aligned_dim; +} + +template size_t InMemDataStore::get_alignment_factor() const +{ + return _distance_fn->get_required_alignment(); +} + +template location_t InMemDataStore::load(const std::string &filename) +{ + return load_impl(filename); +} + +#ifdef EXEC_ENV_OLS +template location_t InMemDataStore::load_impl(AlignedFileReader &reader) +{ + size_t file_dim, file_num_points; + + diskann::get_bin_metadata(reader, file_num_points, file_dim); + + if (file_dim != this->_dim) + { + std::stringstream stream; + stream << "ERROR: Driver requests loading " << this->_dim << " dimension," + << "but file has " << file_dim << " dimension." << std::endl; + diskann::cerr << stream.str() << std::endl; + aligned_free(_data); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (file_num_points > this->capacity()) + { + this->resize((location_t)file_num_points); + } + copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, _aligned_dim); + + return (location_t)file_num_points; +} +#endif + +template location_t InMemDataStore::load_impl(const std::string &filename) +{ + size_t file_dim, file_num_points; + if (!file_exists(filename)) + { + std::stringstream stream; + stream << "ERROR: data file " << filename << " does not exist." << std::endl; + diskann::cerr << stream.str() << std::endl; + aligned_free(_data); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + diskann::get_bin_metadata(filename, file_num_points, file_dim); + + if (file_dim != this->_dim) + { + std::stringstream stream; + stream << "ERROR: Driver requests loading " << this->_dim << " dimension," + << "but file has " << file_dim << " dimension." << std::endl; + diskann::cerr << stream.str() << std::endl; + aligned_free(_data); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (file_num_points > this->capacity()) + { + this->resize((location_t)file_num_points); + } + + copy_aligned_data_from_file(filename.c_str(), _data, file_num_points, file_dim, _aligned_dim); + + return (location_t)file_num_points; +} + +template size_t InMemDataStore::save(const std::string &filename, const location_t num_points) +{ + return save_data_in_base_dimensions(filename, _data, num_points, this->get_dims(), this->get_aligned_dim(), 0U); +} + +template void InMemDataStore::populate_data(const data_t *vectors, const location_t num_pts) +{ + memset(_data, 0, _aligned_dim * sizeof(data_t) * num_pts); + for (location_t i = 0; i < num_pts; i++) + { + std::memmove(_data + i * _aligned_dim, vectors + i * this->_dim, this->_dim * sizeof(data_t)); + } + + if (_distance_fn->preprocessing_required()) + { + _distance_fn->preprocess_base_points(_data, this->_aligned_dim, num_pts); + } +} + +template void InMemDataStore::populate_data(const std::string &filename, const size_t offset) +{ + size_t npts, ndim; + copy_aligned_data_from_file(filename.c_str(), _data, npts, ndim, _aligned_dim, offset); + + if ((location_t)npts > this->capacity()) + { + std::stringstream ss; + ss << "Number of points in the file: " << filename + << " is greater than the capacity of data store: " << this->capacity() + << ". Must invoke resize before calling populate_data()" << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + + if ((location_t)ndim != this->get_dims()) + { + std::stringstream ss; + ss << "Number of dimensions of a point in the file: " << filename + << " is not equal to dimensions of data store: " << this->capacity() << "." << std::endl; + throw diskann::ANNException(ss.str(), -1); + } + + if (_distance_fn->preprocessing_required()) + { + _distance_fn->preprocess_base_points(_data, this->_aligned_dim, this->capacity()); + } +} + +template +void InMemDataStore::extract_data_to_bin(const std::string &filename, const location_t num_points) +{ + save_data_in_base_dimensions(filename, _data, num_points, this->get_dims(), this->get_aligned_dim(), 0U); +} + +template void InMemDataStore::get_vector(const location_t i, data_t *dest) const +{ + memcpy(dest, _data + i * _aligned_dim, this->_dim * sizeof(data_t)); +} + +template void InMemDataStore::set_vector(const location_t loc, const data_t *const vector) +{ + size_t offset_in_data = loc * _aligned_dim; + memset(_data + offset_in_data, 0, _aligned_dim * sizeof(data_t)); + memcpy(_data + offset_in_data, vector, this->_dim * sizeof(data_t)); + if (_distance_fn->preprocessing_required()) + { + _distance_fn->preprocess_base_points(_data + offset_in_data, _aligned_dim, 1); + } +} + +template void InMemDataStore::prefetch_vector(const location_t loc) +{ + diskann::prefetch_vector((const char *)_data + _aligned_dim * (size_t)loc, sizeof(data_t) * _aligned_dim); +} + +template float InMemDataStore::get_distance(const data_t *query, const location_t loc) const +{ + return _distance_fn->compare(query, _data + _aligned_dim * loc, (uint32_t)_aligned_dim); +} + +template +void InMemDataStore::get_distance(const data_t *query, const location_t *locations, + const uint32_t location_count, float *distances) const +{ + for (location_t i = 0; i < location_count; i++) + { + distances[i] = _distance_fn->compare(query, _data + locations[i] * _aligned_dim, (uint32_t)this->_aligned_dim); + } +} + +template +float InMemDataStore::get_distance(const location_t loc1, const location_t loc2) const +{ + return _distance_fn->compare(_data + loc1 * _aligned_dim, _data + loc2 * _aligned_dim, + (uint32_t)this->_aligned_dim); +} + +template location_t InMemDataStore::expand(const location_t new_size) +{ + if (new_size == this->capacity()) + { + return this->capacity(); + } + else if (new_size < this->capacity()) + { + std::stringstream ss; + ss << "Cannot 'expand' datastore when new capacity (" << new_size << ") < existing capacity(" + << this->capacity() << ")" << std::endl; + throw diskann::ANNException(ss.str(), -1); + } +#ifndef _WINDOWS + data_t *new_data; + alloc_aligned((void **)&new_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); + memcpy(new_data, _data, this->capacity() * _aligned_dim * sizeof(data_t)); + aligned_free(_data); + _data = new_data; +#else + realloc_aligned((void **)&_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); +#endif + this->_capacity = new_size; + return this->_capacity; +} + +template location_t InMemDataStore::shrink(const location_t new_size) +{ + if (new_size == this->capacity()) + { + return this->capacity(); + } + else if (new_size > this->capacity()) + { + std::stringstream ss; + ss << "Cannot 'shrink' datastore when new capacity (" << new_size << ") > existing capacity(" + << this->capacity() << ")" << std::endl; + throw diskann::ANNException(ss.str(), -1); + } +#ifndef _WINDOWS + data_t *new_data; + alloc_aligned((void **)&new_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); + memcpy(new_data, _data, new_size * _aligned_dim * sizeof(data_t)); + aligned_free(_data); + _data = new_data; +#else + realloc_aligned((void **)&_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); +#endif + this->_capacity = new_size; + return this->_capacity; +} + +template +void InMemDataStore::move_vectors(const location_t old_location_start, const location_t new_location_start, + const location_t num_locations) +{ + if (num_locations == 0 || old_location_start == new_location_start) + { + return; + } + + /* // Update pointers to the moved nodes. Note: the computation is correct + even + // when new_location_start < old_location_start given the C++ uint32_t + // integer arithmetic rules. + const uint32_t location_delta = new_location_start - old_location_start; + */ + // The [start, end) interval which will contain obsolete points to be + // cleared. + uint32_t mem_clear_loc_start = old_location_start; + uint32_t mem_clear_loc_end_limit = old_location_start + num_locations; + + if (new_location_start < old_location_start) + { + // If ranges are overlapping, make sure not to clear the newly copied + // data. + if (mem_clear_loc_start < new_location_start + num_locations) + { + // Clear only after the end of the new range. + mem_clear_loc_start = new_location_start + num_locations; + } + } + else + { + // If ranges are overlapping, make sure not to clear the newly copied + // data. + if (mem_clear_loc_end_limit > new_location_start) + { + // Clear only up to the beginning of the new range. + mem_clear_loc_end_limit = new_location_start; + } + } + + // Use memmove to handle overlapping ranges. + copy_vectors(old_location_start, new_location_start, num_locations); + memset(_data + _aligned_dim * mem_clear_loc_start, 0, + sizeof(data_t) * _aligned_dim * (mem_clear_loc_end_limit - mem_clear_loc_start)); +} + +template +void InMemDataStore::copy_vectors(const location_t from_loc, const location_t to_loc, + const location_t num_points) +{ + assert(from_loc < this->_capacity); + assert(to_loc < this->_capacity); + assert(num_points < this->_capacity); + memmove(_data + _aligned_dim * to_loc, _data + _aligned_dim * from_loc, num_points * _aligned_dim * sizeof(data_t)); +} + +template location_t InMemDataStore::calculate_medoid() const +{ + // allocate and init centroid + float *center = new float[_aligned_dim]; + for (size_t j = 0; j < _aligned_dim; j++) + center[j] = 0; + + for (size_t i = 0; i < this->capacity(); i++) + for (size_t j = 0; j < _aligned_dim; j++) + center[j] += (float)_data[i * _aligned_dim + j]; + + for (size_t j = 0; j < _aligned_dim; j++) + center[j] /= (float)this->capacity(); + + // compute all to one distance + float *distances = new float[this->capacity()]; + + // TODO: REFACTOR. Removing pragma might make this slow. Must revisit. + // Problem is that we need to pass num_threads here, it is not clear + // if data store must be aware of threads! + // #pragma omp parallel for schedule(static, 65536) + for (int64_t i = 0; i < (int64_t)this->capacity(); i++) + { + // extract point and distance reference + float &dist = distances[i]; + const data_t *cur_vec = _data + (i * (size_t)_aligned_dim); + dist = 0; + float diff = 0; + for (size_t j = 0; j < _aligned_dim; j++) + { + diff = (center[j] - (float)cur_vec[j]) * (center[j] - (float)cur_vec[j]); + dist += diff; + } + } + // find imin + uint32_t min_idx = 0; + float min_dist = distances[0]; + for (uint32_t i = 1; i < this->capacity(); i++) + { + if (distances[i] < min_dist) + { + min_idx = i; + min_dist = distances[i]; + } + } + + delete[] distances; + delete[] center; + return min_idx; +} + +template Distance *InMemDataStore::get_dist_fn() +{ + return this->_distance_fn.get(); +} + +template DISKANN_DLLEXPORT class InMemDataStore; +template DISKANN_DLLEXPORT class InMemDataStore; +template DISKANN_DLLEXPORT class InMemDataStore; + +} // namespace diskann \ No newline at end of file diff --git a/src/in_mem_graph_store.cpp b/src/in_mem_graph_store.cpp new file mode 100644 index 000000000..e9bfd4e9e --- /dev/null +++ b/src/in_mem_graph_store.cpp @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "in_mem_graph_store.h" +#include "utils.h" + +namespace diskann +{ + +InMemGraphStore::InMemGraphStore(const size_t max_pts) : AbstractGraphStore(max_pts) +{ +} + +int InMemGraphStore::load(const std::string &index_path_prefix) +{ + return 0; +} +int InMemGraphStore::store(const std::string &index_path_prefix) +{ + return 0; +} + +void InMemGraphStore::get_adj_list(const location_t i, std::vector &neighbors) +{ +} + +void InMemGraphStore::set_adj_list(const location_t i, std::vector &neighbors) +{ +} + +} // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index d348a510f..ef35c6912 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -30,34 +30,38 @@ namespace diskann // (bin), and initialize max_points template Index::Index(Metric m, const size_t dim, const size_t max_points, const bool dynamic_index, - const Parameters &indexParams, const Parameters &searchParams, const bool enable_tags, - const bool concurrent_consolidate, const bool pq_dist_build, const size_t num_pq_chunks, - const bool use_opq) + const IndexWriteParameters &indexParams, const uint32_t initial_search_list_size, + const uint32_t search_threads, const bool enable_tags, const bool concurrent_consolidate, + const bool pq_dist_build, const size_t num_pq_chunks, const bool use_opq) : Index(m, dim, max_points, dynamic_index, enable_tags, concurrent_consolidate, pq_dist_build, num_pq_chunks, - use_opq, indexParams.Get("num_frozen_pts", 0)) + use_opq, indexParams.num_frozen_points) { - _indexingQueueSize = indexParams.Get("L"); - _indexingRange = indexParams.Get("R"); - _indexingMaxC = indexParams.Get("C"); - _indexingAlpha = indexParams.Get("alpha"); - _filterIndexingQueueSize = indexParams.Get("Lf"); + if (dynamic_index) + { + this->enable_delete(); + } + _indexingQueueSize = indexParams.search_list_size; + _indexingRange = indexParams.max_degree; + _indexingMaxC = indexParams.max_occlusion_size; + _indexingAlpha = indexParams.alpha; + _filterIndexingQueueSize = indexParams.filter_list_size; - uint32_t num_threads_srch = searchParams.Get("num_threads"); - uint32_t num_threads_indx = indexParams.Get("num_threads"); - uint32_t num_scratch_spaces = num_threads_srch + num_threads_indx; - uint32_t search_l = searchParams.Get("L"); + uint32_t num_threads_indx = indexParams.num_threads; + uint32_t num_scratch_spaces = search_threads + num_threads_indx; - initialize_query_scratch(num_scratch_spaces, search_l, _indexingQueueSize, _indexingRange, _indexingMaxC, dim); + initialize_query_scratch(num_scratch_spaces, initial_search_list_size, _indexingQueueSize, _indexingRange, + _indexingMaxC, dim); } template Index::Index(Metric m, const size_t dim, const size_t max_points, const bool dynamic_index, const bool enable_tags, const bool concurrent_consolidate, const bool pq_dist_build, - const size_t num_pq_chunks, const bool use_opq, const size_t num_frozen_pts) - : _dist_metric(m), _dim(dim), _num_frozen_pts(num_frozen_pts), _max_points(max_points), + const size_t num_pq_chunks, const bool use_opq, const size_t num_frozen_pts, + const bool init_data_store) + : _dist_metric(m), _dim(dim), _max_points(max_points), _num_frozen_pts(num_frozen_pts), _dynamic_index(dynamic_index), _enable_tags(enable_tags), _indexingMaxC(DEFAULT_MAXC), _query_scratch(nullptr), - _conc_consolidate(concurrent_consolidate), _delete_set(new tsl::robin_set), _pq_dist(pq_dist_build), - _use_opq(use_opq), _num_pq_chunks(num_pq_chunks) + _pq_dist(pq_dist_build), _use_opq(use_opq), _num_pq_chunks(num_pq_chunks), + _delete_set(new tsl::robin_set), _conc_consolidate(concurrent_consolidate) { if (dynamic_index && !enable_tags) { @@ -71,14 +75,12 @@ Index::Index(Metric m, const size_t dim, const size_t max_point "index construction", -1, __FUNCSIG__, __FILE__, __LINE__); if (m == diskann::Metric::INNER_PRODUCT) - throw ANNException("ERROR: Inner product metrics not yet supported with PQ distance " + throw ANNException("ERROR: Inner product metrics not yet supported " + "with PQ distance " "base index", -1, __FUNCSIG__, __FILE__, __LINE__); } - // data stored to _nd * aligned_dim matrix with necessary zero-padding - _aligned_dim = ROUND_UP(_dim, 8); - if (dynamic_index && _num_frozen_pts == 0) { _num_frozen_pts = 1; @@ -98,25 +100,31 @@ Index::Index(Metric m, const size_t dim, const size_t max_point alloc_aligned(((void **)&_pq_data), total_internal_points * _num_pq_chunks * sizeof(char), 8 * sizeof(char)); std::memset(_pq_data, 0, total_internal_points * _num_pq_chunks * sizeof(char)); } - alloc_aligned(((void **)&_data), total_internal_points * _aligned_dim * sizeof(T), 8 * sizeof(T)); - std::memset(_data, 0, total_internal_points * _aligned_dim * sizeof(T)); - _start = (unsigned)_max_points; + _start = (uint32_t)_max_points; _final_graph.resize(total_internal_points); - if (m == diskann::Metric::COSINE && std::is_floating_point::value) + if (init_data_store) { - // This is safe because T is float inside the if block. - this->_distance = (Distance *)new AVXNormalizedCosineDistanceFloat(); - this->_normalize_vecs = true; - diskann::cout << "Normalizing vectors and using L2 for cosine " - "AVXNormalizedCosineDistanceFloat()." - << std::endl; - } - else - { - this->_distance = get_distance_function(m); + // Issue #374: data_store is injected from index factory. Keeping this for backward compatibility. + // distance is owned by data_store + if (m == diskann::Metric::COSINE && std::is_floating_point::value) + { + // This is safe because T is float inside the if block. + this->_distance.reset((Distance *)new AVXNormalizedCosineDistanceFloat()); + this->_normalize_vecs = true; + diskann::cout << "Normalizing vectors and using L2 for cosine " + "AVXNormalizedCosineDistanceFloat()." + << std::endl; + } + else + { + this->_distance.reset((Distance *)get_distance_function(m)); + } + // Note: moved this to factory, keeping this for backward compatibility. + _data_store = + std::make_unique>((location_t)total_internal_points, _dim, this->_distance); } _locks = std::vector(total_internal_points); @@ -128,6 +136,37 @@ Index::Index(Metric m, const size_t dim, const size_t max_point } } +template +Index::Index(const IndexConfig &index_config, std::unique_ptr> data_store) + : Index(index_config.metric, index_config.dimension, index_config.max_points, index_config.dynamic_index, + index_config.enable_tags, index_config.concurrent_consolidate, index_config.pq_dist_build, + index_config.num_pq_chunks, index_config.use_opq, index_config.num_frozen_pts, false) +{ + + _data_store = std::move(data_store); + _distance.reset(_data_store->get_dist_fn()); + + // enable delete by default for dynamic index + if (_dynamic_index) + { + this->enable_delete(); + } + if (_dynamic_index && index_config.index_write_params != nullptr) + { + _indexingQueueSize = index_config.index_write_params->search_list_size; + _indexingRange = index_config.index_write_params->max_degree; + _indexingMaxC = index_config.index_write_params->max_occlusion_size; + _indexingAlpha = index_config.index_write_params->alpha; + _filterIndexingQueueSize = index_config.index_write_params->filter_list_size; + + uint32_t num_threads_indx = index_config.index_write_params->num_threads; + uint32_t num_scratch_spaces = index_config.search_threads + num_threads_indx; + + initialize_query_scratch(num_scratch_spaces, index_config.initial_search_list_size, _indexingQueueSize, + _indexingRange, _indexingMaxC, _data_store->get_dims()); + } +} + template Index::~Index() { // Ensure that no other activity is happening before dtor() @@ -141,16 +180,13 @@ template Index::~I LockGuard lg(lock); } - if (this->_distance != nullptr) - { - delete this->_distance; - this->_distance = nullptr; - } - if (this->_data != nullptr) - { - aligned_free(this->_data); - this->_data = nullptr; - } + // if (this->_distance != nullptr) + //{ + // delete this->_distance; + // this->_distance = nullptr; + // } + // REFACTOR + if (_opt_graph != nullptr) { delete[] _opt_graph; @@ -169,12 +205,13 @@ void Index::initialize_query_scratch(uint32_t num_threads, uint { for (uint32_t i = 0; i < num_threads; i++) { - auto scratch = new InMemQueryScratch(search_l, indexing_l, r, maxc, dim, _pq_dist, bitmask_size); + auto scratch = new InMemQueryScratch(search_l, indexing_l, r, maxc, dim, _data_store->get_aligned_dim(), + _data_store->get_alignment_factor(), _pq_dist, bitmask_size); _query_scratch.push(scratch); } } -template _u64 Index::save_tags(std::string tags_file) +template size_t Index::save_tags(std::string tags_file) { if (!_enable_tags) { @@ -183,7 +220,7 @@ template _u64 Index _u64 Index _u64 Index::save_data(std::string data_file) +template size_t Index::save_data(std::string data_file) { - // Note: at this point, either _nd == _max_points or any frozen points have been - // temporarily moved to _nd, so _nd + _num_frozen_points is the valid location limit. - return save_data_in_base_dimensions(data_file, _data, _nd + _num_frozen_pts, _dim, _aligned_dim); + // Note: at this point, either _nd == _max_points or any frozen points have + // been temporarily moved to _nd, so _nd + _num_frozen_points is the valid + // location limit. + return _data_store->save(data_file, (location_t)(_nd + _num_frozen_pts)); } // save the graph index on a file as an adjacency list. For each point, // first store the number of neighbors, and then the neighbor list (each as -// 4 byte unsigned) -template _u64 Index::save_graph(std::string graph_file) +// 4 byte uint32_t) +template size_t Index::save_graph(std::string graph_file) { std::ofstream out; open_file_to_write(out, graph_file); - _u64 file_offset = 0; // we will use this if we want + size_t file_offset = 0; // we will use this if we want out.seekp(file_offset, out.beg); - _u64 index_size = 24; - _u32 max_degree = 0; + size_t index_size = 24; + uint32_t max_degree = 0; out.write((char *)&index_size, sizeof(uint64_t)); - out.write((char *)&_max_observed_degree, sizeof(unsigned)); - unsigned ep_u32 = _start; - out.write((char *)&ep_u32, sizeof(unsigned)); - out.write((char *)&_num_frozen_pts, sizeof(_u64)); - // Note: at this point, either _nd == _max_points or any frozen points have been - // temporarily moved to _nd, so _nd + _num_frozen_points is the valid location limit. - for (unsigned i = 0; i < _nd + _num_frozen_pts; i++) - { - unsigned GK = (unsigned)_final_graph[i].size(); - out.write((char *)&GK, sizeof(unsigned)); - out.write((char *)_final_graph[i].data(), GK * sizeof(unsigned)); - max_degree = _final_graph[i].size() > max_degree ? (_u32)_final_graph[i].size() : max_degree; - index_size += (_u64)(sizeof(unsigned) * (GK + 1)); + out.write((char *)&_max_observed_degree, sizeof(uint32_t)); + uint32_t ep_u32 = _start; + out.write((char *)&ep_u32, sizeof(uint32_t)); + out.write((char *)&_num_frozen_pts, sizeof(size_t)); + // Note: at this point, either _nd == _max_points or any frozen points have + // been temporarily moved to _nd, so _nd + _num_frozen_points is the valid + // location limit. + for (uint32_t i = 0; i < _nd + _num_frozen_pts; i++) + { + uint32_t GK = (uint32_t)_final_graph[i].size(); + out.write((char *)&GK, sizeof(uint32_t)); + out.write((char *)_final_graph[i].data(), GK * sizeof(uint32_t)); + max_degree = _final_graph[i].size() > max_degree ? (uint32_t)_final_graph[i].size() : max_degree; + index_size += (size_t)(sizeof(uint32_t) * (GK + 1)); } out.seekp(file_offset, out.beg); out.write((char *)&index_size, sizeof(uint64_t)); - out.write((char *)&max_degree, sizeof(_u32)); + out.write((char *)&max_degree, sizeof(uint32_t)); out.close(); return index_size; // number of bytes written } template -_u64 Index::save_delete_list(const std::string &filename) +size_t Index::save_delete_list(const std::string &filename) { if (_delete_set->size() == 0) { return 0; } - std::unique_ptr<_u32[]> delete_list = std::make_unique<_u32[]>(_delete_set->size()); - _u32 i = 0; + std::unique_ptr delete_list = std::make_unique(_delete_set->size()); + uint32_t i = 0; for (auto &del : *_delete_set) { delete_list[i++] = del; } - return save_bin<_u32>(filename, delete_list.get(), _delete_set->size(), 1); + return save_bin(filename, delete_list.get(), _delete_set->size(), 1); } template @@ -323,9 +362,9 @@ void Index::save(const char *filename, bool compact_before_save { std::ofstream label_writer(std::string(filename) + "_labels.txt"); assert(label_writer.is_open()); - for (_u32 i = 0; i < _pts_to_labels.size(); i++) + for (uint32_t i = 0; i < _pts_to_labels.size(); i++) { - for (_u32 j = 0; j < (_pts_to_labels[i].size() - 1); j++) + for (uint32_t j = 0; j < (_pts_to_labels[i].size() - 1); j++) { label_writer << _pts_to_labels[i][j] << ","; } @@ -362,7 +401,8 @@ void Index::save(const char *filename, bool compact_before_save << std::endl; } - // If frozen points were temporarily compacted to _nd, move back to _max_points. + // If frozen points were temporarily compacted to _nd, move back to + // _max_points. reposition_frozen_point_to_end(); diskann::cout << "Time taken for save: " << timer.elapsed() / 1000000.0 << "s." << std::endl; @@ -378,8 +418,9 @@ size_t Index::load_tags(const std::string tag_filename) { if (_enable_tags && !file_exists(tag_filename)) { - diskann::cerr << "Tag file provided does not exist!" << std::endl; - throw diskann::ANNException("Tag file provided does not exist!", -1, __FUNCSIG__, __FILE__, __LINE__); + diskann::cerr << "Tag file " << tag_filename << " does not exist!" << std::endl; + throw diskann::ANNException("Tag file " + tag_filename + " does not exist!", -1, __FUNCSIG__, __FILE__, + __LINE__); } #endif if (!_enable_tags) @@ -409,7 +450,7 @@ size_t Index::load_tags(const std::string tag_filename) const size_t num_data_points = file_num_points - _num_frozen_pts; _location_to_tag.reserve(num_data_points); _tag_to_location.reserve(num_data_points); - for (_u32 i = 0; i < (_u32)num_data_points; i++) + for (uint32_t i = 0; i < (uint32_t)num_data_points; i++) { TagT tag = *(tag_data + i); if (_delete_set->find(i) == _delete_set->end()) @@ -440,7 +481,6 @@ size_t Index::load_data(std::string filename) std::stringstream stream; stream << "ERROR: data file " << filename << " does not exist." << std::endl; diskann::cerr << stream.str() << std::endl; - aligned_free(_data); throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } diskann::get_bin_metadata(filename, file_num_points, file_dim); @@ -455,7 +495,6 @@ size_t Index::load_data(std::string filename) stream << "ERROR: Driver requests loading " << _dim << " dimension," << "but file has " << file_dim << " dimension." << std::endl; diskann::cerr << stream.str() << std::endl; - aligned_free(_data); throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } @@ -466,9 +505,10 @@ size_t Index::load_data(std::string filename) } #ifdef EXEC_ENV_OLS - copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, _aligned_dim); + // REFACTOR TODO: Must figure out how to support aligned reader in a clean manner. + copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, _data_store->get_aligned_dim()); #else - copy_aligned_data_from_file(filename.c_str(), _data, file_num_points, file_dim, _aligned_dim); + _data_store->load(filename); // offset == 0. #endif return file_num_points; } @@ -482,13 +522,13 @@ template size_t Index::load_delete_set(const std::string &filename) { #endif - std::unique_ptr<_u32[]> delete_list; - _u64 npts, ndim; + std::unique_ptr delete_list; + size_t npts, ndim; #ifdef EXEC_ENV_OLS - diskann::load_bin<_u32>(reader, delete_list, npts, ndim); + diskann::load_bin(reader, delete_list, npts, ndim); #else - diskann::load_bin<_u32>(filename, delete_list, npts, ndim); + diskann::load_bin(filename, delete_list, npts, ndim); #endif assert(ndim == 1); for (uint32_t i = 0; i < npts; i++) @@ -516,15 +556,16 @@ void Index::load(const char *filename, uint32_t num_threads, ui _has_built = true; size_t tags_file_num_pts = 0, graph_num_pts = 0, data_file_num_pts = 0, label_num_pts = 0; - +#ifndef EXEC_ENV_OLS std::string mem_index_file(filename); std::string labels_file = mem_index_file + "_labels.txt"; std::string labels_to_medoids = mem_index_file + "_labels_to_medoids.txt"; std::string labels_map_file = mem_index_file + "_labels_map.txt"; - +#endif if (!_save_as_one_file) { - // For DLVS Store, we will not support saving the index in multiple files. + // For DLVS Store, we will not support saving the index in multiple + // files. #ifndef EXEC_ENV_OLS std::string data_file = std::string(filename) + ".data"; std::string tags_file = std::string(filename) + ".tags"; @@ -557,10 +598,9 @@ void Index::load(const char *filename, uint32_t num_threads, ui << graph_num_pts << " from graph, and " << tags_file_num_pts << " tags, with num_frozen_pts being set to " << _num_frozen_pts << " in constructor." << std::endl; diskann::cerr << stream.str() << std::endl; - aligned_free(_data); throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } - +#ifndef EXEC_ENV_OLS if (file_exists(labels_file)) { _label_map = load_label_map(labels_map_file); @@ -571,21 +611,21 @@ void Index::load(const char *filename, uint32_t num_threads, ui { std::ifstream medoid_stream(labels_to_medoids); std::string line, token; - unsigned line_cnt = 0; + uint32_t line_cnt = 0; _label_to_medoid_id.clear(); while (std::getline(medoid_stream, line)) { std::istringstream iss(line); - _u32 cnt = 0; - _u32 medoid = 0; + uint32_t cnt = 0; + uint32_t medoid = 0; LabelT label; while (std::getline(iss, token, ',')) { token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = std::stoul(token); + LabelT token_as_num = (LabelT)std::stoul(token); if (cnt == 0) label = token_as_num; else @@ -607,7 +647,7 @@ void Index::load(const char *filename, uint32_t num_threads, ui universal_label_reader.close(); } } - +#endif _nd = data_file_num_pts - _num_frozen_pts; _empty_slots.clear(); _empty_slots.reserve(_max_points); @@ -638,17 +678,17 @@ template size_t Index::get_graph_num_frozen_points(const std::string &graph_file) { size_t expected_file_size; - unsigned max_observed_degree, start; - _u64 file_frozen_pts; + uint32_t max_observed_degree, start; + size_t file_frozen_pts; std::ifstream in; in.exceptions(std::ios::badbit | std::ios::failbit); in.open(graph_file, std::ios::binary); - in.read((char *)&expected_file_size, sizeof(_u64)); - in.read((char *)&max_observed_degree, sizeof(unsigned)); - in.read((char *)&start, sizeof(unsigned)); - in.read((char *)&file_frozen_pts, sizeof(_u64)); + in.read((char *)&expected_file_size, sizeof(size_t)); + in.read((char *)&max_observed_degree, sizeof(uint32_t)); + in.read((char *)&start, sizeof(uint32_t)); + in.read((char *)&file_frozen_pts, sizeof(size_t)); return file_frozen_pts; } @@ -665,29 +705,29 @@ size_t Index::load_graph(std::string filename, size_t expected_ { #endif size_t expected_file_size; - _u64 file_frozen_pts; + size_t file_frozen_pts; #ifdef EXEC_ENV_OLS - int header_size = 2 * sizeof(_u64) + 2 * sizeof(unsigned); + int header_size = 2 * sizeof(size_t) + 2 * sizeof(uint32_t); std::unique_ptr header = std::make_unique(header_size); read_array(reader, header.get(), header_size); - expected_file_size = *((_u64 *)header.get()); - _max_observed_degree = *((_u32 *)(header.get() + sizeof(_u64))); - _start = *((_u32 *)(header.get() + sizeof(_u64) + sizeof(unsigned))); - file_frozen_pts = *((_u64 *)(header.get() + sizeof(_u64) + sizeof(unsigned) + sizeof(unsigned))); + expected_file_size = *((size_t *)header.get()); + _max_observed_degree = *((uint32_t *)(header.get() + sizeof(size_t))); + _start = *((uint32_t *)(header.get() + sizeof(size_t) + sizeof(uint32_t))); + file_frozen_pts = *((size_t *)(header.get() + sizeof(size_t) + sizeof(uint32_t) + sizeof(uint32_t))); #else - _u64 file_offset = 0; // will need this for single file format support + size_t file_offset = 0; // will need this for single file format support std::ifstream in; in.exceptions(std::ios::badbit | std::ios::failbit); in.open(filename, std::ios::binary); in.seekg(file_offset, in.beg); - in.read((char *)&expected_file_size, sizeof(_u64)); - in.read((char *)&_max_observed_degree, sizeof(unsigned)); - in.read((char *)&_start, sizeof(unsigned)); - in.read((char *)&file_frozen_pts, sizeof(_u64)); - _u64 vamana_metadata_size = sizeof(_u64) + sizeof(_u32) + sizeof(_u32) + sizeof(_u64); + in.read((char *)&expected_file_size, sizeof(size_t)); + in.read((char *)&_max_observed_degree, sizeof(uint32_t)); + in.read((char *)&_start, sizeof(uint32_t)); + in.read((char *)&file_frozen_pts, sizeof(size_t)); + size_t vamana_metadata_size = sizeof(size_t) + sizeof(uint32_t) + sizeof(uint32_t) + sizeof(size_t); #endif diskann::cout << "From graph header, expected_file_size: " << expected_file_size @@ -710,7 +750,6 @@ size_t Index::load_graph(std::string filename, size_t expected_ << std::endl; } diskann::cerr << stream.str() << std::endl; - aligned_free(_data); throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } @@ -733,18 +772,18 @@ size_t Index::load_graph(std::string filename, size_t expected_ _max_points = expected_max_points; } #ifdef EXEC_ENV_OLS - _u32 nodes_read = 0; - _u64 cc = 0; - _u64 graph_offset = header_size; + uint32_t nodes_read = 0; + size_t cc = 0; + size_t graph_offset = header_size; while (nodes_read < expected_num_points) { - _u32 k; + uint32_t k; read_value(reader, k, graph_offset); - graph_offset += sizeof(_u32); - std::vector<_u32> tmp(k); + graph_offset += sizeof(uint32_t); + std::vector tmp(k); tmp.reserve(k); read_array(reader, tmp.data(), k, graph_offset); - graph_offset += k * sizeof(_u32); + graph_offset += k * sizeof(uint32_t); cc += k; _final_graph[nodes_read].swap(tmp); nodes_read++; @@ -760,11 +799,11 @@ size_t Index::load_graph(std::string filename, size_t expected_ #else size_t bytes_read = vamana_metadata_size; size_t cc = 0; - unsigned nodes_read = 0; + uint32_t nodes_read = 0; while (bytes_read != expected_file_size) { - unsigned k; - in.read((char *)&k, sizeof(unsigned)); + uint32_t k; + in.read((char *)&k, sizeof(uint32_t)); if (k == 0) { @@ -773,11 +812,11 @@ size_t Index::load_graph(std::string filename, size_t expected_ cc += k; ++nodes_read; - std::vector tmp(k); + std::vector tmp(k); tmp.reserve(k); - in.read((char *)tmp.data(), k * sizeof(unsigned)); + in.read((char *)tmp.data(), k * sizeof(uint32_t)); _final_graph[nodes_read - 1].swap(tmp); - bytes_read += sizeof(uint32_t) * ((_u64)k + 1); + bytes_read += sizeof(uint32_t) * ((size_t)k + 1); if (nodes_read % 10000000 == 0) diskann::cout << "." << std::flush; if (k > _max_range_of_loaded_graph) @@ -792,6 +831,25 @@ size_t Index::load_graph(std::string filename, size_t expected_ return nodes_read; } +template +int Index::_get_vector_by_tag(TagType &tag, DataType &vec) +{ + try + { + TagT tag_val = std::any_cast(tag); + T *vec_val = std::any_cast(vec); + return this->get_vector_by_tag(tag_val, vec_val); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while performing _get_vector_by_tags() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + template int Index::get_vector_by_tag(TagT &tag, T *vec) { std::shared_lock lock(_tag_lock); @@ -801,73 +859,34 @@ template int Index return -1; } - size_t location = _tag_to_location[tag]; - memcpy((void *)vec, (void *)(_data + location * _aligned_dim), (size_t)_dim * sizeof(T)); + location_t location = _tag_to_location[tag]; + _data_store->get_vector(location, vec); + return 0; } -template unsigned Index::calculate_entry_point() +template uint32_t Index::calculate_entry_point() { // TODO: need to compute medoid with PQ data too, for now sample at random if (_pq_dist) { size_t r = (size_t)rand() * (size_t)RAND_MAX + (size_t)rand(); - return (unsigned)(r % (size_t)_nd); + return (uint32_t)(r % (size_t)_nd); } - // allocate and init centroid - float *center = new float[_aligned_dim](); - for (size_t j = 0; j < _aligned_dim; j++) - center[j] = 0; - - for (size_t i = 0; i < _nd; i++) - for (size_t j = 0; j < _aligned_dim; j++) - center[j] += (float)_data[i * _aligned_dim + j]; - - for (size_t j = 0; j < _aligned_dim; j++) - center[j] /= (float)_nd; - - // compute all to one distance - float *distances = new float[_nd](); -#pragma omp parallel for schedule(static, 65536) - for (_s64 i = 0; i < (_s64)_nd; i++) - { - // extract point and distance reference - float &dist = distances[i]; - const T *cur_vec = _data + (i * (size_t)_aligned_dim); - dist = 0; - float diff = 0; - for (size_t j = 0; j < _aligned_dim; j++) - { - diff = (center[j] - (float)cur_vec[j]) * (center[j] - (float)cur_vec[j]); - dist += diff; - } - } - // find imin - unsigned min_idx = 0; - float min_dist = distances[0]; - for (unsigned i = 1; i < _nd; i++) - { - if (distances[i] < min_dist) - { - min_idx = i; - min_dist = distances[i]; - } - } - - delete[] distances; - delete[] center; - return min_idx; + // TODO: This function does not support multi-threaded calculation of medoid. + // Must revisit if perf is a concern. + return _data_store->calculate_medoid(); } -template std::vector Index::get_init_ids() +template std::vector Index::get_init_ids() { - std::vector init_ids; + std::vector init_ids; init_ids.reserve(1 + _num_frozen_pts); init_ids.emplace_back(_start); - for (unsigned frozen = _max_points; frozen < _max_points + _num_frozen_pts; frozen++) + for (uint32_t frozen = (uint32_t)_max_points; frozen < _max_points + _num_frozen_pts; frozen++) { if (frozen != _start) { @@ -878,32 +897,66 @@ template std::vector Inde return init_ids; } +// Find common filter between a node's labels and a given set of labels, while taking into account universal label +template +bool Index::detect_common_filters(uint32_t point_id, bool search_invocation, + const std::vector &incoming_labels) +{ + auto &curr_node_labels = _pts_to_labels[point_id]; + std::vector common_filters; + std::set_intersection(incoming_labels.begin(), incoming_labels.end(), curr_node_labels.begin(), + curr_node_labels.end(), std::back_inserter(common_filters)); + if (common_filters.size() > 0) + { + // This is to reduce the repetitive calls. If common_filters size is > 0 , we dont need to check further for + // universal label + return true; + } + if (_use_universal_label) + { + if (!search_invocation) + { + if (std::find(incoming_labels.begin(), incoming_labels.end(), _universal_label) != incoming_labels.end() || + std::find(curr_node_labels.begin(), curr_node_labels.end(), _universal_label) != curr_node_labels.end()) + common_filters.push_back(_universal_label); + } + else + { + if (std::find(curr_node_labels.begin(), curr_node_labels.end(), _universal_label) != curr_node_labels.end()) + common_filters.push_back(_universal_label); + } + } + return (common_filters.size() > 0); +} + template std::pair Index::iterate_to_fixed_point( - const T *query, const unsigned Lsize, const std::vector &init_ids, InMemQueryScratch *scratch, + const T *query, const uint32_t Lsize, const std::vector &init_ids, InMemQueryScratch *scratch, bool use_filter, const std::vector &filter_label, bool search_invocation) { std::vector &expanded_nodes = scratch->pool(); NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); best_L_nodes.reserve(Lsize); - tsl::robin_set &inserted_into_pool_rs = scratch->inserted_into_pool_rs(); + tsl::robin_set &inserted_into_pool_rs = scratch->inserted_into_pool_rs(); boost::dynamic_bitset<> &inserted_into_pool_bs = scratch->inserted_into_pool_bs(); - std::vector &id_scratch = scratch->id_scratch(); + std::vector &id_scratch = scratch->id_scratch(); std::vector &dist_scratch = scratch->dist_scratch(); assert(id_scratch.size() == 0); - T *aligned_query = scratch->aligned_query(); - memcpy(aligned_query, query, _dim * sizeof(T)); - if (_normalize_vecs) - { - normalize((float *)aligned_query, _dim); - } + // REFACTOR + // T *aligned_query = scratch->aligned_query(); + // memcpy(aligned_query, query, _dim * sizeof(T)); + // if (_normalize_vecs) + //{ + // normalize((float *)aligned_query, _dim); + // } + + T *aligned_query = scratch->aligned_query(); std::vector& query_bitmask_buf = scratch->query_label_bitmask(); - - float *query_float; - float *query_rotated; - float *pq_dists; - _u8 *pq_coord_scratch; + float *query_float = nullptr; + float *query_rotated = nullptr; + float *pq_dists = nullptr; + uint8_t *pq_coord_scratch = nullptr; // Intialize PQ related scratch to use PQ based distances if (_pq_dist) { @@ -948,13 +1001,13 @@ std::pair Index::iterate_to_fixed_point( } // Lambda to determine if a node has been visited - auto is_not_visited = [this, fast_iterate, &inserted_into_pool_bs, &inserted_into_pool_rs](const unsigned id) { + auto is_not_visited = [this, fast_iterate, &inserted_into_pool_bs, &inserted_into_pool_rs](const uint32_t id) { return fast_iterate ? inserted_into_pool_bs[id] == 0 : inserted_into_pool_rs.find(id) == inserted_into_pool_rs.end(); }; // Lambda to batch compute query<-> node distances in PQ space - auto compute_dists = [this, pq_coord_scratch, pq_dists](const std::vector &ids, + auto compute_dists = [this, pq_coord_scratch, pq_dists](const std::vector &ids, std::vector &dists_out) { diskann::aggregate_coords(ids, this->_pq_data, this->_num_pq_chunks, pq_coord_scratch); diskann::pq_dist_lookup(pq_coord_scratch, ids.size(), this->_num_pq_chunks, pq_dists, dists_out); @@ -1022,9 +1075,13 @@ std::pair Index::iterate_to_fixed_point( float distance; if (_pq_dist) + { pq_dist_lookup(pq_coord_scratch, 1, this->_num_pq_chunks, pq_dists, &distance); + } else - distance = _distance->compare(_data + _aligned_dim * (size_t)id, aligned_query, (unsigned)_aligned_dim); + { + distance = _data_store->get_distance(aligned_query, id); + } Neighbor nn = Neighbor(id, distance); best_L_nodes.insert(nn); } @@ -1114,17 +1171,15 @@ std::pair Index::iterate_to_fixed_point( assert(dist_scratch.size() == 0); for (size_t m = 0; m < id_scratch.size(); ++m) { - unsigned id = id_scratch[m]; + uint32_t id = id_scratch[m]; if (m + 1 < id_scratch.size()) { auto nextn = id_scratch[m + 1]; - diskann::prefetch_vector((const char *)_data + _aligned_dim * (size_t)nextn, - sizeof(T) * _aligned_dim); + _data_store->prefetch_vector(nextn); } - dist_scratch.push_back( - _distance->compare(aligned_query, _data + _aligned_dim * (size_t)id, (unsigned)_aligned_dim)); + dist_scratch.push_back(_data_store->get_distance(aligned_query, id)); } } // cmps += id_scratch.size(); @@ -1139,32 +1194,35 @@ std::pair Index::iterate_to_fixed_point( } template -void Index::search_for_point_and_prune(int location, _u32 Lindex, std::vector &pruned_list, +void Index::search_for_point_and_prune(int location, uint32_t Lindex, + std::vector &pruned_list, InMemQueryScratch *scratch, bool use_filter, - _u32 filteredLindex) + uint32_t filteredLindex) { - const std::vector init_ids = get_init_ids(); + const std::vector init_ids = get_init_ids(); const std::vector unused_filter_label; if (!use_filter) { - iterate_to_fixed_point(_data + _aligned_dim * location, Lindex, init_ids, scratch, false, unused_filter_label, - false); + _data_store->get_vector(location, scratch->aligned_query()); + iterate_to_fixed_point(scratch->aligned_query(), Lindex, init_ids, scratch, false, unused_filter_label, false); } else { - std::vector<_u32> filter_specific_start_nodes; + std::vector filter_specific_start_nodes; for (auto &x : _pts_to_labels[location]) filter_specific_start_nodes.emplace_back(_label_to_medoid_id[x]); - iterate_to_fixed_point(_data + _aligned_dim * location, filteredLindex, filter_specific_start_nodes, scratch, - true, _pts_to_labels[location], false); + + _data_store->get_vector(location, scratch->aligned_query()); + iterate_to_fixed_point(scratch->aligned_query(), filteredLindex, filter_specific_start_nodes, scratch, true, + _pts_to_labels[location], false); } auto &pool = scratch->pool(); - for (unsigned i = 0; i < pool.size(); i++) + for (uint32_t i = 0; i < pool.size(); i++) { - if (pool[i].id == (unsigned)location) + if (pool[i].id == (uint32_t)location) { pool.erase(pool.begin() + i); i--; @@ -1183,10 +1241,10 @@ void Index::search_for_point_and_prune(int location, _u32 Linde } template -void Index::occlude_list(const unsigned location, std::vector &pool, const float alpha, - const unsigned degree, const unsigned maxc, std::vector &result, +void Index::occlude_list(const uint32_t location, std::vector &pool, const float alpha, + const uint32_t degree, const uint32_t maxc, std::vector &result, InMemQueryScratch *scratch, - const tsl::robin_set *const delete_set_ptr) + const tsl::robin_set *const delete_set_ptr) { if (pool.size() == 0) return; @@ -1218,8 +1276,8 @@ void Index::occlude_list(const unsigned location, std::vector::max(); - // Add the entry to the result if its not been deleted, and doesn't add - // a self loop + // Add the entry to the result if its not been deleted, and doesn't + // add a self loop if (delete_set_ptr == nullptr || delete_set_ptr->find(iter->id) == delete_set_ptr->end()) { if (iter->id != location) @@ -1238,8 +1296,8 @@ void Index::occlude_list(const unsigned location, std::vectorid; - _u32 b = iter2->id; + uint32_t a = iter->id; + uint32_t b = iter2->id; simple_bitmask bm1(_bitmask_buf.get_bitmask(a), _bitmask_buf._bitmask_size); simple_bitmask bm2(_bitmask_buf.get_bitmask(b), _bitmask_buf._bitmask_size); @@ -1249,8 +1307,7 @@ void Index::occlude_list(const unsigned location, std::vectorcompare(_data + _aligned_dim * (size_t)iter2->id, - _data + _aligned_dim * (size_t)iter->id, (unsigned)_aligned_dim); + float djk = _data_store->get_distance(iter2->id, iter->id); if (_dist_metric == diskann::Metric::L2 || _dist_metric == diskann::Metric::COSINE) { occlude_factor[t] = (djk == 0) ? std::numeric_limits::max() @@ -1273,16 +1330,16 @@ void Index::occlude_list(const unsigned location, std::vector -void Index::prune_neighbors(const unsigned location, std::vector &pool, - std::vector &pruned_list, InMemQueryScratch *scratch) +void Index::prune_neighbors(const uint32_t location, std::vector &pool, + std::vector &pruned_list, InMemQueryScratch *scratch) { prune_neighbors(location, pool, _indexingRange, _indexingMaxC, _indexingAlpha, pruned_list, scratch); } template -void Index::prune_neighbors(const unsigned location, std::vector &pool, const _u32 range, - const _u32 max_candidate_size, const float alpha, - std::vector &pruned_list, InMemQueryScratch *scratch) +void Index::prune_neighbors(const uint32_t location, std::vector &pool, const uint32_t range, + const uint32_t max_candidate_size, const float alpha, + std::vector &pruned_list, InMemQueryScratch *scratch) { if (pool.size() == 0) { @@ -1297,14 +1354,14 @@ void Index::prune_neighbors(const unsigned location, std::vecto if (_pq_dist) { for (auto &ngh : pool) - ngh.distance = _distance->compare(_data + _aligned_dim * (size_t)ngh.id, - _data + _aligned_dim * (size_t)location, (unsigned)_aligned_dim); + ngh.distance = _data_store->get_distance(ngh.id, location); } // sort the pool based on distance to query and prune it with occlude_list std::sort(pool.begin(), pool.end()); pruned_list.clear(); pruned_list.reserve(range); + occlude_list(location, pool, alpha, range, max_candidate_size, pruned_list, scratch); assert(pruned_list.size() <= range); @@ -1322,7 +1379,7 @@ void Index::prune_neighbors(const unsigned location, std::vecto } template -void Index::inter_insert(unsigned n, std::vector &pruned_list, const _u32 range, +void Index::inter_insert(uint32_t n, std::vector &pruned_list, const uint32_t range, InMemQueryScratch *scratch) { const auto &src_pool = pruned_list; @@ -1334,14 +1391,14 @@ void Index::inter_insert(unsigned n, std::vector &pru // des.loc is the loc of the neighbors of n assert(des < _max_points + _num_frozen_pts); // des_pool contains the neighbors of the neighbors of n - std::vector copy_of_neighbors; + std::vector copy_of_neighbors; bool prune_needed = false; { LockGuard guard(_locks[des]); auto &des_pool = _final_graph[des]; if (std::find(des_pool.begin(), des_pool.end(), n) == des_pool.end()) { - if (des_pool.size() < (_u64)(GRAPH_SLACK_FACTOR * range)) + if (des_pool.size() < (uint64_t)(GRAPH_SLACK_FACTOR * range)) { des_pool.emplace_back(n); prune_needed = false; @@ -1358,7 +1415,7 @@ void Index::inter_insert(unsigned n, std::vector &pru if (prune_needed) { - tsl::robin_set dummy_visited(0); + tsl::robin_set dummy_visited(0); std::vector dummy_pool(0); size_t reserveSize = (size_t)(std::ceil(1.05 * GRAPH_SLACK_FACTOR * range)); @@ -1369,13 +1426,12 @@ void Index::inter_insert(unsigned n, std::vector &pru { if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != des) { - float dist = _distance->compare(_data + _aligned_dim * (size_t)des, - _data + _aligned_dim * (size_t)cur_nbr, (unsigned)_aligned_dim); + float dist = _data_store->get_distance(des, cur_nbr); dummy_pool.emplace_back(Neighbor(cur_nbr, dist)); dummy_visited.insert(cur_nbr); } } - std::vector new_out_neighbors; + std::vector new_out_neighbors; prune_neighbors(des, dummy_pool, new_out_neighbors, scratch); { LockGuard guard(_locks[des]); @@ -1387,51 +1443,49 @@ void Index::inter_insert(unsigned n, std::vector &pru } template -void Index::inter_insert(unsigned n, std::vector &pruned_list, InMemQueryScratch *scratch) +void Index::inter_insert(uint32_t n, std::vector &pruned_list, InMemQueryScratch *scratch) { inter_insert(n, pruned_list, _indexingRange, scratch); } -template void Index::link(Parameters ¶meters) +template +void Index::link(const IndexWriteParameters ¶meters) { - unsigned num_threads = parameters.Get("num_threads"); + uint32_t num_threads = parameters.num_threads; if (num_threads != 0) omp_set_num_threads(num_threads); - _saturate_graph = parameters.Get("saturate_graph"); + _saturate_graph = parameters.saturate_graph; - if (num_threads != 0) - omp_set_num_threads(num_threads); - - _indexingQueueSize = parameters.Get("L"); // Search list size - _filterIndexingQueueSize = parameters.Get("Lf"); - _indexingRange = parameters.Get("R"); - _indexingMaxC = parameters.Get("C"); - _indexingAlpha = parameters.Get("alpha"); + _indexingQueueSize = parameters.search_list_size; + _filterIndexingQueueSize = parameters.filter_list_size; + _indexingRange = parameters.max_degree; + _indexingMaxC = parameters.max_occlusion_size; + _indexingAlpha = parameters.alpha; /* visit_order is a vector that is initialized to the entire graph */ - std::vector visit_order; + std::vector visit_order; std::vector pool, tmp; - tsl::robin_set visited; + tsl::robin_set visited; visit_order.reserve(_nd + _num_frozen_pts); - for (unsigned i = 0; i < (unsigned)_nd; i++) + for (uint32_t i = 0; i < (uint32_t)_nd; i++) { visit_order.emplace_back(i); } // If there are any frozen points, add them all. - for (unsigned frozen = _max_points; frozen < _max_points + _num_frozen_pts; frozen++) + for (uint32_t frozen = (uint32_t)_max_points; frozen < _max_points + _num_frozen_pts; frozen++) { visit_order.emplace_back(frozen); } // if there are frozen points, the first such one is set to be the _start if (_num_frozen_pts > 0) - _start = (unsigned)_max_points; + _start = (uint32_t)_max_points; else _start = calculate_entry_point(); - for (uint64_t p = 0; p < _nd; p++) + for (size_t p = 0; p < _nd; p++) { _final_graph[p].reserve((size_t)(std::ceil(_indexingRange * GRAPH_SLACK_FACTOR * 1.05))); } @@ -1439,14 +1493,14 @@ template void Index> manager(_query_scratch); auto scratch = manager.scratch_space(); - std::vector pruned_list; + std::vector pruned_list; if (_filtered_index) { search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch, _filtered_index, @@ -1458,7 +1512,7 @@ template void Index void Index _indexingRange) @@ -1485,16 +1539,15 @@ template void Index> manager(_query_scratch); auto scratch = manager.scratch_space(); - tsl::robin_set dummy_visited(0); + tsl::robin_set dummy_visited(0); std::vector dummy_pool(0); - std::vector new_out_neighbors; + std::vector new_out_neighbors; for (auto cur_nbr : _final_graph[node]) { if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != node) { - float dist = _distance->compare(_data + _aligned_dim * (size_t)node, - _data + _aligned_dim * (size_t)cur_nbr, (unsigned)_aligned_dim); + float dist = _data_store->get_distance(node, cur_nbr); dummy_pool.emplace_back(Neighbor(cur_nbr, dist)); dummy_visited.insert(cur_nbr); } @@ -1513,24 +1566,25 @@ template void Index -void Index::prune_all_nbrs(const Parameters ¶meters) +void Index::prune_all_neighbors(const uint32_t max_degree, const uint32_t max_occlusion_size, + const float alpha) { - const unsigned range = parameters.Get("R"); - const unsigned maxc = parameters.Get("C"); - const float alpha = parameters.Get("alpha"); + const uint32_t range = max_degree; + const uint32_t maxc = max_occlusion_size; + _filtered_index = true; diskann::Timer timer; #pragma omp parallel for - for (_s64 node = 0; node < (_s64)(_max_points + _num_frozen_pts); node++) + for (int64_t node = 0; node < (int64_t)(_max_points + _num_frozen_pts); node++) { if ((size_t)node < _nd || (size_t)node >= _max_points) { if (_final_graph[node].size() > range) { - tsl::robin_set dummy_visited(0); + tsl::robin_set dummy_visited(0); std::vector dummy_pool(0); - std::vector new_out_neighbors; + std::vector new_out_neighbors; ScratchStoreManager> manager(_query_scratch); auto scratch = manager.scratch_space(); @@ -1539,14 +1593,13 @@ void Index::prune_all_nbrs(const Parameters ¶meters) { if (dummy_visited.find(cur_nbr) == dummy_visited.end() && cur_nbr != node) { - float dist = _distance->compare(_data + _aligned_dim * (size_t)node, - _data + _aligned_dim * (size_t)cur_nbr, (unsigned)_aligned_dim); + float dist = _data_store->get_distance((location_t)node, (location_t)cur_nbr); dummy_pool.emplace_back(Neighbor(cur_nbr, dist)); dummy_visited.insert(cur_nbr); } } - prune_neighbors((_u32)node, dummy_pool, range, maxc, alpha, new_out_neighbors, scratch); + prune_neighbors((uint32_t)node, dummy_pool, range, maxc, alpha, new_out_neighbors, scratch); _final_graph[node].clear(); for (auto id : new_out_neighbors) _final_graph[node].emplace_back(id); @@ -1560,7 +1613,7 @@ void Index::prune_all_nbrs(const Parameters ¶meters) { if (i < _nd || i >= _max_points) { - const std::vector &pool = _final_graph[i]; + const std::vector &pool = _final_graph[i]; max = (std::max)(max, pool.size()); min = (std::min)(min, pool.size()); total += pool.size(); @@ -1578,6 +1631,7 @@ void Index::prune_all_nbrs(const Parameters ¶meters) } } +// REFACTOR template void Index::set_start_points(const T *data, size_t data_count) { @@ -1586,23 +1640,47 @@ void Index::set_start_points(const T *data, size_t data_count) if (_nd > 0) throw ANNException("Can not set starting point for a non-empty index", -1, __FUNCSIG__, __FILE__, __LINE__); - if (data_count != _num_frozen_pts * _aligned_dim) + if (data_count != _num_frozen_pts * _dim) throw ANNException("Invalid number of points", -1, __FUNCSIG__, __FILE__, __LINE__); - memcpy(_data + _aligned_dim * _max_points, data, _aligned_dim * sizeof(T) * _num_frozen_pts); + // memcpy(_data + _aligned_dim * _max_points, data, _aligned_dim * + // sizeof(T) * _num_frozen_pts); + for (location_t i = 0; i < _num_frozen_pts; i++) + { + _data_store->set_vector((location_t)(i + _max_points), data + i * _dim); + } _has_built = true; diskann::cout << "Index start points set: #" << _num_frozen_pts << std::endl; } template -void Index::set_start_points_at_random(T radius, unsigned int random_seed) +void Index::_set_start_points_at_random(DataType radius, uint32_t random_seed) +{ + try + { + T radius_to_use = std::any_cast(radius); + this->set_start_points_at_random(radius_to_use, random_seed); + } + catch (const std::bad_any_cast &e) + { + throw ANNException( + "Error: bad any cast while performing _set_start_points_at_random() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + +template +void Index::set_start_points_at_random(T radius, uint32_t random_seed) { std::mt19937 gen{random_seed}; std::normal_distribution<> d{0.0, 1.0}; std::vector points_data; - points_data.reserve(_aligned_dim * _num_frozen_pts); - std::vector real_vec(_aligned_dim); + points_data.reserve(_dim * _num_frozen_pts); + std::vector real_vec(_dim); for (size_t frozen_point = 0; frozen_point < _num_frozen_pts; frozen_point++) { @@ -1623,7 +1701,8 @@ void Index::set_start_points_at_random(T radius, unsigned int r } template -void Index::build_with_data_populated(Parameters ¶meters, const std::vector &tags) +void Index::build_with_data_populated(const IndexWriteParameters ¶meters, + const std::vector &tags) { diskann::cout << "Starting index build with " << _nd << " points... " << std::endl; @@ -1636,26 +1715,26 @@ void Index::build_with_data_populated(Parameters ¶meters, c stream << "ERROR: Driver requests loading " << _nd << " points from file," << "but tags vector is of size " << tags.size() << "." << std::endl; diskann::cerr << stream.str() << std::endl; - aligned_free(_data); throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } if (_enable_tags) { for (size_t i = 0; i < tags.size(); ++i) { - _tag_to_location[tags[i]] = (unsigned)i; - _location_to_tag.set(static_cast(i), tags[i]); + _tag_to_location[tags[i]] = (uint32_t)i; + _location_to_tag.set(static_cast(i), tags[i]); } } - uint32_t index_R = parameters.Get("R"); - uint32_t num_threads_index = parameters.Get("num_threads"); - uint32_t index_L = parameters.Get("L"); - uint32_t maxc = parameters.Get("C"); + uint32_t index_R = parameters.max_degree; + uint32_t num_threads_index = parameters.num_threads; + uint32_t index_L = parameters.search_list_size; + uint32_t maxc = parameters.max_occlusion_size; if (_query_scratch.size() == 0) { - initialize_query_scratch(5 + num_threads_index, index_L, index_L, index_R, maxc, _aligned_dim); + initialize_query_scratch(5 + num_threads_index, index_L, index_L, index_R, maxc, + _data_store->get_aligned_dim()); } generate_frozen_point(); @@ -1674,13 +1753,30 @@ void Index::build_with_data_populated(Parameters ¶meters, c diskann::cout << "Index built with degree: max:" << max << " avg:" << (float)total / (float)(_nd + _num_frozen_pts) << " min:" << min << " count(deg<2):" << cnt << std::endl; - _max_observed_degree = std::max((unsigned)max, _max_observed_degree); + _max_observed_degree = std::max((uint32_t)max, _max_observed_degree); _has_built = true; } - template -void Index::build(const T *data, const size_t num_points_to_load, Parameters ¶meters, - const std::vector &tags) +void Index::_build(const DataType &data, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, TagVector &tags) +{ + try + { + this->build(std::any_cast(data), num_points_to_load, parameters, + tags.get>()); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast in while building index. " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error" + std::string(e.what()), -1); + } +} +template +void Index::build(const T *data, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, const std::vector &tags) { if (num_points_to_load == 0) { @@ -1698,25 +1794,31 @@ void Index::build(const T *data, const size_t num_points_to_loa std::unique_lock tl(_tag_lock); _nd = num_points_to_load; - memcpy((char *)_data, (char *)data, _aligned_dim * _nd * sizeof(T)); + _data_store->populate_data(data, (location_t)num_points_to_load); - if (_normalize_vecs) - { - for (uint64_t i = 0; i < num_points_to_load; i++) - { - normalize(_data + _aligned_dim * i, _aligned_dim); - } - } + // REFACTOR + // memcpy((char *)_data, (char *)data, _aligned_dim * _nd * sizeof(T)); + // if (_normalize_vecs) + //{ + // for (size_t i = 0; i < num_points_to_load; i++) + // { + // normalize(_data + _aligned_dim * i, _aligned_dim); + // } + // } } build_with_data_populated(parameters, tags); } template -void Index::build(const char *filename, const size_t num_points_to_load, Parameters ¶meters, - const std::vector &tags) +void Index::build(const char *filename, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, const std::vector &tags) { + // idealy this should call build_filtered_index based on params passed + std::unique_lock ul(_update_lock); + + // error checks if (num_points_to_load == 0) throw ANNException("Do not call build with 0 points", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -1744,8 +1846,6 @@ void Index::build(const char *filename, const size_t num_points if (_pq_dist) aligned_free(_pq_data); - else - aligned_free(_data); throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } @@ -1757,8 +1857,6 @@ void Index::build(const char *filename, const size_t num_points if (_pq_dist) aligned_free(_pq_data); - else - aligned_free(_data); throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } @@ -1771,8 +1869,6 @@ void Index::build(const char *filename, const size_t num_points if (_pq_dist) aligned_free(_pq_data); - else - aligned_free(_data); throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } @@ -1787,25 +1883,18 @@ void Index::build(const char *filename, const size_t num_points generate_quantized_data(std::string(filename), pq_pivots_file, pq_compressed_file, _dist_metric, p_val, _num_pq_chunks, _use_opq); - copy_aligned_data_from_file<_u8>(pq_compressed_file.c_str(), _pq_data, file_num_points, _num_pq_chunks, - _num_pq_chunks); + copy_aligned_data_from_file(pq_compressed_file.c_str(), _pq_data, file_num_points, _num_pq_chunks, + _num_pq_chunks); #ifdef EXEC_ENV_OLS - throw ANNException("load_pq_centroid_bin should not be called when EXEC_ENV_OLS is defined.", -1, __FUNCSIG__, - __FILE__, __LINE__); + throw ANNException("load_pq_centroid_bin should not be called when " + "EXEC_ENV_OLS is defined.", + -1, __FUNCSIG__, __FILE__, __LINE__); #else _pq_table.load_pq_centroid_bin(pq_pivots_file.c_str(), _num_pq_chunks); #endif } - copy_aligned_data_from_file(filename, _data, file_num_points, file_dim, _aligned_dim); - if (_normalize_vecs) - { - for (uint64_t i = 0; i < file_num_points; i++) - { - normalize(_data + _aligned_dim * i, _aligned_dim); - } - } - + _data_store->populate_data(filename, 0U); diskann::cout << "Using only first " << num_points_to_load << " from file.. " << std::endl; { @@ -1816,8 +1905,8 @@ void Index::build(const char *filename, const size_t num_points } template -void Index::build(const char *filename, const size_t num_points_to_load, Parameters ¶meters, - const char *tag_filename) +void Index::build(const char *filename, const size_t num_points_to_load, + const IndexWriteParameters ¶meters, const char *tag_filename) { std::vector tags; @@ -1859,6 +1948,43 @@ void Index::build(const char *filename, const size_t num_points build(filename, num_points_to_load, parameters, tags); } +template +void Index::build(const std::string &data_file, const size_t num_points_to_load, + IndexBuildParams &build_params) +{ + std::string labels_file_to_use = build_params.save_path_prefix + "_label_formatted.txt"; + std::string mem_labels_int_map_file = build_params.save_path_prefix + "_labels_map.txt"; + + size_t points_to_load = num_points_to_load == 0 ? _max_points : num_points_to_load; + + auto s = std::chrono::high_resolution_clock::now(); + if (build_params.label_file == "") + { + this->build(data_file.c_str(), points_to_load, build_params.index_write_params); + } + else + { + // TODO: this should ideally happen in save() + uint32_t unv_label_as_num = 0; + convert_labels_string_to_int(build_params.label_file, labels_file_to_use, mem_labels_int_map_file, + build_params.universal_label, unv_label_as_num); + if (build_params.universal_label != "") + { + LabelT unv_label_as_num = 0; + this->set_universal_label(unv_label_as_num); + } + this->build_filtered_index(data_file.c_str(), labels_file_to_use, points_to_load, + build_params.index_write_params); + } + std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; + std::cout << "Indexing time: " << diff.count() << "\n"; + // cleanup + if (build_params.label_file != "") + { + // clean_up_artifacts({labels_file_to_use, mem_labels_int_map_file}, {}); + } +} + template std::unordered_map Index::load_label_map(const std::string &labels_map_file) { @@ -1873,7 +1999,7 @@ std::unordered_map Index::load_label_map(c getline(iss, token, '\t'); label_str = token; getline(iss, token, '\t'); - token_as_num = std::stoul(token); + token_as_num = (LabelT)std::stoul(token); string_to_int_mp[label_str] = token_as_num; } return string_to_int_mp; @@ -1908,7 +2034,7 @@ void Index::parse_label_file(const std::string &label_file, siz } std::string line, token; - unsigned line_cnt = 0; + uint32_t line_cnt = 0; while (std::getline(infile, line)) { @@ -1930,7 +2056,7 @@ void Index::parse_label_file(const std::string &label_file, siz { token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = std::stoul(token); + LabelT token_as_num = (LabelT)std::stoul(token); lbls.push_back(token_as_num); _labels.insert(token_as_num); } @@ -2017,20 +2143,21 @@ void Index::set_universal_label(const LabelT &label) template void Index::build_filtered_index(const char *filename, const std::string &label_file, - const size_t num_points_to_load, Parameters ¶meters, + const size_t num_points_to_load, IndexWriteParameters ¶meters, const std::vector &tags) { - _labels_file = label_file; + _labels_file = label_file; // original label file _filtered_index = true; _label_to_medoid_id.clear(); size_t num_points_labels = 0; + parse_label_file(label_file, - num_points_labels); // determines medoid for each label and - // identifies the points to label mapping + num_points_labels); // determines medoid for each label and identifies + // the points to label mapping convert_pts_label_to_bitmask(_pts_to_labels, _bitmask_buf, _labels.size()); - std::unordered_map> label_to_points; + std::unordered_map> label_to_points; std::vector label_bitmask; for (int lbl = 0; lbl < _labels.size(); lbl++) { @@ -2051,8 +2178,8 @@ void Index::build_filtered_index(const char *filename, const st bitmask_full_val.merge_bitmask_val(bitmask_val); } - std::vector<_u32> labeled_points; - for (_u32 point_id = 0; point_id < num_points_to_load; point_id++) + std::vector labeled_points; + for (uint32_t point_id = 0; point_id < num_points_to_load; point_id++) { simple_bitmask bm(_bitmask_buf.get_bitmask(point_id), _bitmask_buf._bitmask_size); bool pt_has_lbl = bm.test_full_mask_val(bitmask_full_val); @@ -2065,17 +2192,17 @@ void Index::build_filtered_index(const char *filename, const st label_to_points[x] = labeled_points; } - _u32 num_cands = 25; + uint32_t num_cands = 25; for (auto itr = _labels.begin(); itr != _labels.end(); itr++) { - _u32 best_medoid_count = std::numeric_limits<_u32>::max(); + uint32_t best_medoid_count = std::numeric_limits::max(); auto &curr_label = *itr; - _u32 best_medoid; + uint32_t best_medoid; auto labeled_points = label_to_points[curr_label]; - for (_u32 cnd = 0; cnd < num_cands; cnd++) + for (uint32_t cnd = 0; cnd < num_cands; cnd++) { - _u32 cur_cnd = labeled_points[rand() % labeled_points.size()]; - _u32 cur_cnt = std::numeric_limits<_u32>::max(); + uint32_t cur_cnd = labeled_points[rand() % labeled_points.size()]; + uint32_t cur_cnt = std::numeric_limits::max(); if (_medoid_counts.find(cur_cnd) == _medoid_counts.end()) { _medoid_counts[cur_cnd] = 0; @@ -2098,9 +2225,41 @@ void Index::build_filtered_index(const char *filename, const st this->build(filename, num_points_to_load, parameters, tags); } +template +std::pair Index::_search(const DataType &query, const size_t K, const uint32_t L, + std::any &indices, float *distances) +{ + try + { + auto typed_query = std::any_cast(query); + if (typeid(uint32_t *) == indices.type()) + { + auto u32_ptr = std::any_cast(indices); + return this->search(typed_query, K, L, u32_ptr, distances); + } + else if (typeid(uint64_t *) == indices.type()) + { + auto u64_ptr = std::any_cast(indices); + return this->search(typed_query, K, L, u64_ptr, distances); + } + else + { + throw ANNException("Error: indices type can only be uint64_t or uint32_t.", -1); + } + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while searching. " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + template template -std::pair Index::search(const T *query, const size_t K, const unsigned L, +std::pair Index::search(const T *query, const size_t K, const uint32_t L, IdType *indices, float *distances) { if (K > (uint64_t)L) @@ -2120,16 +2279,18 @@ std::pair Index::search(const T *query, con } const std::vector unused_filter_label; - const std::vector init_ids = get_init_ids(); + const std::vector init_ids = get_init_ids(); std::shared_lock lock(_update_lock); - auto retval = iterate_to_fixed_point(query, L, init_ids, scratch, false, unused_filter_label, true); + _distance->preprocess_query(query, _data_store->get_dims(), scratch->aligned_query()); + auto retval = + iterate_to_fixed_point(scratch->aligned_query(), L, init_ids, scratch, false, unused_filter_label, true); NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); size_t pos = 0; - for (int i = 0; i < best_L_nodes.size(); ++i) + for (size_t i = 0; i < best_L_nodes.size(); ++i) { if (best_L_nodes[i].id < _max_points) { @@ -2153,16 +2314,39 @@ std::pair Index::search(const T *query, con } if (pos < K) { - diskann::cerr << "Found fewer than K elements for query" << std::endl; + diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K << " for query" << std::endl; } return retval; } +template +std::pair Index::_search_with_filters(const DataType &query, + const std::string &raw_label, const size_t K, + const uint32_t L, std::any &indices, + float *distances) +{ + auto converted_label = this->get_converted_label(raw_label); + if (typeid(uint64_t *) == indices.type()) + { + auto ptr = std::any_cast(indices); + return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); + } + else if (typeid(uint32_t *) == indices.type()) + { + auto ptr = std::any_cast(indices); + return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); + } + else + { + throw ANNException("Error: Id type can only be uint64_t or uint32_t.", -1); + } +} + template template std::pair Index::search_with_filters(const T *query, const LabelT &filter_label, - const size_t K, const unsigned L, + const size_t K, const uint32_t L, IdType *indices, float *distances) { if (K > (uint64_t)L) @@ -2182,7 +2366,7 @@ std::pair Index::search_with_filters(const } std::vector filter_vec; - std::vector init_ids = get_init_ids(); + std::vector init_ids = get_init_ids(); std::shared_lock lock(_update_lock); @@ -2198,15 +2382,16 @@ std::pair Index::search_with_filters(const } filter_vec.emplace_back(filter_label); - T *aligned_query = scratch->aligned_query(); - memcpy(aligned_query, query, _dim * sizeof(T)); - - auto retval = iterate_to_fixed_point(aligned_query, L, init_ids, scratch, true, filter_vec, true); + // REFACTOR + // T *aligned_query = scratch->aligned_query(); + // memcpy(aligned_query, query, _dim * sizeof(T)); + _distance->preprocess_query(query, _data_store->get_dims(), scratch->aligned_query()); + auto retval = iterate_to_fixed_point(scratch->aligned_query(), L, init_ids, scratch, true, filter_vec, true); auto best_L_nodes = scratch->best_l_nodes(); size_t pos = 0; - for (int i = 0; i < best_L_nodes.size(); ++i) + for (size_t i = 0; i < best_L_nodes.size(); ++i) { if (best_L_nodes[i].id < _max_points) { @@ -2237,7 +2422,26 @@ std::pair Index::search_with_filters(const } template -size_t Index::search_with_tags(const T *query, const uint64_t K, const unsigned L, TagT *tags, +size_t Index::_search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, + const TagType &tags, float *distances, DataVector &res_vectors) +{ + try + { + return this->search_with_tags(std::any_cast(query), K, L, std::any_cast(tags), distances, + res_vectors.get>()); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while performing _search_with_tags() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + +template +size_t Index::search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, float *distances, std::vector &res_vectors) { if (K > (uint64_t)L) @@ -2257,10 +2461,12 @@ size_t Index::search_with_tags(const T *query, const uint64_t K std::shared_lock ul(_update_lock); - const std::vector init_ids = get_init_ids(); + const std::vector init_ids = get_init_ids(); const std::vector unused_filter_label; - iterate_to_fixed_point(query, L, init_ids, scratch, false, unused_filter_label, true); + _distance->preprocess_query(query, _data_store->get_dims(), scratch->aligned_query()); + iterate_to_fixed_point(scratch->aligned_query(), L, init_ids, scratch, false, unused_filter_label, true); + NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); assert(best_L_nodes.size() <= L); @@ -2278,7 +2484,7 @@ size_t Index::search_with_tags(const T *query, const uint64_t K if (res_vectors.size() > 0) { - memcpy(res_vectors[pos], _data + ((size_t)node.id) * _aligned_dim, _dim * sizeof(T)); + _data_store->get_vector(node.id, res_vectors[pos]); } if (distances != nullptr) @@ -2337,7 +2543,7 @@ template void Indexcopy_vectors((location_t)res, (location_t)_max_points, 1); } } @@ -2351,33 +2557,38 @@ template int Index return -2; } + if (this->_deletes_enabled) + { + return 0; + } + std::unique_lock ul(_update_lock); std::unique_lock tl(_tag_lock); std::unique_lock dl(_delete_lock); if (_data_compacted) { - for (unsigned slot = (unsigned)_nd; slot < _max_points; ++slot) + for (uint32_t slot = (uint32_t)_nd; slot < _max_points; ++slot) { _empty_slots.insert(slot); } } - + this->_deletes_enabled = true; return 0; } template -inline void Index::process_delete(const tsl::robin_set &old_delete_set, size_t loc, - const unsigned range, const unsigned maxc, const float alpha, +inline void Index::process_delete(const tsl::robin_set &old_delete_set, size_t loc, + const uint32_t range, const uint32_t maxc, const float alpha, InMemQueryScratch *scratch) { - tsl::robin_set &expanded_nodes_set = scratch->expanded_nodes_set(); + tsl::robin_set &expanded_nodes_set = scratch->expanded_nodes_set(); std::vector &expanded_nghrs_vec = scratch->expanded_nodes_vec(); // If this condition were not true, deadlock could result - assert(old_delete_set.find(loc) == old_delete_set.end()); + assert(old_delete_set.find((uint32_t)loc) == old_delete_set.end()); - std::vector adj_list; + std::vector adj_list; { // Acquire and release lock[loc] before acquiring locks for neighbors std::unique_lock adj_list_lock; @@ -2421,13 +2632,12 @@ inline void Index::process_delete(const tsl::robin_setcompare(_data + _aligned_dim * loc, _data + _aligned_dim * ngh, (unsigned)_aligned_dim)); + expanded_nghrs_vec.emplace_back(ngh, _data_store->get_distance((location_t)loc, (location_t)ngh)); } std::sort(expanded_nghrs_vec.begin(), expanded_nghrs_vec.end()); - std::vector &occlude_list_output = scratch->occlude_list_output(); - occlude_list(loc, expanded_nghrs_vec, alpha, range, maxc, occlude_list_output, scratch, &old_delete_set); + std::vector &occlude_list_output = scratch->occlude_list_output(); + occlude_list((uint32_t)loc, expanded_nghrs_vec, alpha, range, maxc, occlude_list_output, scratch, + &old_delete_set); std::unique_lock adj_list_lock(_locks[loc]); _final_graph[loc] = occlude_list_output; } @@ -2436,7 +2646,7 @@ inline void Index::process_delete(const tsl::robin_set -consolidation_report Index::consolidate_deletes(const Parameters ¶ms) +consolidation_report Index::consolidate_deletes(const IndexWriteParameters ¶ms) { if (!_enable_tags) throw diskann::ANNException("Point tag array not instantiated", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -2480,7 +2690,7 @@ consolidation_report Index::consolidate_deletes(const Parameter diskann::cout << "Starting consolidate_deletes... "; - std::unique_ptr> old_delete_set(new tsl::robin_set); + std::unique_ptr> old_delete_set(new tsl::robin_set); { std::unique_lock dl(_delete_lock); std::swap(_delete_set, old_delete_set); @@ -2491,18 +2701,17 @@ consolidation_report Index::consolidate_deletes(const Parameter throw diskann::ANNException("ERROR: start node has been deleted", -1, __FUNCSIG__, __FILE__, __LINE__); } - const unsigned range = params.Get("R"); - const unsigned maxc = params.Get("C"); - const float alpha = params.Get("alpha"); - const unsigned num_threads = - params.Get("num_threads") == 0 ? omp_get_num_threads() : params.Get("num_threads"); + const uint32_t range = params.max_degree; + const uint32_t maxc = params.max_occlusion_size; + const float alpha = params.alpha; + const uint32_t num_threads = params.num_threads == 0 ? omp_get_num_threads() : params.num_threads; - unsigned num_calls_to_process_delete = 0; + uint32_t num_calls_to_process_delete = 0; diskann::Timer timer; #pragma omp parallel for num_threads(num_threads) schedule(dynamic, 8192) reduction(+ : num_calls_to_process_delete) - for (_s64 loc = 0; loc < (_s64)_max_points; loc++) + for (int64_t loc = 0; loc < (int64_t)_max_points; loc++) { - if (old_delete_set->find((_u32)loc) == old_delete_set->end() && !_empty_slots.is_in_set((_u32)loc)) + if (old_delete_set->find((uint32_t)loc) == old_delete_set->end() && !_empty_slots.is_in_set((uint32_t)loc)) { ScratchStoreManager> manager(_query_scratch); auto scratch = manager.scratch_space(); @@ -2510,7 +2719,7 @@ consolidation_report Index::consolidate_deletes(const Parameter num_calls_to_process_delete += 1; } } - for (_s64 loc = _max_points; loc < (_s64)(_max_points + _num_frozen_pts); loc++) + for (int64_t loc = _max_points; loc < (int64_t)(_max_points + _num_frozen_pts); loc++) { ScratchStoreManager> manager(_query_scratch); auto scratch = manager.scratch_space(); @@ -2543,8 +2752,8 @@ template void Index 0) { - reposition_points(_max_points, _nd, _num_frozen_pts); - _start = (_u32)_nd; + reposition_points((uint32_t)_max_points, (uint32_t)_nd, (uint32_t)_num_frozen_pts); + _start = (uint32_t)_nd; } } @@ -2570,11 +2779,11 @@ template void Index new_location = std::vector(_max_points + _num_frozen_pts, (_u32)UINT32_MAX); + std::vector new_location = std::vector(_max_points + _num_frozen_pts, UINT32_MAX); - _u32 new_counter = 0; - std::set<_u32> empty_locations; - for (_u32 old_location = 0; old_location < _max_points; old_location++) + uint32_t new_counter = 0; + std::set empty_locations; + for (uint32_t old_location = 0; old_location < _max_points; old_location++) { if (_location_to_tag.contains(old_location)) { @@ -2586,7 +2795,7 @@ template void Index void Index new_adj_list; + std::vector new_adj_list; if ((new_location[old] < _max_points) // If point continues to exist || (old >= _max_points && old < _max_points + _num_frozen_pts)) @@ -2627,8 +2836,7 @@ template void Indexcopy_vectors(old, new_location[old], 1); } } else @@ -2650,7 +2858,7 @@ template void Index void Index int Index { return -1; } - unsigned location; + uint32_t location; if (_data_compacted && _empty_slots.is_empty()) { // This code path is encountered when enable_delete hasn't been // called yet, so no points have been deleted and _empty_slots // hasn't been filled in. In that case, just keep assigning // consecutive locations. - location = (unsigned)_nd; + location = (uint32_t)_nd; } else { @@ -2707,13 +2914,14 @@ template size_t Index -size_t Index::release_locations(const tsl::robin_set &locations) +size_t Index::release_locations(const tsl::robin_set &locations) { for (auto location : locations) { if (_empty_slots.is_in_set(location)) - throw ANNException("Trying to release location, but location already in empty slots", -1, __FUNCSIG__, - __FILE__, __LINE__); + throw ANNException("Trying to release location, but location " + "already in empty slots", + -1, __FUNCSIG__, __FILE__, __LINE__); _empty_slots.insert(location); _nd--; @@ -2726,40 +2934,43 @@ size_t Index::release_locations(const tsl::robin_set } template -void Index::reposition_points(unsigned old_location_start, unsigned new_location_start, - unsigned num_locations) +void Index::reposition_points(uint32_t old_location_start, uint32_t new_location_start, + uint32_t num_locations) { if (num_locations == 0 || old_location_start == new_location_start) { return; } - // Update pointers to the moved nodes. Note: the computation is correct even when - // new_location_start < old_location_start given the C++ unsigned integer arithmetic - // rules. - const unsigned location_delta = new_location_start - old_location_start; + // Update pointers to the moved nodes. Note: the computation is correct even + // when new_location_start < old_location_start given the C++ uint32_t + // integer arithmetic rules. + const uint32_t location_delta = new_location_start - old_location_start; - for (unsigned i = 0; i < _max_points + _num_frozen_pts; i++) + for (uint32_t i = 0; i < _max_points + _num_frozen_pts; i++) for (auto &loc : _final_graph[i]) if (loc >= old_location_start && loc < old_location_start + num_locations) loc += location_delta; - // The [start, end) interval which will contain obsolete points to be cleared. - unsigned mem_clear_loc_start = old_location_start; - unsigned mem_clear_loc_end_limit = old_location_start + num_locations; + // The [start, end) interval which will contain obsolete points to be + // cleared. + uint32_t mem_clear_loc_start = old_location_start; + uint32_t mem_clear_loc_end_limit = old_location_start + num_locations; - // Move the adjacency lists. Make sure that overlapping ranges are handled correctly. + // Move the adjacency lists. Make sure that overlapping ranges are handled + // correctly. if (new_location_start < old_location_start) { // New location before the old location: copy the entries in order // to avoid modifying locations that are yet to be copied. - for (unsigned loc_offset = 0; loc_offset < num_locations; loc_offset++) + for (uint32_t loc_offset = 0; loc_offset < num_locations; loc_offset++) { assert(_final_graph[new_location_start + loc_offset].empty()); _final_graph[new_location_start + loc_offset].swap(_final_graph[old_location_start + loc_offset]); } - // If ranges are overlapping, make sure not to clear the newly copied data. + // If ranges are overlapping, make sure not to clear the newly copied + // data. if (mem_clear_loc_start < new_location_start + num_locations) { // Clear only after the end of the new range. @@ -2770,25 +2981,21 @@ void Index::reposition_points(unsigned old_location_start, unsi { // Old location after the new location: copy from the end of the range // to avoid modifying locations that are yet to be copied. - for (unsigned loc_offset = num_locations; loc_offset > 0; loc_offset--) + for (uint32_t loc_offset = num_locations; loc_offset > 0; loc_offset--) { assert(_final_graph[new_location_start + loc_offset - 1u].empty()); _final_graph[new_location_start + loc_offset - 1u].swap(_final_graph[old_location_start + loc_offset - 1u]); } - // If ranges are overlapping, make sure not to clear the newly copied data. + // If ranges are overlapping, make sure not to clear the newly copied + // data. if (mem_clear_loc_end_limit > new_location_start) { // Clear only up to the beginning of the new range. mem_clear_loc_end_limit = new_location_start; } } - - // Use memmove to handle overlapping ranges. - memmove(_data + _aligned_dim * new_location_start, _data + _aligned_dim * old_location_start, - sizeof(T) * _aligned_dim * num_locations); - memset(_data + _aligned_dim * mem_clear_loc_start, 0, - sizeof(T) * _aligned_dim * (mem_clear_loc_end_limit - mem_clear_loc_start)); + _data_store->move_vectors(old_location_start, new_location_start, num_locations); } template void Index::reposition_frozen_point_to_end() @@ -2802,8 +3009,8 @@ template void Index void Index::resize(size_t new_max_points) @@ -2812,22 +3019,14 @@ template void Indexresize((location_t)new_internal_points); _final_graph.resize(new_internal_points); _locks = std::vector(new_internal_points); if (_num_frozen_pts != 0) { - reposition_points((_u32)_max_points, (_u32)new_max_points, (_u32)_num_frozen_pts); - _start = (_u32)new_max_points; + reposition_points((uint32_t)_max_points, (uint32_t)new_max_points, (uint32_t)_num_frozen_pts); + _start = (uint32_t)new_max_points; } _max_points = new_max_points; @@ -2841,13 +3040,31 @@ template void Index(stop - start).count() << "s" << std::endl; } +template +int Index::_insert_point(const DataType &point, const TagType tag) +{ + try + { + return this->insert_point(std::any_cast(point), std::any_cast(tag)); + } + catch (const std::bad_any_cast &anycast_e) + { + throw new ANNException("Error:Trying to insert invalid data type" + std::string(anycast_e.what()), -1); + } + catch (const std::exception &e) + { + throw new ANNException("Error:" + std::string(e.what()), -1); + } +} + template int Index::insert_point(const T *point, const TagT tag) { assert(_has_built); if (tag == static_cast(0)) { - throw diskann::ANNException("Do not insert point with tag 0. That is reserved for points hidden " + throw diskann::ANNException("Do not insert point with tag 0. That is " + "reserved for points hidden " "from the user.", -1, __FUNCSIG__, __FILE__, __LINE__); } @@ -2888,8 +3105,9 @@ int Index::insert_point(const T *point, const TagT tag) location = reserve_location(); if (location == -1) { - throw diskann::ANNException("Cannot reserve location even after expanding graph. Terminating.", -1, - __FUNCSIG__, __FILE__, __LINE__); + throw diskann::ANNException("Cannot reserve location even after " + "expanding graph. Terminating.", + -1, __FUNCSIG__, __FILE__, __LINE__); } #else return -1; @@ -2911,20 +3129,12 @@ int Index::insert_point(const T *point, const TagT tag) } tl.unlock(); - // Copy the vector in to the data array - auto offset_data = _data + (size_t)_aligned_dim * location; - memset((void *)offset_data, 0, sizeof(T) * _aligned_dim); - memcpy((void *)offset_data, point, sizeof(T) * _dim); - - if (_normalize_vecs) - { - normalize((float *)offset_data, _dim); - } + _data_store->set_vector(location, point); // Find and add appropriate graph edges ScratchStoreManager> manager(_query_scratch); auto scratch = manager.scratch_space(); - std::vector pruned_list; + std::vector pruned_list; if (_filtered_index) { search_for_point_and_prune(location, _indexingQueueSize, pruned_list, scratch, true, _filterIndexingQueueSize); @@ -2940,7 +3150,7 @@ int Index::insert_point(const T *point, const TagT tag) LockGuard guard(_locks[location]); _final_graph[location].clear(); - _final_graph[location].reserve((_u64)(_indexingRange * GRAPH_SLACK_FACTOR * 1.05)); + _final_graph[location].reserve((size_t)(_indexingRange * GRAPH_SLACK_FACTOR * 1.05)); for (auto link : pruned_list) { @@ -2960,6 +3170,35 @@ int Index::insert_point(const T *point, const TagT tag) return 0; } +template int Index::_lazy_delete(const TagType &tag) +{ + try + { + return lazy_delete(std::any_cast(tag)); + } + catch (const std::bad_any_cast &e) + { + throw ANNException(std::string("Error: ") + e.what(), -1); + } +} + +template +void Index::_lazy_delete(TagVector &tags, TagVector &failed_tags) +{ + try + { + this->lazy_delete(tags.get>(), failed_tags.get>()); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while performing _lazy_delete() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + template int Index::lazy_delete(const TagT &tag) { std::shared_lock ul(_update_lock); @@ -3015,6 +3254,23 @@ template bool Index +void Index::_get_active_tags(TagRobinSet &active_tags) +{ + try + { + this->get_active_tags(active_tags.get>()); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad_any cast while performing _get_active_tags() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error :" + std::string(e.what()), -1); + } +} + template void Index::get_active_tags(tsl::robin_set &active_tags) { @@ -3052,12 +3308,12 @@ template void Index visited(_max_points + _num_frozen_pts); size_t MAX_BFS_LEVELS = 32; - auto bfs_sets = new tsl::robin_set[MAX_BFS_LEVELS]; + auto bfs_sets = new tsl::robin_set[MAX_BFS_LEVELS]; bfs_sets[0].insert(_start); visited.set(_start); - for (unsigned i = _max_points; i < _max_points + _num_frozen_pts; ++i) + for (uint32_t i = (uint32_t)_max_points; i < _max_points + _num_frozen_pts; ++i) { if (i != _start) { @@ -3087,6 +3343,13 @@ template void Index void +// Index::optimize_index_layout() +//{ // use after build or load +//} + +// REFACTOR: This should be an OptimizedDataStore class template void Index::optimize_index_layout() { // use after build or load if (_dynamic_index) @@ -3095,40 +3358,70 @@ template void Indexget_aligned_dim()]; + std::memset(cur_vec, 0, _data_store->get_aligned_dim() * sizeof(float)); + _data_len = (_data_store->get_aligned_dim() + 1) * sizeof(float); + _neighbor_len = (_max_observed_degree + 1) * sizeof(uint32_t); _node_size = _data_len + _neighbor_len; _opt_graph = new char[_node_size * _nd]; - DistanceFastL2 *dist_fast = (DistanceFastL2 *)_distance; - for (unsigned i = 0; i < _nd; i++) + DistanceFastL2 *dist_fast = (DistanceFastL2 *)_data_store->get_dist_fn(); + for (uint32_t i = 0; i < _nd; i++) { char *cur_node_offset = _opt_graph + i * _node_size; - float cur_norm = dist_fast->norm(_data + i * _aligned_dim, _aligned_dim); + _data_store->get_vector(i, (T *)cur_vec); + float cur_norm = dist_fast->norm((T *)cur_vec, (uint32_t)_data_store->get_aligned_dim()); std::memcpy(cur_node_offset, &cur_norm, sizeof(float)); - std::memcpy(cur_node_offset + sizeof(float), _data + i * _aligned_dim, _data_len - sizeof(float)); + std::memcpy(cur_node_offset + sizeof(float), cur_vec, _data_len - sizeof(float)); cur_node_offset += _data_len; - unsigned k = _final_graph[i].size(); - std::memcpy(cur_node_offset, &k, sizeof(unsigned)); - std::memcpy(cur_node_offset + sizeof(unsigned), _final_graph[i].data(), k * sizeof(unsigned)); - std::vector().swap(_final_graph[i]); + uint32_t k = (uint32_t)_final_graph[i].size(); + std::memcpy(cur_node_offset, &k, sizeof(uint32_t)); + std::memcpy(cur_node_offset + sizeof(uint32_t), _final_graph[i].data(), k * sizeof(uint32_t)); + std::vector().swap(_final_graph[i]); } _final_graph.clear(); _final_graph.shrink_to_fit(); + delete[] cur_vec; +} + +// REFACTOR: once optimized layout becomes its own Data+Graph store, we should +// just invoke regular search +// template +// void Index::search_with_optimized_layout(const T *query, +// size_t K, size_t L, uint32_t *indices) +//{ +//} + +template +void Index::_search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) +{ + try + { + return this->search_with_optimized_layout(std::any_cast(query), K, L, indices); + } + catch (const std::bad_any_cast &e) + { + throw ANNException( + "Error: bad any cast while performing _search_with_optimized_layout() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } } template -void Index::search_with_optimized_layout(const T *query, size_t K, size_t L, unsigned *indices) +void Index::search_with_optimized_layout(const T *query, size_t K, size_t L, uint32_t *indices) { - DistanceFastL2 *dist_fast = (DistanceFastL2 *)_distance; + DistanceFastL2 *dist_fast = (DistanceFastL2 *)_data_store->get_dist_fn(); NeighborPriorityQueue retset(L); - std::vector init_ids(L); + std::vector init_ids(L); boost::dynamic_bitset<> flags{_nd, 0}; - unsigned tmp_l = 0; - unsigned *neighbors = (unsigned *)(_opt_graph + _node_size * _start + _data_len); - unsigned MaxM_ep = *neighbors; + uint32_t tmp_l = 0; + uint32_t *neighbors = (uint32_t *)(_opt_graph + _node_size * _start + _data_len); + uint32_t MaxM_ep = *neighbors; neighbors++; for (; tmp_l < L && tmp_l < MaxM_ep; tmp_l++) @@ -3139,7 +3432,7 @@ void Index::search_with_optimized_layout(const T *query, size_t while (tmp_l < L) { - unsigned id = rand() % _nd; + uint32_t id = rand() % _nd; if (flags[id]) continue; flags[id] = true; @@ -3147,23 +3440,23 @@ void Index::search_with_optimized_layout(const T *query, size_t tmp_l++; } - for (unsigned i = 0; i < init_ids.size(); i++) + for (uint32_t i = 0; i < init_ids.size(); i++) { - unsigned id = init_ids[i]; + uint32_t id = init_ids[i]; if (id >= _nd) continue; _mm_prefetch(_opt_graph + _node_size * id, _MM_HINT_T0); } L = 0; - for (unsigned i = 0; i < init_ids.size(); i++) + for (uint32_t i = 0; i < init_ids.size(); i++) { - unsigned id = init_ids[i]; + uint32_t id = init_ids[i]; if (id >= _nd) continue; T *x = (T *)(_opt_graph + _node_size * id); float norm_x = *x; x++; - float dist = dist_fast->compare(x, query, norm_x, (unsigned)_aligned_dim); + float dist = dist_fast->compare(x, query, norm_x, (uint32_t)_data_store->get_aligned_dim()); retset.insert(Neighbor(id, dist)); flags[id] = true; L++; @@ -3174,21 +3467,21 @@ void Index::search_with_optimized_layout(const T *query, size_t auto nbr = retset.closest_unexpanded(); auto n = nbr.id; _mm_prefetch(_opt_graph + _node_size * n + _data_len, _MM_HINT_T0); - neighbors = (unsigned *)(_opt_graph + _node_size * n + _data_len); - unsigned MaxM = *neighbors; + neighbors = (uint32_t *)(_opt_graph + _node_size * n + _data_len); + uint32_t MaxM = *neighbors; neighbors++; - for (unsigned m = 0; m < MaxM; ++m) + for (uint32_t m = 0; m < MaxM; ++m) _mm_prefetch(_opt_graph + _node_size * neighbors[m], _MM_HINT_T0); - for (unsigned m = 0; m < MaxM; ++m) + for (uint32_t m = 0; m < MaxM; ++m) { - unsigned id = neighbors[m]; + uint32_t id = neighbors[m]; if (flags[id]) continue; flags[id] = 1; T *data = (T *)(_opt_graph + _node_size * id); float norm = *data; data++; - float dist = dist_fast->compare(query, data, norm, (unsigned)_aligned_dim); + float dist = dist_fast->compare(query, data, norm, (uint32_t)_data_store->get_aligned_dim()); Neighbor nn(id, dist); retset.insert(nn); } @@ -3231,131 +3524,130 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const unsigned L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const unsigned L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint64_t *indices, + uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const unsigned L, uint32_t *indices, + uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); - } // namespace diskann diff --git a/src/index_factory.cpp b/src/index_factory.cpp new file mode 100644 index 000000000..c5607f4a0 --- /dev/null +++ b/src/index_factory.cpp @@ -0,0 +1,150 @@ +#include "index_factory.h" + +namespace diskann +{ + +IndexFactory::IndexFactory(const IndexConfig &config) : _config(std::make_unique(config)) +{ + check_config(); +} + +std::unique_ptr IndexFactory::create_instance() +{ + return create_instance(_config->data_type, _config->tag_type, _config->label_type); +} + +void IndexFactory::check_config() +{ + if (_config->dynamic_index && !_config->enable_tags) + { + throw ANNException("ERROR: Dynamic Indexing must have tags enabled.", -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (_config->pq_dist_build) + { + if (_config->dynamic_index) + throw ANNException("ERROR: Dynamic Indexing not supported with PQ distance based " + "index construction", + -1, __FUNCSIG__, __FILE__, __LINE__); + if (_config->metric == diskann::Metric::INNER_PRODUCT) + throw ANNException("ERROR: Inner product metrics not yet supported " + "with PQ distance " + "base index", + -1, __FUNCSIG__, __FILE__, __LINE__); + } + + if (_config->data_type != "float" && _config->data_type != "uint8" && _config->data_type != "int8") + { + throw ANNException("ERROR: invalid data type : + " + _config->data_type + + " is not supported. please select from [float, int8, uint8]", + -1); + } + + if (_config->tag_type != "int32" && _config->tag_type != "uint32" && _config->tag_type != "int64" && + _config->tag_type != "uint64") + { + throw ANNException("ERROR: invalid data type : + " + _config->tag_type + + " is not supported. please select from [int32, uint32, int64, uint64]", + -1); + } +} + +template +std::unique_ptr> IndexFactory::construct_datastore(DataStoreStrategy strategy, size_t num_points, + size_t dimension) +{ + const size_t total_internal_points = num_points + _config->num_frozen_pts; + std::shared_ptr> distance; + switch (strategy) + { + case MEMORY: + if (_config->metric == diskann::Metric::COSINE && std::is_same::value) + { + distance.reset((Distance *)new AVXNormalizedCosineDistanceFloat()); + return std::make_unique>((location_t)total_internal_points, dimension, distance); + } + else + { + distance.reset((Distance *)get_distance_function(_config->metric)); + return std::make_unique>((location_t)total_internal_points, dimension, distance); + } + break; + default: + break; + } + return nullptr; +} + +std::unique_ptr IndexFactory::construct_graphstore(GraphStoreStrategy, size_t size) +{ + return std::make_unique(size); +} + +template +std::unique_ptr IndexFactory::create_instance() +{ + size_t num_points = _config->max_points; + size_t dim = _config->dimension; + // auto graph_store = construct_graphstore(_config->graph_strategy, num_points); + auto data_store = construct_datastore(_config->data_strategy, num_points, dim); + return std::make_unique>(*_config, std::move(data_store)); +} + +std::unique_ptr IndexFactory::create_instance(const std::string &data_type, const std::string &tag_type, + const std::string &label_type) +{ + if (data_type == std::string("float")) + { + return create_instance(tag_type, label_type); + } + else if (data_type == std::string("uint8")) + { + return create_instance(tag_type, label_type); + } + else if (data_type == std::string("int8")) + { + return create_instance(tag_type, label_type); + } + else + throw ANNException("Error: unsupported data_type please choose from [float/int8/uint8]", -1); +} + +template +std::unique_ptr IndexFactory::create_instance(const std::string &tag_type, const std::string &label_type) +{ + if (tag_type == std::string("int32")) + { + return create_instance(label_type); + } + else if (tag_type == std::string("uint32")) + { + return create_instance(label_type); + } + else if (tag_type == std::string("int64")) + { + return create_instance(label_type); + } + else if (tag_type == std::string("uint64")) + { + return create_instance(label_type); + } + else + throw ANNException("Error: unsupported tag_type please choose from [int32/uint32/int64/uint64]", -1); +} + +template +std::unique_ptr IndexFactory::create_instance(const std::string &label_type) +{ + if (label_type == std::string("uint16") || label_type == std::string("ushort")) + { + return create_instance(); + } + else if (label_type == std::string("uint32") || label_type == std::string("uint")) + { + return create_instance(); + } + else + throw ANNException("Error: unsupported label_type please choose from [uint/ushort]", -1); +} + +} // namespace diskann diff --git a/src/logger.cpp b/src/logger.cpp index 1444487f7..052f54877 100644 --- a/src/logger.cpp +++ b/src/logger.cpp @@ -10,20 +10,19 @@ namespace diskann { +#ifdef ENABLE_CUSTOM_LOGGER DISKANN_DLLEXPORT ANNStreamBuf coutBuff(stdout); DISKANN_DLLEXPORT ANNStreamBuf cerrBuff(stderr); DISKANN_DLLEXPORT std::basic_ostream cout(&coutBuff); DISKANN_DLLEXPORT std::basic_ostream cerr(&cerrBuff); - -#ifdef EXEC_ENV_OLS std::function g_logger; void SetCustomLogger(std::function logger) { g_logger = logger; + diskann::cout << "Set Custom Logger" << std::endl; } -#endif ANNStreamBuf::ANNStreamBuf(FILE *fp) { @@ -37,11 +36,7 @@ ANNStreamBuf::ANNStreamBuf(FILE *fp) } _fp = fp; _logLevel = (_fp == stdout) ? LogLevel::LL_Info : LogLevel::LL_Error; -#ifdef EXEC_ENV_OLS _buf = new char[BUFFER_SIZE + 1]; // See comment in the header -#else - _buf = new char[BUFFER_SIZE]; // See comment in the header -#endif std::memset(_buf, 0, (BUFFER_SIZE) * sizeof(char)); setp(_buf, _buf + BUFFER_SIZE - 1); @@ -87,17 +82,16 @@ int ANNStreamBuf::flush() } void ANNStreamBuf::logImpl(char *str, int num) { -#ifdef EXEC_ENV_OLS str[num] = '\0'; // Safe. See the c'tor. // Invoke the OLS custom logging function. if (g_logger) { g_logger(_logLevel, str); } +} #else - fwrite(str, sizeof(char), num, _fp); - fflush(_fp); +using std::cerr; +using std::cout; #endif -} } // namespace diskann diff --git a/src/math_utils.cpp b/src/math_utils.cpp index 151e6a97d..7481da848 100644 --- a/src/math_utils.cpp +++ b/src/math_utils.cpp @@ -27,7 +27,7 @@ float calc_distance(float *vec_1, float *vec_2, size_t dim) void compute_vecs_l2sq(float *vecs_l2sq, float *data, const size_t num_points, const size_t dim) { #pragma omp parallel for schedule(static, 8192) - for (int64_t n_iter = 0; n_iter < (_s64)num_points; n_iter++) + for (int64_t n_iter = 0; n_iter < (int64_t)num_points; n_iter++) { vecs_l2sq[n_iter] = cblas_snrm2((MKL_INT)dim, (data + (n_iter * dim)), 1); vecs_l2sq[n_iter] *= vecs_l2sq[n_iter]; @@ -96,7 +96,7 @@ void compute_closest_centers_in_block(const float *const data, const size_t num_ if (k == 1) { #pragma omp parallel for schedule(static, 8192) - for (int64_t i = 0; i < (_s64)num_points; i++) + for (int64_t i = 0; i < (int64_t)num_points; i++) { float min = std::numeric_limits::max(); float *current = dist_matrix + (i * num_centers); @@ -113,7 +113,7 @@ void compute_closest_centers_in_block(const float *const data, const size_t num_ else { #pragma omp parallel for schedule(static, 8192) - for (int64_t i = 0; i < (_s64)num_points; i++) + for (int64_t i = 0; i < (int64_t)num_points; i++) { std::priority_queue top_k_queue; float *current = dist_matrix + (i * num_centers); @@ -182,7 +182,7 @@ void compute_closest_centers(float *data, size_t num_points, size_t dim, float * #pragma omp parallel for schedule(static, 1) for (int64_t j = cur_blk * PAR_BLOCK_SIZE; - j < std::min((_s64)num_points, (_s64)((cur_blk + 1) * PAR_BLOCK_SIZE)); j++) + j < std::min((int64_t)num_points, (int64_t)((cur_blk + 1) * PAR_BLOCK_SIZE)); j++) { for (size_t l = 0; l < k; l++) { @@ -212,7 +212,7 @@ void process_residuals(float *data_load, size_t num_points, size_t dim, float *c diskann::cout << "Processing residuals of " << num_points << " points in " << dim << " dimensions using " << num_centers << " centers " << std::endl; #pragma omp parallel for schedule(static, 8192) - for (int64_t n_iter = 0; n_iter < (_s64)num_points; n_iter++) + for (int64_t n_iter = 0; n_iter < (int64_t)num_points; n_iter++) { for (size_t d_iter = 0; d_iter < dim; d_iter++) { @@ -259,7 +259,7 @@ float lloyds_iter(float *data, size_t num_points, size_t dim, float *centers, si memset(centers, 0, sizeof(float) * (size_t)num_centers * (size_t)dim); #pragma omp parallel for schedule(static, 1) - for (int64_t c = 0; c < (_s64)num_centers; ++c) + for (int64_t c = 0; c < (int64_t)num_centers; ++c) { float *center = centers + (size_t)c * (size_t)dim; double *cluster_sum = new double[dim]; @@ -290,7 +290,7 @@ float lloyds_iter(float *data, size_t num_points, size_t dim, float *centers, si std::vector residuals(nchunks * BUF_PAD, 0.0); #pragma omp parallel for schedule(static, 32) - for (int64_t chunk = 0; chunk < (_s64)nchunks; ++chunk) + for (int64_t chunk = 0; chunk < (int64_t)nchunks; ++chunk) for (size_t d = chunk * CHUNK_SIZE; d < num_points && d < (chunk + 1) * CHUNK_SIZE; ++d) residuals[chunk * BUF_PAD] += math_utils::calc_distance(data + (d * dim), centers + (size_t)closest_center[d] * (size_t)dim, dim); @@ -405,7 +405,7 @@ void kmeanspp_selecting_pivots(float *data, size_t num_points, size_t dim, float float *dist = new float[num_points]; #pragma omp parallel for schedule(static, 8192) - for (int64_t i = 0; i < (_s64)num_points; i++) + for (int64_t i = 0; i < (int64_t)num_points; i++) { dist[i] = math_utils::calc_distance(data + i * dim, data + init_id * dim, dim); } @@ -446,7 +446,7 @@ void kmeanspp_selecting_pivots(float *data, size_t num_points, size_t dim, float std::memcpy(pivot_data + num_picked * dim, data + tmp_pivot * dim, dim * sizeof(float)); #pragma omp parallel for schedule(static, 8192) - for (int64_t i = 0; i < (_s64)num_points; i++) + for (int64_t i = 0; i < (int64_t)num_points; i++) { dist[i] = (std::min)(dist[i], math_utils::calc_distance(data + i * dim, data + tmp_pivot * dim, dim)); } diff --git a/src/natural_number_map.cpp b/src/natural_number_map.cpp index f5ba9523f..9050831a2 100644 --- a/src/natural_number_map.cpp +++ b/src/natural_number_map.cpp @@ -107,8 +107,8 @@ template void natural_number_map::cle } // Instantiate used templates. -template class natural_number_map; -template class natural_number_map; -template class natural_number_map; -template class natural_number_map; +template class natural_number_map; +template class natural_number_map; +template class natural_number_map; +template class natural_number_map; } // namespace diskann diff --git a/src/partition.cpp b/src/partition.cpp index ece453e87..2d46f9faf 100644 --- a/src/partition.cpp +++ b/src/partition.cpp @@ -33,7 +33,7 @@ template void gen_random_slice(const std::string base_file, const std::string output_prefix, double sampling_rate) { - _u64 read_blk_size = 64 * 1024 * 1024; + size_t read_blk_size = 64 * 1024 * 1024; cached_ifstream base_reader(base_file.c_str(), read_blk_size); std::ofstream sample_writer(std::string(output_prefix + "_data.bin").c_str(), std::ios::binary); std::ofstream sample_id_writer(std::string(output_prefix + "_ids.bin").c_str(), std::ios::binary); @@ -68,7 +68,7 @@ void gen_random_slice(const std::string base_file, const std::string output_pref if (sample < sampling_rate) { sample_writer.write((char *)cur_row.get(), sizeof(T) * nd); - uint32_t cur_i_u32 = (_u32)i; + uint32_t cur_i_u32 = (uint32_t)i; sample_id_writer.write((char *)&cur_i_u32, sizeof(uint32_t)); num_sampled_pts_u32++; } @@ -100,13 +100,13 @@ void gen_random_slice(const std::string data_file, double p_val, float *&sampled std::vector> sampled_vectors; // amount to read in one shot - _u64 read_blk_size = 64 * 1024 * 1024; + size_t read_blk_size = 64 * 1024 * 1024; // create cached reader + writer cached_ifstream base_reader(data_file.c_str(), read_blk_size); // metadata: npts, ndims - base_reader.read((char *)&npts32, sizeof(unsigned)); - base_reader.read((char *)&ndims32, sizeof(unsigned)); + base_reader.read((char *)&npts32, sizeof(uint32_t)); + base_reader.read((char *)&ndims32, sizeof(uint32_t)); npts = npts32; ndims = ndims32; @@ -115,7 +115,7 @@ void gen_random_slice(const std::string data_file, double p_val, float *&sampled std::random_device rd; // Will be used to obtain a seed for the random number size_t x = rd(); - std::mt19937 generator((unsigned)x); + std::mt19937 generator((uint32_t)x); std::uniform_real_distribution distribution(0, 1); for (size_t i = 0; i < npts; i++) @@ -154,7 +154,7 @@ void gen_random_slice(const T *inputdata, size_t npts, size_t ndims, double p_va std::random_device rd; // Will be used to obtain a seed for the random number engine size_t x = rd(); - std::mt19937 generator((unsigned)x); // Standard mersenne_twister_engine seeded with rd() + std::mt19937 generator((uint32_t)x); // Standard mersenne_twister_engine seeded with rd() std::uniform_real_distribution distribution(0, 1); for (size_t i = 0; i < npts; i++) @@ -193,7 +193,7 @@ int estimate_cluster_sizes(float *test_data_float, size_t num_test, float *pivot } size_t block_size = num_test <= BLOCK_SIZE ? num_test : BLOCK_SIZE; - _u32 *block_closest_centers = new _u32[block_size * k_base]; + uint32_t *block_closest_centers = new uint32_t[block_size * k_base]; float *block_data_float; size_t num_blocks = DIV_ROUND_UP(num_test, block_size); @@ -222,7 +222,7 @@ int estimate_cluster_sizes(float *test_data_float, size_t num_test, float *pivot diskann::cout << "Estimated cluster sizes: "; for (size_t i = 0; i < num_centers; i++) { - _u32 cur_shard_count = (_u32)shard_counts[i]; + uint32_t cur_shard_count = (uint32_t)shard_counts[i]; cluster_sizes.push_back((size_t)cur_shard_count); diskann::cout << cur_shard_count << " "; } @@ -236,12 +236,12 @@ template int shard_data_into_clusters(const std::string data_file, float *pivots, const size_t num_centers, const size_t dim, const size_t k_base, std::string prefix_path) { - _u64 read_blk_size = 64 * 1024 * 1024; - // _u64 write_blk_size = 64 * 1024 * 1024; + size_t read_blk_size = 64 * 1024 * 1024; + // uint64_t write_blk_size = 64 * 1024 * 1024; // create cached reader + writer cached_ifstream base_reader(data_file, read_blk_size); - _u32 npts32; - _u32 basedim32; + uint32_t npts32; + uint32_t basedim32; base_reader.read((char *)&npts32, sizeof(uint32_t)); base_reader.read((char *)&basedim32, sizeof(uint32_t)); size_t num_points = npts32; @@ -254,8 +254,8 @@ int shard_data_into_clusters(const std::string data_file, float *pivots, const s std::unique_ptr shard_counts = std::make_unique(num_centers); std::vector shard_data_writer(num_centers); std::vector shard_idmap_writer(num_centers); - _u32 dummy_size = 0; - _u32 const_one = 1; + uint32_t dummy_size = 0; + uint32_t const_one = 1; for (size_t i = 0; i < num_centers; i++) { @@ -271,7 +271,7 @@ int shard_data_into_clusters(const std::string data_file, float *pivots, const s } size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; - std::unique_ptr<_u32[]> block_closest_centers = std::make_unique<_u32[]>(block_size * k_base); + std::unique_ptr block_closest_centers = std::make_unique(block_size * k_base); std::unique_ptr block_data_T = std::make_unique(block_size * dim); std::unique_ptr block_data_float = std::make_unique(block_size * dim); @@ -306,7 +306,7 @@ int shard_data_into_clusters(const std::string data_file, float *pivots, const s diskann::cout << "Actual shard sizes: " << std::flush; for (size_t i = 0; i < num_centers; i++) { - _u32 cur_shard_count = (_u32)shard_counts[i]; + uint32_t cur_shard_count = (uint32_t)shard_counts[i]; total_count += cur_shard_count; diskann::cout << cur_shard_count << " "; shard_data_writer[i].seekp(0); @@ -328,12 +328,12 @@ template int shard_data_into_clusters_only_ids(const std::string data_file, float *pivots, const size_t num_centers, const size_t dim, const size_t k_base, std::string prefix_path) { - _u64 read_blk_size = 64 * 1024 * 1024; - // _u64 write_blk_size = 64 * 1024 * 1024; + size_t read_blk_size = 64 * 1024 * 1024; + // uint64_t write_blk_size = 64 * 1024 * 1024; // create cached reader + writer cached_ifstream base_reader(data_file, read_blk_size); - _u32 npts32; - _u32 basedim32; + uint32_t npts32; + uint32_t basedim32; base_reader.read((char *)&npts32, sizeof(uint32_t)); base_reader.read((char *)&basedim32, sizeof(uint32_t)); size_t num_points = npts32; @@ -346,8 +346,8 @@ int shard_data_into_clusters_only_ids(const std::string data_file, float *pivots std::unique_ptr shard_counts = std::make_unique(num_centers); std::vector shard_idmap_writer(num_centers); - _u32 dummy_size = 0; - _u32 const_one = 1; + uint32_t dummy_size = 0; + uint32_t const_one = 1; for (size_t i = 0; i < num_centers; i++) { @@ -359,7 +359,7 @@ int shard_data_into_clusters_only_ids(const std::string data_file, float *pivots } size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; - std::unique_ptr<_u32[]> block_closest_centers = std::make_unique<_u32[]>(block_size * k_base); + std::unique_ptr block_closest_centers = std::make_unique(block_size * k_base); std::unique_ptr block_data_T = std::make_unique(block_size * dim); std::unique_ptr block_data_float = std::make_unique(block_size * dim); @@ -393,7 +393,7 @@ int shard_data_into_clusters_only_ids(const std::string data_file, float *pivots diskann::cout << "Actual shard sizes: " << std::flush; for (size_t i = 0; i < num_centers; i++) { - _u32 cur_shard_count = (_u32)shard_counts[i]; + uint32_t cur_shard_count = (uint32_t)shard_counts[i]; total_count += cur_shard_count; diskann::cout << cur_shard_count << " "; shard_idmap_writer[i].seekp(0); @@ -409,29 +409,29 @@ int shard_data_into_clusters_only_ids(const std::string data_file, float *pivots template int retrieve_shard_data_from_ids(const std::string data_file, std::string idmap_filename, std::string data_filename) { - _u64 read_blk_size = 64 * 1024 * 1024; - // _u64 write_blk_size = 64 * 1024 * 1024; + size_t read_blk_size = 64 * 1024 * 1024; + // uint64_t write_blk_size = 64 * 1024 * 1024; // create cached reader + writer cached_ifstream base_reader(data_file, read_blk_size); - _u32 npts32; - _u32 basedim32; + uint32_t npts32; + uint32_t basedim32; base_reader.read((char *)&npts32, sizeof(uint32_t)); base_reader.read((char *)&basedim32, sizeof(uint32_t)); size_t num_points = npts32; size_t dim = basedim32; - _u32 dummy_size = 0; + uint32_t dummy_size = 0; std::ofstream shard_data_writer(data_filename.c_str(), std::ios::binary); shard_data_writer.write((char *)&dummy_size, sizeof(uint32_t)); shard_data_writer.write((char *)&basedim32, sizeof(uint32_t)); - _u32 *shard_ids; - _u64 shard_size, tmp; - diskann::load_bin<_u32>(idmap_filename, shard_ids, shard_size, tmp); + uint32_t *shard_ids; + uint64_t shard_size, tmp; + diskann::load_bin(idmap_filename, shard_ids, shard_size, tmp); - _u32 cur_pos = 0; - _u32 num_written = 0; + uint32_t cur_pos = 0; + uint32_t num_written = 0; std::cout << "Shard has " << shard_size << " points" << std::endl; size_t block_size = num_points <= BLOCK_SIZE ? num_points : BLOCK_SIZE; @@ -495,7 +495,8 @@ int partition(const std::string data_file, const float sampling_rate, size_t num // kmeans_partitioning on training data - // cur_file = cur_file + "_kmeans_partitioning-" + std::to_string(num_parts); + // cur_file = cur_file + "_kmeans_partitioning-" + + // std::to_string(num_parts); output_file = cur_file + "_centroids.bin"; pivot_data = new float[num_parts * train_dim]; @@ -544,7 +545,8 @@ int partition_with_ram_budget(const std::string data_file, const double sampling // kmeans_partitioning on training data - // cur_file = cur_file + "_kmeans_partitioning-" + std::to_string(num_parts); + // cur_file = cur_file + "_kmeans_partitioning-" + + // std::to_string(num_parts); output_file = cur_file + "_centroids.bin"; while (!fit_in_ram) @@ -572,8 +574,9 @@ int partition_with_ram_budget(const std::string data_file, const double sampling { // to account for the fact that p is the size of the shard over the // testing sample. - p = (_u64)(p / sampling_rate); - double cur_shard_ram_estimate = diskann::estimate_ram_usage(p, train_dim, sizeof(T), graph_degree); + p = (uint64_t)(p / sampling_rate); + double cur_shard_ram_estimate = + diskann::estimate_ram_usage(p, (uint32_t)train_dim, sizeof(T), (uint32_t)graph_degree); if (cur_shard_ram_estimate > max_ram_usage) max_ram_usage = cur_shard_ram_estimate; diff --git a/src/pq.cpp b/src/pq.cpp index 2eb0a3cb0..86c68ce0a 100644 --- a/src/pq.cpp +++ b/src/pq.cpp @@ -41,16 +41,16 @@ void FixedChunkPQTable::load_pq_centroid_bin(const char *pq_table_file, size_t n { #endif - _u64 nr, nc; + uint64_t nr, nc; std::string rotmat_file = std::string(pq_table_file) + "_rotation_matrix.bin"; #ifdef EXEC_ENV_OLS - _u64 *file_offset_data; // since load_bin only sets the pointer, no need - // to delete. - diskann::load_bin<_u64>(files, pq_table_file, file_offset_data, nr, nc); + size_t *file_offset_data; // since load_bin only sets the pointer, no need + // to delete. + diskann::load_bin(files, pq_table_file, file_offset_data, nr, nc); #else - std::unique_ptr<_u64[]> file_offset_data; - diskann::load_bin<_u64>(pq_table_file, file_offset_data, nr, nc); + std::unique_ptr file_offset_data; + diskann::load_bin(pq_table_file, file_offset_data, nr, nc); #endif bool use_old_filetype = false; @@ -150,32 +150,32 @@ void FixedChunkPQTable::load_pq_centroid_bin(const char *pq_table_file, size_t n // alloc and compute transpose tables_tr = new float[256 * this->ndims]; - for (_u64 i = 0; i < 256; i++) + for (size_t i = 0; i < 256; i++) { - for (_u64 j = 0; j < this->ndims; j++) + for (size_t j = 0; j < this->ndims; j++) { tables_tr[j * 256 + i] = tables[i * this->ndims + j]; } } } -_u32 FixedChunkPQTable::get_num_chunks() +uint32_t FixedChunkPQTable::get_num_chunks() { - return static_cast<_u32>(n_chunks); + return static_cast(n_chunks); } void FixedChunkPQTable::preprocess_query(float *query_vec) { - for (_u32 d = 0; d < ndims; d++) + for (uint32_t d = 0; d < ndims; d++) { query_vec[d] -= centroid[d]; } std::vector tmp(ndims, 0); if (use_rotation) { - for (_u32 d = 0; d < ndims; d++) + for (uint32_t d = 0; d < ndims; d++) { - for (_u32 d1 = 0; d1 < ndims; d1++) + for (uint32_t d1 = 0; d1 < ndims; d1++) { tmp[d] += query_vec[d1] * rotmat_tr[d1 * ndims + d]; } @@ -189,14 +189,14 @@ void FixedChunkPQTable::populate_chunk_distances(const float *query_vec, float * { memset(dist_vec, 0, 256 * n_chunks * sizeof(float)); // chunk wise distance computation - for (_u64 chunk = 0; chunk < n_chunks; chunk++) + for (size_t chunk = 0; chunk < n_chunks; chunk++) { // sum (q-c)^2 for the dimensions associated with this chunk float *chunk_dists = dist_vec + (256 * chunk); - for (_u64 j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) + for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { const float *centers_dim_vec = tables_tr + (256 * j); - for (_u64 idx = 0; idx < 256; idx++) + for (size_t idx = 0; idx < 256; idx++) { double diff = centers_dim_vec[idx] - (query_vec[j]); chunk_dists[idx] += (float)(diff * diff); @@ -205,12 +205,12 @@ void FixedChunkPQTable::populate_chunk_distances(const float *query_vec, float * } } -float FixedChunkPQTable::l2_distance(const float *query_vec, _u8 *base_vec) +float FixedChunkPQTable::l2_distance(const float *query_vec, uint8_t *base_vec) { float res = 0; - for (_u64 chunk = 0; chunk < n_chunks; chunk++) + for (size_t chunk = 0; chunk < n_chunks; chunk++) { - for (_u64 j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) + for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { const float *centers_dim_vec = tables_tr + (256 * j); float diff = centers_dim_vec[base_vec[chunk]] - (query_vec[j]); @@ -220,12 +220,12 @@ float FixedChunkPQTable::l2_distance(const float *query_vec, _u8 *base_vec) return res; } -float FixedChunkPQTable::inner_product(const float *query_vec, _u8 *base_vec) +float FixedChunkPQTable::inner_product(const float *query_vec, uint8_t *base_vec) { float res = 0; - for (_u64 chunk = 0; chunk < n_chunks; chunk++) + for (size_t chunk = 0; chunk < n_chunks; chunk++) { - for (_u64 j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) + for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { const float *centers_dim_vec = tables_tr + (256 * j); float diff = centers_dim_vec[base_vec[chunk]] * query_vec[j]; // assumes centroid is 0 to @@ -238,11 +238,11 @@ float FixedChunkPQTable::inner_product(const float *query_vec, _u8 *base_vec) } // assumes no rotation is involved -void FixedChunkPQTable::inflate_vector(_u8 *base_vec, float *out_vec) +void FixedChunkPQTable::inflate_vector(uint8_t *base_vec, float *out_vec) { - for (_u64 chunk = 0; chunk < n_chunks; chunk++) + for (size_t chunk = 0; chunk < n_chunks; chunk++) { - for (_u64 j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) + for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { const float *centers_dim_vec = tables_tr + (256 * j); out_vec[j] = centers_dim_vec[base_vec[chunk]] + centroid[j]; @@ -254,35 +254,35 @@ void FixedChunkPQTable::populate_chunk_inner_products(const float *query_vec, fl { memset(dist_vec, 0, 256 * n_chunks * sizeof(float)); // chunk wise distance computation - for (_u64 chunk = 0; chunk < n_chunks; chunk++) + for (size_t chunk = 0; chunk < n_chunks; chunk++) { // sum (q-c)^2 for the dimensions associated with this chunk float *chunk_dists = dist_vec + (256 * chunk); - for (_u64 j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) + for (size_t j = chunk_offsets[chunk]; j < chunk_offsets[chunk + 1]; j++) { const float *centers_dim_vec = tables_tr + (256 * j); - for (_u64 idx = 0; idx < 256; idx++) + for (size_t idx = 0; idx < 256; idx++) { double prod = centers_dim_vec[idx] * query_vec[j]; // assumes that we are not // shifting the vectors to // mean zero, i.e., centroid // array should be all zeros - chunk_dists[idx] -= (float)prod; // returning negative to keep the search code clean - // (max inner product vs min distance) + chunk_dists[idx] -= (float)prod; // returning negative to keep the search code + // clean (max inner product vs min distance) } } } } -void aggregate_coords(const std::vector &ids, const _u8 *all_coords, const _u64 ndims, _u8 *out) +void aggregate_coords(const std::vector &ids, const uint8_t *all_coords, const size_t ndims, uint8_t *out) { - for (_u64 i = 0; i < ids.size(); i++) + for (size_t i = 0; i < ids.size(); i++) { - memcpy(out + i * ndims, all_coords + ids[i] * ndims, ndims * sizeof(_u8)); + memcpy(out + i * ndims, all_coords + ids[i] * ndims, ndims * sizeof(uint8_t)); } } -void pq_dist_lookup(const _u8 *pq_ids, const _u64 n_pts, const _u64 pq_nchunks, const float *pq_dists, +void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists, std::vector &dists_out) { //_mm_prefetch((char*) dists_out, _MM_HINT_T0); @@ -291,47 +291,50 @@ void pq_dist_lookup(const _u8 *pq_ids, const _u64 n_pts, const _u64 pq_nchunks, _mm_prefetch((char *)(pq_ids + 128), _MM_HINT_T0); dists_out.clear(); dists_out.resize(n_pts, 0); - for (_u64 chunk = 0; chunk < pq_nchunks; chunk++) + for (size_t chunk = 0; chunk < pq_nchunks; chunk++) { const float *chunk_dists = pq_dists + 256 * chunk; if (chunk < pq_nchunks - 1) { _mm_prefetch((char *)(chunk_dists + 256), _MM_HINT_T0); } - for (_u64 idx = 0; idx < n_pts; idx++) + for (size_t idx = 0; idx < n_pts; idx++) { - _u8 pq_centerid = pq_ids[pq_nchunks * idx + chunk]; + uint8_t pq_centerid = pq_ids[pq_nchunks * idx + chunk]; dists_out[idx] += chunk_dists[pq_centerid]; } } } -// Need to replace calls to these functions with calls to vector& based functions above -void aggregate_coords(const unsigned *ids, const _u64 n_ids, const _u8 *all_coords, const _u64 ndims, _u8 *out) +// Need to replace calls to these functions with calls to vector& based +// functions above +void aggregate_coords(const uint32_t *ids, const size_t n_ids, const uint8_t *all_coords, const size_t ndims, + uint8_t *out) { - for (_u64 i = 0; i < n_ids; i++) + for (size_t i = 0; i < n_ids; i++) { - memcpy(out + i * ndims, all_coords + ids[i] * ndims, ndims * sizeof(_u8)); + memcpy(out + i * ndims, all_coords + ids[i] * ndims, ndims * sizeof(uint8_t)); } } -void pq_dist_lookup(const _u8 *pq_ids, const _u64 n_pts, const _u64 pq_nchunks, const float *pq_dists, float *dists_out) +void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists, + float *dists_out) { _mm_prefetch((char *)dists_out, _MM_HINT_T0); _mm_prefetch((char *)pq_ids, _MM_HINT_T0); _mm_prefetch((char *)(pq_ids + 64), _MM_HINT_T0); _mm_prefetch((char *)(pq_ids + 128), _MM_HINT_T0); memset(dists_out, 0, n_pts * sizeof(float)); - for (_u64 chunk = 0; chunk < pq_nchunks; chunk++) + for (size_t chunk = 0; chunk < pq_nchunks; chunk++) { const float *chunk_dists = pq_dists + 256 * chunk; if (chunk < pq_nchunks - 1) { _mm_prefetch((char *)(chunk_dists + 256), _MM_HINT_T0); } - for (_u64 idx = 0; idx < n_pts; idx++) + for (size_t idx = 0; idx < n_pts; idx++) { - _u8 pq_centerid = pq_ids[pq_nchunks * idx + chunk]; + uint8_t pq_centerid = pq_ids[pq_nchunks * idx + chunk]; dists_out[idx] += chunk_dists[pq_centerid]; } } @@ -342,8 +345,8 @@ void pq_dist_lookup(const _u8 *pq_ids, const _u64 n_pts, const _u64 pq_nchunks, // num_pq_chunks (if it divides dimension, else rounded) chunks, and runs // k-means in each chunk to compute the PQ pivots and stores in bin format in // file pq_pivots_path as a s num_centers*dim floating point binary file -int generate_pq_pivots(const float *const passed_train_data, size_t num_train, unsigned dim, unsigned num_centers, - unsigned num_pq_chunks, unsigned max_k_means_reps, std::string pq_pivots_path, +int generate_pq_pivots(const float *const passed_train_data, size_t num_train, uint32_t dim, uint32_t num_centers, + uint32_t num_pq_chunks, uint32_t max_k_means_reps, std::string pq_pivots_path, bool make_zero_mean) { if (num_pq_chunks > dim) @@ -440,7 +443,7 @@ int generate_pq_pivots(const float *const passed_train_data, size_t num_train, u for (uint32_t b = 0; b < num_pq_chunks; b++) { if (b > 0) - chunk_offsets.push_back(chunk_offsets[b - 1] + (unsigned)bin_to_dims[b - 1].size()); + chunk_offsets.push_back(chunk_offsets[b - 1] + (uint32_t)bin_to_dims[b - 1].size()); } chunk_offsets.push_back(dim); @@ -460,7 +463,7 @@ int generate_pq_pivots(const float *const passed_train_data, size_t num_train, u << chunk_offsets[i + 1] << ")" << std::endl; #pragma omp parallel for schedule(static, 65536) - for (int64_t j = 0; j < (_s64)num_train; j++) + for (int64_t j = 0; j < (int64_t)num_train; j++) { std::memcpy(cur_data.get() + j * cur_chunk_size, train_data.get() + j * dim + chunk_offsets[i], cur_chunk_size * sizeof(float)); @@ -486,7 +489,7 @@ int generate_pq_pivots(const float *const passed_train_data, size_t num_train, u diskann::save_bin(pq_pivots_path.c_str(), centroid.get(), (size_t)dim, 1, cumul_bytes[1]); cumul_bytes[3] = cumul_bytes[2] + diskann::save_bin(pq_pivots_path.c_str(), chunk_offsets.data(), chunk_offsets.size(), 1, cumul_bytes[2]); - diskann::save_bin<_u64>(pq_pivots_path.c_str(), cumul_bytes.data(), cumul_bytes.size(), 1, 0); + diskann::save_bin(pq_pivots_path.c_str(), cumul_bytes.data(), cumul_bytes.size(), 1, 0); diskann::cout << "Saved pq pivot data to " << pq_pivots_path << " of size " << cumul_bytes[cumul_bytes.size() - 1] << "B." << std::endl; @@ -494,8 +497,8 @@ int generate_pq_pivots(const float *const passed_train_data, size_t num_train, u return 0; } -int generate_opq_pivots(const float *passed_train_data, size_t num_train, unsigned dim, unsigned num_centers, - unsigned num_pq_chunks, std::string opq_pivots_path, bool make_zero_mean) +int generate_opq_pivots(const float *passed_train_data, size_t num_train, uint32_t dim, uint32_t num_centers, + uint32_t num_pq_chunks, std::string opq_pivots_path, bool make_zero_mean) { if (num_pq_chunks > dim) { @@ -591,7 +594,7 @@ int generate_opq_pivots(const float *passed_train_data, size_t num_train, unsign for (uint32_t b = 0; b < num_pq_chunks; b++) { if (b > 0) - chunk_offsets.push_back(chunk_offsets[b - 1] + (unsigned)bin_to_dims[b - 1].size()); + chunk_offsets.push_back(chunk_offsets[b - 1] + (uint32_t)bin_to_dims[b - 1].size()); } chunk_offsets.push_back(dim); @@ -599,10 +602,10 @@ int generate_opq_pivots(const float *passed_train_data, size_t num_train, unsign rotmat_tr.reset(new float[dim * dim]); std::memset(rotmat_tr.get(), 0, dim * dim * sizeof(float)); - for (_u32 d1 = 0; d1 < dim; d1++) + for (uint32_t d1 = 0; d1 < dim; d1++) *(rotmat_tr.get() + d1 * dim + d1) = 1; - for (_u32 rnd = 0; rnd < MAX_OPQ_ITERS; rnd++) + for (uint32_t rnd = 0; rnd < MAX_OPQ_ITERS; rnd++) { // rotate the training data using the current rotation matrix cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)num_train, (MKL_INT)dim, (MKL_INT)dim, 1.0f, @@ -624,7 +627,7 @@ int generate_opq_pivots(const float *passed_train_data, size_t num_train, unsign << chunk_offsets[i + 1] << ")" << std::endl; #pragma omp parallel for schedule(static, 65536) - for (int64_t j = 0; j < (_s64)num_train; j++) + for (int64_t j = 0; j < (int64_t)num_train; j++) { std::memcpy(cur_data.get() + j * cur_chunk_size, rotated_train_data.get() + j * dim + chunk_offsets[i], cur_chunk_size * sizeof(float)); @@ -644,7 +647,7 @@ int generate_opq_pivots(const float *passed_train_data, size_t num_train, unsign } } - _u32 num_lloyds_iters = 8; + uint32_t num_lloyds_iters = 8; kmeans::run_lloyds(cur_data.get(), num_train, cur_chunk_size, cur_pivot_data.get(), num_centers, num_lloyds_iters, NULL, closest_center.get()); @@ -654,10 +657,10 @@ int generate_opq_pivots(const float *passed_train_data, size_t num_train, unsign cur_pivot_data.get() + j * cur_chunk_size, cur_chunk_size * sizeof(float)); } - for (_u64 j = 0; j < num_train; j++) + for (size_t j = 0; j < num_train; j++) { std::memcpy(rotated_and_quantized_train_data.get() + j * dim + chunk_offsets[i], - cur_pivot_data.get() + (_u64)closest_center[j] * cur_chunk_size, + cur_pivot_data.get() + (size_t)closest_center[j] * cur_chunk_size, cur_chunk_size * sizeof(float)); } } @@ -670,9 +673,9 @@ int generate_opq_pivots(const float *passed_train_data, size_t num_train, unsign // compute the SVD of the correlation matrix to help determine the new // rotation matrix - _u32 errcode = - LAPACKE_sgesdd(LAPACK_ROW_MAJOR, 'A', (MKL_INT)dim, (MKL_INT)dim, correlation_matrix.get(), (MKL_INT)dim, - singular_values.get(), Umat.get(), (MKL_INT)dim, Vmat_T.get(), (MKL_INT)dim); + uint32_t errcode = (uint32_t)LAPACKE_sgesdd(LAPACK_ROW_MAJOR, 'A', (MKL_INT)dim, (MKL_INT)dim, + correlation_matrix.get(), (MKL_INT)dim, singular_values.get(), + Umat.get(), (MKL_INT)dim, Vmat_T.get(), (MKL_INT)dim); if (errcode > 0) { @@ -694,7 +697,7 @@ int generate_opq_pivots(const float *passed_train_data, size_t num_train, unsign diskann::save_bin(opq_pivots_path.c_str(), centroid.get(), (size_t)dim, 1, cumul_bytes[1]); cumul_bytes[3] = cumul_bytes[2] + diskann::save_bin(opq_pivots_path.c_str(), chunk_offsets.data(), chunk_offsets.size(), 1, cumul_bytes[2]); - diskann::save_bin<_u64>(opq_pivots_path.c_str(), cumul_bytes.data(), cumul_bytes.size(), 1, 0); + diskann::save_bin(opq_pivots_path.c_str(), cumul_bytes.data(), cumul_bytes.size(), 1, 0); diskann::cout << "Saved opq pivot data to " << opq_pivots_path << " of size " << cumul_bytes[cumul_bytes.size() - 1] << "B." << std::endl; @@ -711,13 +714,14 @@ int generate_opq_pivots(const float *passed_train_data, size_t num_train, unsign // If the numbber of centers is < 256, it stores as byte vector, else as // 4-byte vector in binary format. template -int generate_pq_data_from_pivots(const std::string data_file, unsigned num_centers, unsigned num_pq_chunks, - std::string pq_pivots_path, std::string pq_compressed_vectors_path, bool use_opq) +int generate_pq_data_from_pivots(const std::string &data_file, uint32_t num_centers, uint32_t num_pq_chunks, + const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path, + bool use_opq) { - _u64 read_blk_size = 64 * 1024 * 1024; + size_t read_blk_size = 64 * 1024 * 1024; cached_ifstream base_reader(data_file, read_blk_size); - _u32 npts32; - _u32 basedim32; + uint32_t npts32; + uint32_t basedim32; base_reader.read((char *)&npts32, sizeof(uint32_t)); base_reader.read((char *)&basedim32, sizeof(uint32_t)); size_t num_points = npts32; @@ -737,10 +741,10 @@ int generate_pq_data_from_pivots(const std::string data_file, unsigned num_cente } else { - _u64 nr, nc; - std::unique_ptr<_u64[]> file_offset_data; + size_t nr, nc; + std::unique_ptr file_offset_data; - diskann::load_bin<_u64>(pq_pivots_path.c_str(), file_offset_data, nr, nc, 0); + diskann::load_bin(pq_pivots_path.c_str(), file_offset_data, nr, nc, 0); if (nr != 4) { @@ -796,7 +800,7 @@ int generate_pq_data_from_pivots(const std::string data_file, unsigned num_cente } std::ofstream compressed_file_writer(pq_compressed_vectors_path, std::ios::binary); - _u32 num_pq_chunks_u32 = num_pq_chunks; + uint32_t num_pq_chunks_u32 = num_pq_chunks; compressed_file_writer.write((char *)&num_points, sizeof(uint32_t)); compressed_file_writer.write((char *)&num_pq_chunks_u32, sizeof(uint32_t)); @@ -812,8 +816,9 @@ int generate_pq_data_from_pivots(const std::string data_file, unsigned num_cente std::memset(block_inflated_base.get(), 0, block_size * dim * sizeof(float)); #endif - std::unique_ptr<_u32[]> block_compressed_base = std::make_unique<_u32[]>(block_size * (_u64)num_pq_chunks); - std::memset(block_compressed_base.get(), 0, block_size * (_u64)num_pq_chunks * sizeof(uint32_t)); + std::unique_ptr block_compressed_base = + std::make_unique(block_size * (size_t)num_pq_chunks); + std::memset(block_compressed_base.get(), 0, block_size * (size_t)num_pq_chunks * sizeof(uint32_t)); std::unique_ptr block_data_T = std::make_unique(block_size * dim); std::unique_ptr block_data_float = std::make_unique(block_size * dim); @@ -832,7 +837,7 @@ int generate_pq_data_from_pivots(const std::string data_file, unsigned num_cente diskann::cout << "Processing points [" << start_id << ", " << end_id << ").." << std::flush; - for (uint64_t p = 0; p < cur_blk_size; p++) + for (size_t p = 0; p < cur_blk_size; p++) { for (uint64_t d = 0; d < dim; d++) { @@ -840,7 +845,7 @@ int generate_pq_data_from_pivots(const std::string data_file, unsigned num_cente } } - for (uint64_t p = 0; p < cur_blk_size; p++) + for (size_t p = 0; p < cur_blk_size; p++) { for (uint64_t d = 0; d < dim; d++) { @@ -850,7 +855,8 @@ int generate_pq_data_from_pivots(const std::string data_file, unsigned num_cente if (use_opq) { - // rotate the current block with the trained rotation matrix before PQ + // rotate the current block with the trained rotation matrix before + // PQ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, (MKL_INT)cur_blk_size, (MKL_INT)dim, (MKL_INT)dim, 1.0f, block_data_float.get(), (MKL_INT)dim, rotmat_tr.get(), (MKL_INT)dim, 0.0f, block_data_tmp.get(), (MKL_INT)dim); @@ -868,14 +874,14 @@ int generate_pq_data_from_pivots(const std::string data_file, unsigned num_cente std::unique_ptr closest_center = std::make_unique(cur_blk_size); #pragma omp parallel for schedule(static, 8192) - for (int64_t j = 0; j < (_s64)cur_blk_size; j++) + for (int64_t j = 0; j < (int64_t)cur_blk_size; j++) { - for (uint64_t k = 0; k < cur_chunk_size; k++) + for (size_t k = 0; k < cur_chunk_size; k++) cur_data[j * cur_chunk_size + k] = block_data_float[j * dim + chunk_offsets[i] + k]; } #pragma omp parallel for schedule(static, 1) - for (int64_t j = 0; j < (_s64)num_centers; j++) + for (int64_t j = 0; j < (int64_t)num_centers; j++) { std::memcpy(cur_pivot_data.get() + j * cur_chunk_size, full_pivot_data.get() + j * dim + chunk_offsets[i], cur_chunk_size * sizeof(float)); @@ -885,11 +891,11 @@ int generate_pq_data_from_pivots(const std::string data_file, unsigned num_cente num_centers, 1, closest_center.get()); #pragma omp parallel for schedule(static, 8192) - for (int64_t j = 0; j < (_s64)cur_blk_size; j++) + for (int64_t j = 0; j < (int64_t)cur_blk_size; j++) { block_compressed_base[j * num_pq_chunks + i] = closest_center[j]; #ifdef SAVE_INFLATED_PQ - for (uint64_t k = 0; k < cur_chunk_size; k++) + for (size_t k = 0; k < cur_chunk_size; k++) block_inflated_base[j * dim + chunk_offsets[i] + k] = cur_pivot_data[closest_center[j] * cur_chunk_size + k] + centroid[chunk_offsets[i] + k]; #endif @@ -926,8 +932,8 @@ int generate_pq_data_from_pivots(const std::string data_file, unsigned num_cente } template -void generate_disk_quantized_data(const std::string data_file_to_use, const std::string disk_pq_pivots_path, - const std::string disk_pq_compressed_vectors_path, diskann::Metric compareMetric, +void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, + const std::string &disk_pq_compressed_vectors_path, diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims) { size_t train_size, train_dim; @@ -944,97 +950,108 @@ void generate_disk_quantized_data(const std::string data_file_to_use, const std: generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, 256, (uint32_t)disk_pq_dims, NUM_KMEANS_REPS_PQ, disk_pq_pivots_path, false); if (compareMetric == diskann::Metric::INNER_PRODUCT) - generate_pq_data_from_pivots(data_file_to_use.c_str(), 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path, + generate_pq_data_from_pivots(data_file_to_use, 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path, disk_pq_compressed_vectors_path); else - generate_pq_data_from_pivots(data_file_to_use.c_str(), 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path, + generate_pq_data_from_pivots(data_file_to_use, 256, (uint32_t)disk_pq_dims, disk_pq_pivots_path, disk_pq_compressed_vectors_path); delete[] train_data; } template -void generate_quantized_data(const std::string data_file_to_use, const std::string pq_pivots_path, - const std::string pq_compressed_vectors_path, diskann::Metric compareMetric, - const double p_val, const size_t num_pq_chunks, const bool use_opq) +void generate_quantized_data(const std::string &data_file_to_use, const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, diskann::Metric compareMetric, + const double p_val, const size_t num_pq_chunks, const bool use_opq, + const std::string &codebook_prefix) { size_t train_size, train_dim; float *train_data; + if (!file_exists(codebook_prefix)) + { + // instantiates train_data with random sample updates train_size + gen_random_slice(data_file_to_use.c_str(), p_val, train_data, train_size, train_dim); + diskann::cout << "Training data with " << train_size << " samples loaded." << std::endl; - // instantiates train_data with random sample updates train_size - gen_random_slice(data_file_to_use.c_str(), p_val, train_data, train_size, train_dim); - diskann::cout << "Training data with " << train_size << " samples loaded." << std::endl; - - bool make_zero_mean = true; - if (compareMetric == diskann::Metric::INNER_PRODUCT) - make_zero_mean = false; - if (use_opq) // we also do not center the data for OPQ - make_zero_mean = false; + bool make_zero_mean = true; + if (compareMetric == diskann::Metric::INNER_PRODUCT) + make_zero_mean = false; + if (use_opq) // we also do not center the data for OPQ + make_zero_mean = false; - if (!use_opq) - { - generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, - NUM_KMEANS_REPS_PQ, pq_pivots_path, make_zero_mean); + if (!use_opq) + { + generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, + NUM_KMEANS_REPS_PQ, pq_pivots_path, make_zero_mean); + } + else + { + generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, + pq_pivots_path, make_zero_mean); + } + delete[] train_data; } else { - generate_opq_pivots(train_data, train_size, (_u32)train_dim, NUM_PQ_CENTROIDS, (_u32)num_pq_chunks, - pq_pivots_path, make_zero_mean); + diskann::cout << "Skip Training with predefined pivots in: " << pq_pivots_path << std::endl; } - generate_pq_data_from_pivots(data_file_to_use.c_str(), NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, pq_pivots_path, + generate_pq_data_from_pivots(data_file_to_use, NUM_PQ_CENTROIDS, (uint32_t)num_pq_chunks, pq_pivots_path, pq_compressed_vectors_path, use_opq); - - delete[] train_data; } // Instantations of supported templates -template DISKANN_DLLEXPORT int generate_pq_data_from_pivots(const std::string data_file, unsigned num_centers, - unsigned num_pq_chunks, std::string pq_pivots_path, - std::string pq_compressed_vectors_path, +template DISKANN_DLLEXPORT int generate_pq_data_from_pivots(const std::string &data_file, uint32_t num_centers, + uint32_t num_pq_chunks, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, bool use_opq); -template DISKANN_DLLEXPORT int generate_pq_data_from_pivots(const std::string data_file, unsigned num_centers, - unsigned num_pq_chunks, std::string pq_pivots_path, - std::string pq_compressed_vectors_path, +template DISKANN_DLLEXPORT int generate_pq_data_from_pivots(const std::string &data_file, uint32_t num_centers, + uint32_t num_pq_chunks, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, bool use_opq); -template DISKANN_DLLEXPORT int generate_pq_data_from_pivots(const std::string data_file, unsigned num_centers, - unsigned num_pq_chunks, std::string pq_pivots_path, - std::string pq_compressed_vectors_path, +template DISKANN_DLLEXPORT int generate_pq_data_from_pivots(const std::string &data_file, uint32_t num_centers, + uint32_t num_pq_chunks, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, bool use_opq); -template DISKANN_DLLEXPORT void generate_disk_quantized_data(const std::string data_file_to_use, - const std::string disk_pq_pivots_path, - const std::string disk_pq_compressed_vectors_path, +template DISKANN_DLLEXPORT void generate_disk_quantized_data(const std::string &data_file_to_use, + const std::string &disk_pq_pivots_path, + const std::string &disk_pq_compressed_vectors_path, diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims); -template DISKANN_DLLEXPORT void generate_disk_quantized_data(const std::string data_file_to_use, - const std::string disk_pq_pivots_path, - const std::string disk_pq_compressed_vectors_path, - diskann::Metric compareMetric, const double p_val, - size_t &disk_pq_dims); +template DISKANN_DLLEXPORT void generate_disk_quantized_data( + const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, + const std::string &disk_pq_compressed_vectors_path, diskann::Metric compareMetric, const double p_val, + size_t &disk_pq_dims); -template DISKANN_DLLEXPORT void generate_disk_quantized_data(const std::string data_file_to_use, - const std::string disk_pq_pivots_path, - const std::string disk_pq_compressed_vectors_path, +template DISKANN_DLLEXPORT void generate_disk_quantized_data(const std::string &data_file_to_use, + const std::string &disk_pq_pivots_path, + const std::string &disk_pq_compressed_vectors_path, diskann::Metric compareMetric, const double p_val, size_t &disk_pq_dims); -template DISKANN_DLLEXPORT void generate_quantized_data(const std::string data_file_to_use, - const std::string pq_pivots_path, - const std::string pq_compressed_vectors_path, +template DISKANN_DLLEXPORT void generate_quantized_data(const std::string &data_file_to_use, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, diskann::Metric compareMetric, const double p_val, - const size_t num_pq_chunks, const bool use_opq); + const size_t num_pq_chunks, const bool use_opq, + const std::string &codebook_prefix); -template DISKANN_DLLEXPORT void generate_quantized_data(const std::string data_file_to_use, - const std::string pq_pivots_path, - const std::string pq_compressed_vectors_path, +template DISKANN_DLLEXPORT void generate_quantized_data(const std::string &data_file_to_use, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, diskann::Metric compareMetric, const double p_val, - const size_t num_pq_chunks, const bool use_opq); + const size_t num_pq_chunks, const bool use_opq, + const std::string &codebook_prefix); -template DISKANN_DLLEXPORT void generate_quantized_data(const std::string data_file_to_use, - const std::string pq_pivots_path, - const std::string pq_compressed_vectors_path, +template DISKANN_DLLEXPORT void generate_quantized_data(const std::string &data_file_to_use, + const std::string &pq_pivots_path, + const std::string &pq_compressed_vectors_path, diskann::Metric compareMetric, const double p_val, - const size_t num_pq_chunks, const bool use_opq); + const size_t num_pq_chunks, const bool use_opq, + const std::string &codebook_prefix); } // namespace diskann diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 050d20546..943fed44c 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -14,34 +14,35 @@ #include "linux_aligned_file_reader.h" #endif -#define READ_U64(stream, val) stream.read((char *)&val, sizeof(_u64)) -#define READ_U32(stream, val) stream.read((char *)&val, sizeof(_u32)) +#define READ_U64(stream, val) stream.read((char *)&val, sizeof(uint64_t)) +#define READ_U32(stream, val) stream.read((char *)&val, sizeof(uint32_t)) #define READ_UNSIGNED(stream, val) stream.read((char *)&val, sizeof(unsigned)) // sector # on disk where node_id is present with in the graph part -#define NODE_SECTOR_NO(node_id) (((_u64)(node_id)) / nnodes_per_sector + 1) +#define NODE_SECTOR_NO(node_id) (((uint64_t)(node_id)) / nnodes_per_sector + 1) // obtains region of sector containing node -#define OFFSET_TO_NODE(sector_buf, node_id) ((char *)sector_buf + (((_u64)node_id) % nnodes_per_sector) * max_node_len) +#define OFFSET_TO_NODE(sector_buf, node_id) \ + ((char *)sector_buf + (((uint64_t)node_id) % nnodes_per_sector) * max_node_len) -// returns region of `node_buf` containing [NNBRS][NBR_ID(_u32)] +// returns region of `node_buf` containing [NNBRS][NBR_ID(uint32_t)] #define OFFSET_TO_NODE_NHOOD(node_buf) (unsigned *)((char *)node_buf + disk_bytes_per_point) // returns region of `node_buf` containing [COORD(T)] #define OFFSET_TO_NODE_COORDS(node_buf) (T *)(node_buf) // sector # beyond the end of graph where data for id is present for reordering -#define VECTOR_SECTOR_NO(id) (((_u64)(id)) / nvecs_per_sector + reorder_data_start_sector) +#define VECTOR_SECTOR_NO(id) (((uint64_t)(id)) / nvecs_per_sector + reorder_data_start_sector) // sector # beyond the end of graph where data for id is present for reordering -#define VECTOR_SECTOR_OFFSET(id) ((((_u64)(id)) % nvecs_per_sector) * data_dim * sizeof(float)) +#define VECTOR_SECTOR_OFFSET(id) ((((uint64_t)(id)) % nvecs_per_sector) * data_dim * sizeof(float)) namespace diskann { template PQFlashIndex::PQFlashIndex(std::shared_ptr &fileReader, diskann::Metric m) - : reader(fileReader), metric(m) + : reader(fileReader), metric(m), thread_data(nullptr) { if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) { @@ -102,12 +103,12 @@ template PQFlashIndex::~PQFlashIndex() } template -void PQFlashIndex::setup_thread_data(_u64 nthreads, _u64 visited_reserve) +void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visited_reserve) { diskann::cout << "Setting up thread-specific contexts for nthreads: " << nthreads << std::endl; // omp parallel for to generate unique thread IDs #pragma omp parallel for num_threads((int)nthreads) - for (_s64 thread = 0; thread < (_s64)nthreads; thread++) + for (int64_t thread = 0; thread < (int64_t)nthreads; thread++) { #pragma omp critical { @@ -123,30 +124,30 @@ void PQFlashIndex::setup_thread_data(_u64 nthreads, _u64 visited_rese template void PQFlashIndex::load_cache_list(std::vector &node_list) { diskann::cout << "Loading the cache list into memory.." << std::flush; - _u64 num_cached_nodes = node_list.size(); + size_t num_cached_nodes = node_list.size(); // borrow thread data ScratchStoreManager> manager(this->thread_data); auto this_thread_data = manager.scratch_space(); IOContext &ctx = this_thread_data->ctx; - nhood_cache_buf = new unsigned[num_cached_nodes * (max_degree + 1)]; + nhood_cache_buf = new uint32_t[num_cached_nodes * (max_degree + 1)]; memset(nhood_cache_buf, 0, num_cached_nodes * (max_degree + 1)); - _u64 coord_cache_buf_len = num_cached_nodes * aligned_dim; + size_t coord_cache_buf_len = num_cached_nodes * aligned_dim; diskann::alloc_aligned((void **)&coord_cache_buf, coord_cache_buf_len * sizeof(T), 8 * sizeof(T)); memset(coord_cache_buf, 0, coord_cache_buf_len * sizeof(T)); size_t BLOCK_SIZE = 8; size_t num_blocks = DIV_ROUND_UP(num_cached_nodes, BLOCK_SIZE); - for (_u64 block = 0; block < num_blocks; block++) + for (size_t block = 0; block < num_blocks; block++) { - _u64 start_idx = block * BLOCK_SIZE; - _u64 end_idx = (std::min)(num_cached_nodes, (block + 1) * BLOCK_SIZE); + size_t start_idx = block * BLOCK_SIZE; + size_t end_idx = (std::min)(num_cached_nodes, (block + 1) * BLOCK_SIZE); std::vector read_reqs; - std::vector> nhoods; - for (_u64 node_idx = start_idx; node_idx < end_idx; node_idx++) + std::vector> nhoods; + for (size_t node_idx = start_idx; node_idx < end_idx; node_idx++) { AlignedRead read; char *buf = nullptr; @@ -160,8 +161,8 @@ template void PQFlashIndex::load_cache_ reader->read(read_reqs, ctx); - _u64 node_idx = start_idx; - for (_u32 i = 0; i < read_reqs.size(); i++) + size_t node_idx = start_idx; + for (uint32_t i = 0; i < read_reqs.size(); i++) { #if defined(_WINDOWS) && defined(USE_BING_INFRA) // this block is to handle failed reads in // production settings @@ -178,14 +179,14 @@ template void PQFlashIndex::load_cache_ coord_cache.insert(std::make_pair(nhood.first, cached_coords)); // insert node nhood into nhood_cache - unsigned *node_nhood = OFFSET_TO_NODE_NHOOD(node_buf); + uint32_t *node_nhood = OFFSET_TO_NODE_NHOOD(node_buf); auto nnbrs = *node_nhood; - unsigned *nbrs = node_nhood + 1; - std::pair<_u32, unsigned *> cnhood; + uint32_t *nbrs = node_nhood + 1; + std::pair cnhood; cnhood.first = nnbrs; cnhood.second = nhood_cache_buf + node_idx * (max_degree + 1); - memcpy(cnhood.second, nbrs, nnbrs * sizeof(unsigned)); + memcpy(cnhood.second, nbrs, nnbrs * sizeof(uint32_t)); nhood_cache.insert(std::make_pair(nhood.first, cnhood)); aligned_free(nhood.second); node_idx++; @@ -197,28 +198,39 @@ template void PQFlashIndex::load_cache_ #ifdef EXEC_ENV_OLS template void PQFlashIndex::generate_cache_list_from_sample_queries(MemoryMappedFiles &files, std::string sample_bin, - _u64 l_search, _u64 beamwidth, - _u64 num_nodes_to_cache, uint32_t nthreads, + uint64_t l_search, uint64_t beamwidth, + uint64_t num_nodes_to_cache, uint32_t nthreads, std::vector &node_list) { #else template -void PQFlashIndex::generate_cache_list_from_sample_queries(std::string sample_bin, _u64 l_search, - _u64 beamwidth, _u64 num_nodes_to_cache, +void PQFlashIndex::generate_cache_list_from_sample_queries(std::string sample_bin, uint64_t l_search, + uint64_t beamwidth, uint64_t num_nodes_to_cache, uint32_t nthreads, std::vector &node_list) { #endif + if (num_nodes_to_cache >= this->num_points) + { + // for small num_points and big num_nodes_to_cache, use below way to get the node_list quickly + node_list.resize(this->num_points); + for (uint32_t i = 0; i < this->num_points; ++i) + { + node_list[i] = i; + } + return; + } + this->count_visited_nodes = true; this->node_visit_counter.clear(); this->node_visit_counter.resize(this->num_points); - for (_u32 i = 0; i < node_visit_counter.size(); i++) + for (uint32_t i = 0; i < node_visit_counter.size(); i++) { this->node_visit_counter[i].first = i; this->node_visit_counter[i].second = 0; } - _u64 sample_num, sample_dim, sample_aligned_dim; + uint64_t sample_num, sample_dim, sample_aligned_dim; T *samples; #ifdef EXEC_ENV_OLS @@ -241,19 +253,34 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin std::vector tmp_result_ids_64(sample_num, 0); std::vector tmp_result_dists(sample_num, 0); + bool filtered_search = false; + std::vector random_query_filters(sample_num); + if (_filter_to_medoid_ids.size() != 0) + { + filtered_search = true; + generate_random_labels(random_query_filters, (uint32_t)sample_num, nthreads); + } + #pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) - for (_s64 i = 0; i < (int64_t)sample_num; i++) + for (int64_t i = 0; i < (int64_t)sample_num; i++) { - cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + (i * 1), - tmp_result_dists.data() + (i * 1), beamwidth); + auto &label_for_search = random_query_filters[i]; + // run a search on the sample query with a random label (sampled from base label distribution), and it will + // concurrently update the node_visit_counter to track most visited nodes. The last false is to not use the + // "use_reorder_data" option which enables a final reranking if the disk index itself contains only PQ data. + cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + i, + tmp_result_dists.data() + i, beamwidth, filtered_search, label_for_search, false); } std::sort(this->node_visit_counter.begin(), node_visit_counter.end(), - [](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) { return left.second > right.second; }); + [](std::pair &left, std::pair &right) { + return left.second > right.second; + }); node_list.clear(); node_list.shrink_to_fit(); + num_nodes_to_cache = std::min(num_nodes_to_cache, this->node_visit_counter.size()); node_list.reserve(num_nodes_to_cache); - for (_u64 i = 0; i < num_nodes_to_cache; i++) + for (uint64_t i = 0; i < num_nodes_to_cache; i++) { node_list.push_back(this->node_visit_counter[i].first); } @@ -263,7 +290,7 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin } template -void PQFlashIndex::cache_bfs_levels(_u64 num_nodes_to_cache, std::vector &node_list, +void PQFlashIndex::cache_bfs_levels(uint64_t num_nodes_to_cache, std::vector &node_list, const bool shuffle) { std::random_device rng; @@ -272,7 +299,7 @@ void PQFlashIndex::cache_bfs_levels(_u64 num_nodes_to_cache, std::vec tsl::robin_set node_set; // Do not cache more than 10% of the nodes in the index - _u64 tenp_nodes = (_u64)(std::round(this->num_points * 0.1)); + uint64_t tenp_nodes = (uint64_t)(std::round(this->num_points * 0.1)); if (num_nodes_to_cache > tenp_nodes) { diskann::cout << "Reducing nodes to cache from: " << num_nodes_to_cache << " to: " << tenp_nodes @@ -286,16 +313,31 @@ void PQFlashIndex::cache_bfs_levels(_u64 num_nodes_to_cache, std::vec auto this_thread_data = manager.scratch_space(); IOContext &ctx = this_thread_data->ctx; - std::unique_ptr> cur_level, prev_level; - cur_level = std::make_unique>(); - prev_level = std::make_unique>(); + std::unique_ptr> cur_level, prev_level; + cur_level = std::make_unique>(); + prev_level = std::make_unique>(); - for (_u64 miter = 0; miter < num_medoids; miter++) + for (uint64_t miter = 0; miter < num_medoids && cur_level->size() < num_nodes_to_cache; miter++) { cur_level->insert(medoids[miter]); } - _u64 lvl = 1; + if ((_filter_to_medoid_ids.size() > 0) && (cur_level->size() < num_nodes_to_cache)) + { + for (auto &x : _filter_to_medoid_ids) + { + for (auto &y : x.second) + { + cur_level->insert(y); + if (cur_level->size() == num_nodes_to_cache) + break; + } + if (cur_level->size() == num_nodes_to_cache) + break; + } + } + + uint64_t lvl = 1; uint64_t prev_node_set_size = 0; while ((node_set.size() + cur_level->size() < num_nodes_to_cache) && cur_level->size() != 0) { @@ -304,9 +346,9 @@ void PQFlashIndex::cache_bfs_levels(_u64 num_nodes_to_cache, std::vec // clear cur_level cur_level->clear(); - std::vector nodes_to_expand; + std::vector nodes_to_expand; - for (const unsigned &id : *prev_level) + for (const uint32_t &id : *prev_level) { if (node_set.find(id) != node_set.end()) { @@ -332,7 +374,7 @@ void PQFlashIndex::cache_bfs_levels(_u64 num_nodes_to_cache, std::vec size_t start = block * BLOCK_SIZE; size_t end = (std::min)((block + 1) * BLOCK_SIZE, nodes_to_expand.size()); std::vector read_reqs; - std::vector> nhoods; + std::vector> nhoods; for (size_t cur_pt = start; cur_pt < end; cur_pt++) { char *buf = nullptr; @@ -349,7 +391,7 @@ void PQFlashIndex::cache_bfs_levels(_u64 num_nodes_to_cache, std::vec reader->read(read_reqs, ctx); // process each nhood buf - for (_u32 i = 0; i < read_reqs.size(); i++) + for (uint32_t i = 0; i < read_reqs.size(); i++) { #if defined(_WINDOWS) && defined(USE_BING_INFRA) // this block is to handle read failures in // production settings @@ -362,11 +404,11 @@ void PQFlashIndex::cache_bfs_levels(_u64 num_nodes_to_cache, std::vec // insert node coord into coord_cache char *node_buf = OFFSET_TO_NODE(nhood.second, nhood.first); - unsigned *node_nhood = OFFSET_TO_NODE_NHOOD(node_buf); - _u64 nnbrs = (_u64)*node_nhood; - unsigned *nbrs = node_nhood + 1; + uint32_t *node_nhood = OFFSET_TO_NODE_NHOOD(node_buf); + uint64_t nnbrs = (uint64_t)*node_nhood; + uint32_t *nbrs = node_nhood + 1; // explore next level - for (_u64 j = 0; j < nnbrs && !finish_flag; j++) + for (uint64_t j = 0; j < nnbrs && !finish_flag; j++) { if (node_set.find(nbrs[j]) == node_set.end()) { @@ -382,7 +424,7 @@ void PQFlashIndex::cache_bfs_levels(_u64 num_nodes_to_cache, std::vec } diskann::cout << ". #nodes: " << node_set.size() - prev_node_set_size - << ", #nodes thus far: " << node_list.size() << std::endl; + << ", #nodes thus far: " << node_set.size() << std::endl; prev_node_set_size = node_set.size(); lvl++; } @@ -441,7 +483,7 @@ template void PQFlashIndex::use_medoids } else { - disk_pq_table.inflate_vector((_u8 *)medoid_coords, (centroid_data + cur_m * aligned_dim)); + disk_pq_table.inflate_vector((uint8_t *)medoid_coords, (centroid_data + cur_m * aligned_dim)); } aligned_free(medoid_buf); @@ -453,7 +495,7 @@ template inline int32_t PQFlashIndex::get_filter_number(const LabelT &filter_label) { int idx = -1; - for (_u32 i = 0; i < _filter_list.size(); i++) + for (uint32_t i = 0; i < _filter_list.size(); i++) { if (_filter_list[i] == filter_label) { @@ -464,6 +506,41 @@ inline int32_t PQFlashIndex::get_filter_number(const LabelT &filter_l return idx; } +template +void PQFlashIndex::generate_random_labels(std::vector &labels, const uint32_t num_labels, + const uint32_t nthreads) +{ + std::random_device rd; + labels.clear(); + labels.resize(num_labels); + + uint64_t num_total_labels = + _pts_to_label_offsets[num_points - 1] + _pts_to_labels[_pts_to_label_offsets[num_points - 1]]; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, num_total_labels); + + tsl::robin_set skip_locs; + for (uint32_t i = 0; i < num_points; i++) + { + skip_locs.insert(_pts_to_label_offsets[i]); + } + +#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) + for (int64_t i = 0; i < num_labels; i++) + { + bool found_flag = false; + while (!found_flag) + { + uint64_t rnd_loc = dis(gen); + if (skip_locs.find(rnd_loc) == skip_locs.end()) + { + found_flag = true; + labels[i] = _filter_list[_pts_to_labels[rnd_loc]]; + } + } + } +} + template std::unordered_map PQFlashIndex::load_label_map(const std::string &labels_map_file) { @@ -478,7 +555,7 @@ std::unordered_map PQFlashIndex::load_label_map( getline(iss, token, '\t'); label_str = token; getline(iss, token, '\t'); - token_as_num = std::stoul(token); + token_as_num = (LabelT)std::stoul(token); string_to_int_mp[label_str] = token_as_num; } return string_to_int_mp; @@ -502,7 +579,8 @@ LabelT PQFlashIndex::get_converted_label(const std::string &filter_la } template -void PQFlashIndex::get_label_file_metadata(std::string map_file, _u32 &num_pts, _u32 &num_total_labels) +void PQFlashIndex::get_label_file_metadata(std::string map_file, uint32_t &num_pts, + uint32_t &num_total_labels) { std::ifstream infile(map_file); std::string line, token; @@ -527,12 +605,12 @@ void PQFlashIndex::get_label_file_metadata(std::string map_file, _u32 } template -inline bool PQFlashIndex::point_has_label(_u32 point_id, _u32 label_id) +inline bool PQFlashIndex::point_has_label(uint32_t point_id, uint32_t label_id) { - _u32 start_vec = _pts_to_label_offsets[point_id]; - _u32 num_lbls = _pts_to_labels[start_vec]; + uint32_t start_vec = _pts_to_label_offsets[point_id]; + uint32_t num_lbls = _pts_to_labels[start_vec]; bool ret_val = false; - for (_u32 i = 0; i < num_lbls; i++) + for (uint32_t i = 0; i < num_lbls; i++) { if (_pts_to_labels[start_vec + 1 + i] == label_id) { @@ -553,23 +631,23 @@ void PQFlashIndex::parse_label_file(const std::string &label_file, si } std::string line, token; - _u32 line_cnt = 0; + uint32_t line_cnt = 0; - _u32 num_pts_in_label_file; - _u32 num_total_labels; + uint32_t num_pts_in_label_file; + uint32_t num_total_labels; get_label_file_metadata(label_file, num_pts_in_label_file, num_total_labels); - _pts_to_label_offsets = new _u32[num_pts_in_label_file]; - _pts_to_labels = new _u32[num_pts_in_label_file + num_total_labels]; - _u32 counter = 0; + _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; + _pts_to_labels = new uint32_t[num_pts_in_label_file + num_total_labels]; + uint32_t counter = 0; while (std::getline(infile, line)) { std::istringstream iss(line); - std::vector<_u32> lbls(0); + std::vector lbls(0); _pts_to_label_offsets[line_cnt] = counter; - _u32 &num_lbls_in_cur_pt = _pts_to_labels[counter]; + uint32_t &num_lbls_in_cur_pt = _pts_to_labels[counter]; num_lbls_in_cur_pt = 0; counter++; getline(iss, token, '\t'); @@ -578,7 +656,7 @@ void PQFlashIndex::parse_label_file(const std::string &label_file, si { token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = std::stoul(token); + LabelT token_as_num = (LabelT)std::stoul(token); if (_labels.find(token_as_num) == _labels.end()) { _filter_list.emplace_back(token_as_num); @@ -610,13 +688,12 @@ template void PQFlashIndex::set_univers int32_t temp_filter_num = get_filter_number(label); if (temp_filter_num == -1) { - diskann::cout << "Error, could not find universal label. Exitting." << std::endl; - exit(-1); + diskann::cout << "Error, could not find universal label." << std::endl; } else { _use_universal_label = true; - _universal_filter_num = (_u32)temp_filter_num; + _universal_filter_num = (uint32_t)temp_filter_num; } } @@ -631,12 +708,19 @@ template int PQFlashIndex::load(uint32_ std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin"; std::string pq_compressed_vectors = std::string(index_prefix) + "_pq_compressed.bin"; std::string disk_index_file = std::string(index_prefix) + "_disk.index"; + std::string labels_file = std::string(index_prefix) + "_labels.txt"; + std::string labels_to_medoids = std::string(index_prefix) + "_labels_to_medoids.txt"; + std::string labels_map_file = std::string(index_prefix) + "_labels_map.txt"; + std::string univ_label_file = std::string(index_prefix) + "_universal_label.txt"; + #ifdef EXEC_ENV_OLS return load_from_separate_paths(files, num_threads, disk_index_file.c_str(), pq_table_bin.c_str(), - pq_compressed_vectors.c_str()); + pq_compressed_vectors.c_str(), labels_file.c_str(), labels_to_medoids.c_str(), + labels_map_file.c_str(), univ_label_file.c_str()); #else return load_from_separate_paths(num_threads, disk_index_file.c_str(), pq_table_bin.c_str(), - pq_compressed_vectors.c_str()); + pq_compressed_vectors.c_str(), labels_file.c_str(), labels_to_medoids.c_str(), + labels_map_file.c_str(), univ_label_file.c_str()); #endif } @@ -644,12 +728,16 @@ template int PQFlashIndex::load(uint32_ template int PQFlashIndex::load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_filepath, const char *pivots_filepath, - const char *compressed_filepath) + const char *compressed_filepath, + const char* labels_filepath, const char* labels_to_medoids_filepath, + const char* labels_map_filepath, const char* unv_label_filepath) { #else template int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, const char *index_filepath, - const char *pivots_filepath, const char *compressed_filepath) + const char *pivots_filepath, const char *compressed_filepath, + const char* labels_filepath, const char* labels_to_medoids_filepath, + const char* labels_map_filepath, const char* unv_label_filepath) { #endif std::string pq_table_bin = pivots_filepath; @@ -658,10 +746,10 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::string medoids_file = std::string(disk_index_file) + "_medoids.bin"; std::string centroids_file = std::string(disk_index_file) + "_centroids.bin"; - std::string labels_file = std ::string(disk_index_file) + "_labels.txt"; - std::string labels_to_medoids = std ::string(disk_index_file) + "_labels_to_medoids.txt"; + std::string labels_file = (labels_filepath == nullptr ? "" : labels_filepath); + std::string labels_to_medoids = (labels_to_medoids_filepath == nullptr ? "" : labels_to_medoids_filepath); std::string dummy_map_file = std ::string(disk_index_file) + "_dummy_map.txt"; - std::string labels_map_file = std ::string(disk_index_file) + "_labels_map.txt"; + std::string labels_map_file = (labels_map_filepath == nullptr ? "" : labels_map_filepath); size_t num_pts_in_label_file = 0; size_t pq_file_dim, pq_file_num_centroids; @@ -689,9 +777,9 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons size_t npts_u64, nchunks_u64; #ifdef EXEC_ENV_OLS - diskann::load_bin<_u8>(files, pq_compressed_vectors, this->data, npts_u64, nchunks_u64); + diskann::load_bin(files, pq_compressed_vectors, this->data, npts_u64, nchunks_u64); #else - diskann::load_bin<_u8>(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); + diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); #endif this->num_points = npts_u64; @@ -707,24 +795,24 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons assert(medoid_stream.is_open()); std::string line, token; - _filter_to_medoid_id.clear(); + _filter_to_medoid_ids.clear(); try { while (std::getline(medoid_stream, line)) { std::istringstream iss(line); - _u32 cnt = 0; - _u32 medoid; + uint32_t cnt = 0; + std::vector medoids; LabelT label; while (std::getline(iss, token, ',')) { if (cnt == 0) - label = std::stoul(token); + label = (LabelT)std::stoul(token); else - medoid = (_u32)stoul(token); + medoids.push_back((uint32_t)stoul(token)); cnt++; } - _filter_to_medoid_id[label] = medoid; + _filter_to_medoid_ids[label].swap(medoids); } } catch (std::system_error &e) @@ -732,7 +820,8 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, __LINE__); } } - std::string univ_label_file = std ::string(disk_index_file) + "_universal_label.txt"; + + std::string univ_label_file = (unv_label_filepath == nullptr ? "" : unv_label_filepath); if (file_exists(univ_label_file)) { std::ifstream universal_label_reader(univ_label_file); @@ -740,7 +829,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::string univ_label; universal_label_reader >> univ_label; universal_label_reader.close(); - LabelT label_as_num = std::stoul(univ_label); + LabelT label_as_num = (LabelT)std::stoul(univ_label); set_universal_label(label_as_num); } if (file_exists(dummy_map_file)) @@ -752,15 +841,15 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons while (std::getline(dummy_map_stream, line)) { std::istringstream iss(line); - _u32 cnt = 0; - _u32 dummy_id; - _u32 real_id; + uint32_t cnt = 0; + uint32_t dummy_id; + uint32_t real_id; while (std::getline(iss, token, ',')) { if (cnt == 0) - dummy_id = (_u32)stoul(token); + dummy_id = (uint32_t)stoul(token); else - real_id = (_u32)stoul(token); + real_id = (uint32_t)stoul(token); cnt++; } _dummy_pts.insert(dummy_id); @@ -768,7 +857,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons _dummy_to_real_map[dummy_id] = real_id; if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) - _real_to_dummy_map[real_id] = std::vector<_u32>(); + _real_to_dummy_map[real_id] = std::vector(); _real_to_dummy_map[real_id].emplace_back(dummy_id); } @@ -809,7 +898,8 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons disk_pq_table.load_pq_centroid_bin(disk_pq_pivots_path.c_str(), 0); #endif disk_pq_n_chunks = disk_pq_table.get_num_chunks(); - disk_bytes_per_point = disk_pq_n_chunks * sizeof(_u8); // revising disk_bytes_per_point since DISK PQ is used. + disk_bytes_per_point = + disk_pq_n_chunks * sizeof(uint8_t); // revising disk_bytes_per_point since DISK PQ is used. diskann::cout << "Disk index uses PQ data compressed down to " << disk_pq_n_chunks << " bytes per point." << std::endl; } @@ -832,13 +922,13 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::ifstream index_metadata(disk_index_file, std::ios::binary); #endif - _u32 nr, nc; // metadata itself is stored as bin format (nr is number of - // metadata, nc should be 1) + uint32_t nr, nc; // metadata itself is stored as bin format (nr is number of + // metadata, nc should be 1) READ_U32(index_metadata, nr); READ_U32(index_metadata, nc); - _u64 disk_nnodes; - _u64 disk_ndims; // can be disk PQ dim if disk_PQ is set to true + uint64_t disk_nnodes; + uint64_t disk_ndims; // can be disk PQ dim if disk_PQ is set to true READ_U64(index_metadata, disk_nnodes); READ_U64(index_metadata, disk_ndims); @@ -854,7 +944,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons READ_U64(index_metadata, medoid_id_on_file); READ_U64(index_metadata, max_node_len); READ_U64(index_metadata, nnodes_per_sector); - max_degree = ((max_node_len - disk_bytes_per_point) / sizeof(unsigned)) - 1; + max_degree = ((max_node_len - disk_bytes_per_point) / sizeof(uint32_t)) - 1; if (max_degree > MAX_GRAPH_DEGREE) { @@ -867,7 +957,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons // setting up concept of frozen points in disk index for streaming-DiskANN READ_U64(index_metadata, this->num_frozen_points); - _u64 file_frozen_id; + uint64_t file_frozen_id; READ_U64(index_metadata, file_frozen_id); if (this->num_frozen_points == 1) this->frozen_location = file_frozen_id; @@ -882,8 +972,9 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons { if (this->use_disk_index_pq == false) { - throw ANNException("Reordering is designed for used with disk PQ compression option", -1, __FUNCSIG__, - __FILE__, __LINE__); + throw ANNException("Reordering is designed for used with disk PQ " + "compression option", + -1, __FUNCSIG__, __FILE__, __LINE__); } READ_U64(index_metadata, this->reorder_data_start_sector); READ_U64(index_metadata, this->ndims_reorder_vecs); @@ -954,7 +1045,8 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons if (aligned_tmp_dim != aligned_dim || num_centroids != num_medoids) { std::stringstream stream; - stream << "Error loading centroids data file. Expected bin format of " + stream << "Error loading centroids data file. Expected bin format " + "of " "m times data_dim vector of float, where m is number of " "medoids " "in medoids file."; @@ -967,7 +1059,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons { num_medoids = 1; medoids = new uint32_t[1]; - medoids[0] = (_u32)(medoid_id_on_file); + medoids[0] = (uint32_t)(medoid_id_on_file); use_medoids_data_as_centroids(); } @@ -975,7 +1067,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons if (file_exists(norm_file) && metric == diskann::Metric::INNER_PRODUCT) { - _u64 dumr, dumc; + uint64_t dumr, dumc; float *norm_val; diskann::load_bin(norm_file, norm_val, dumr, dumc); this->max_base_norm = norm_val[0]; @@ -1022,39 +1114,41 @@ bool getNextCompletedRequest(const IOContext &ctx, size_t size, int &completedIn #endif template -void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_search, const _u64 l_search, - _u64 *indices, float *distances, const _u64 beam_width, +void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, const bool use_reorder_data, QueryStats *stats) { - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits<_u32>::max(), + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits::max(), use_reorder_data, stats); } template -void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_search, const _u64 l_search, - _u64 *indices, float *distances, const _u64 beam_width, +void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, const bool use_filter, const LabelT &filter_label, const bool use_reorder_data, QueryStats *stats) { cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filter_label, - std::numeric_limits<_u32>::max(), use_reorder_data, stats); + std::numeric_limits::max(), use_reorder_data, stats); } template -void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_search, const _u64 l_search, - _u64 *indices, float *distances, const _u64 beam_width, - const _u32 io_limit, const bool use_reorder_data, QueryStats *stats) +void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, + const uint32_t io_limit, const bool use_reorder_data, + QueryStats *stats) { LabelT dummy_filter = 0; - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, - std::numeric_limits<_u32>::max(), use_reorder_data, stats); + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, + use_reorder_data, stats); } template -void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_search, const _u64 l_search, - _u64 *indices, float *distances, const _u64 beam_width, - const bool use_filter, const LabelT &filter_label, const _u32 io_limit, - const bool use_reorder_data, QueryStats *stats) +void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, + uint64_t *indices, float *distances, const uint64_t beam_width, + const bool use_filter, const LabelT &filter_label, + const uint32_t io_limit, const bool use_reorder_data, + QueryStats *stats) { int32_t filter_num = 0; if (use_filter) @@ -1107,7 +1201,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s for (size_t i = 0; i < this->data_dim - 1; i++) { - aligned_query_T[i] /= query_norm; + aligned_query_T[i] = (T)(aligned_query_T[i] / query_norm); } pq_query_scratch->set(this->data_dim, aligned_query_T); } @@ -1122,12 +1216,11 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s // pointers to buffers for data T *data_buf = query_scratch->coord_scratch; - _u64 &data_buf_idx = query_scratch->coord_idx; _mm_prefetch((char *)data_buf, _MM_HINT_T1); // sector scratch char *sector_scratch = query_scratch->sector_scratch; - _u64 §or_scratch_idx = query_scratch->sector_idx; + uint64_t §or_scratch_idx = query_scratch->sector_idx; // query <-> PQ chunk centers distances pq_table.preprocess_query(query_rotated); // center the query and rotate if @@ -1137,28 +1230,29 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s // query <-> neighbor list float *dist_scratch = pq_query_scratch->aligned_dist_scratch; - _u8 *pq_coord_scratch = pq_query_scratch->aligned_pq_coord_scratch; + uint8_t *pq_coord_scratch = pq_query_scratch->aligned_pq_coord_scratch; // lambda to batch compute query<-> node distances in PQ space - auto compute_dists = [this, pq_coord_scratch, pq_dists](const unsigned *ids, const _u64 n_ids, float *dists_out) { + auto compute_dists = [this, pq_coord_scratch, pq_dists](const uint32_t *ids, const uint64_t n_ids, + float *dists_out) { diskann::aggregate_coords(ids, n_ids, this->data, this->n_chunks, pq_coord_scratch); diskann::pq_dist_lookup(pq_coord_scratch, n_ids, this->n_chunks, pq_dists, dists_out); }; Timer query_timer, io_timer, cpu_timer; - tsl::robin_set<_u64> &visited = query_scratch->visited; + tsl::robin_set &visited = query_scratch->visited; NeighborPriorityQueue &retset = query_scratch->retset; retset.reserve(l_search); std::vector &full_retset = query_scratch->full_retset; - _u32 best_medoid = 0; + uint32_t best_medoid = 0; float best_dist = (std::numeric_limits::max)(); if (!use_filter) { - for (_u64 cur_m = 0; cur_m < num_medoids; cur_m++) + for (uint64_t cur_m = 0; cur_m < num_medoids; cur_m++) { float cur_expanded_dist = - dist_cmp_float->compare(query_float, centroid_data + aligned_dim * cur_m, (unsigned)aligned_dim); + dist_cmp_float->compare(query_float, centroid_data + aligned_dim * cur_m, (uint32_t)aligned_dim); if (cur_expanded_dist < best_dist) { best_medoid = medoids[cur_m]; @@ -1166,31 +1260,46 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s } } } - else if (_filter_to_medoid_id.find(filter_label) != _filter_to_medoid_id.end()) - { - best_medoid = _filter_to_medoid_id[filter_label]; - } else { - throw ANNException("Cannot find medoid for specified filter.", -1, __FUNCSIG__, __FILE__, __LINE__); + if (_filter_to_medoid_ids.find(filter_label) != _filter_to_medoid_ids.end()) + { + const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; + for (uint64_t cur_m = 0; cur_m < medoid_ids.size(); cur_m++) + { + // for filtered index, we dont store global centroid data as for unfiltered index, so we use PQ distance + // as approximation to decide closest medoid matching the query filter. + compute_dists(&medoid_ids[cur_m], 1, dist_scratch); + float cur_expanded_dist = dist_scratch[0]; + if (cur_expanded_dist < best_dist) + { + best_medoid = medoid_ids[cur_m]; + best_dist = cur_expanded_dist; + } + } + } + else + { + throw ANNException("Cannot find medoid for specified filter.", -1, __FUNCSIG__, __FILE__, __LINE__); + } } compute_dists(&best_medoid, 1, dist_scratch); retset.insert(Neighbor(best_medoid, dist_scratch[0])); visited.insert(best_medoid); - unsigned cmps = 0; - unsigned hops = 0; - unsigned num_ios = 0; + uint32_t cmps = 0; + uint32_t hops = 0; + uint32_t num_ios = 0; // cleared every iteration - std::vector frontier; + std::vector frontier; frontier.reserve(2 * beam_width); - std::vector> frontier_nhoods; + std::vector> frontier_nhoods; frontier_nhoods.reserve(2 * beam_width); std::vector frontier_read_reqs; frontier_read_reqs.reserve(2 * beam_width); - std::vector>> cached_nhoods; + std::vector>> cached_nhoods; cached_nhoods.reserve(2 * beam_width); while (retset.has_unexpanded_node() && num_ios < io_limit) @@ -1202,7 +1311,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s cached_nhoods.clear(); sector_scratch_idx = 0; // find new beam - _u32 num_seen = 0; + uint32_t num_seen = 0; while (retset.has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width) { auto nbr = retset.closest_unexpanded(); @@ -1222,7 +1331,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s } if (this->count_visited_nodes) { - reinterpret_cast &>(this->node_visit_counter[nbr.id].second).fetch_add(1); + reinterpret_cast &>(this->node_visit_counter[nbr.id].second).fetch_add(1); } } @@ -1231,10 +1340,10 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s { if (stats != nullptr) stats->n_hops++; - for (_u64 i = 0; i < frontier.size(); i++) + for (uint64_t i = 0; i < frontier.size(); i++) { auto id = frontier[i]; - std::pair<_u32, char *> fnhood; + std::pair fnhood; fnhood.first = id; fnhood.second = sector_scratch + sector_scratch_idx * SECTOR_LEN; sector_scratch_idx++; @@ -1249,13 +1358,14 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s } io_timer.reset(); #ifdef USE_BING_INFRA - reader->read(frontier_read_reqs, ctx, true); // async reader windows. + reader->read(frontier_read_reqs, ctx, + true); // async reader windows. #else reader->read(frontier_read_reqs, ctx); // synchronous IO linux #endif if (stats != nullptr) { - stats->io_us += (double)io_timer.elapsed(); + stats->io_us += (float)io_timer.elapsed(); } } @@ -1267,40 +1377,41 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s float cur_expanded_dist; if (!use_disk_index_pq) { - cur_expanded_dist = dist_cmp->compare(aligned_query_T, node_fp_coords_copy, (unsigned)aligned_dim); + cur_expanded_dist = dist_cmp->compare(aligned_query_T, node_fp_coords_copy, (uint32_t)aligned_dim); } else { if (metric == diskann::Metric::INNER_PRODUCT) - cur_expanded_dist = disk_pq_table.inner_product(query_float, (_u8 *)node_fp_coords_copy); + cur_expanded_dist = disk_pq_table.inner_product(query_float, (uint8_t *)node_fp_coords_copy); else cur_expanded_dist = disk_pq_table.l2_distance( // disk_pq does not support OPQ yet - query_float, (_u8 *)node_fp_coords_copy); + query_float, (uint8_t *)node_fp_coords_copy); } - full_retset.push_back(Neighbor((unsigned)cached_nhood.first, cur_expanded_dist)); + full_retset.push_back(Neighbor((uint32_t)cached_nhood.first, cur_expanded_dist)); - _u64 nnbrs = cached_nhood.second.first; - unsigned *node_nbrs = cached_nhood.second.second; + uint64_t nnbrs = cached_nhood.second.first; + uint32_t *node_nbrs = cached_nhood.second.second; // compute node_nbrs <-> query dists in PQ space cpu_timer.reset(); compute_dists(node_nbrs, nnbrs, dist_scratch); if (stats != nullptr) { - stats->n_cmps += (double)nnbrs; - stats->cpu_us += (double)cpu_timer.elapsed(); + stats->n_cmps += (uint32_t)nnbrs; + stats->cpu_us += (float)cpu_timer.elapsed(); } // process prefetched nhood - for (_u64 m = 0; m < nnbrs; ++m) + for (uint64_t m = 0; m < nnbrs; ++m) { - unsigned id = node_nbrs[m]; + uint32_t id = node_nbrs[m]; if (visited.insert(id).second) { if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) continue; - if (use_filter && !point_has_label(id, filter_num) && !point_has_label(id, _universal_filter_num)) + if (use_filter && !point_has_label(id, filter_num) + && (!_use_universal_label || !point_has_label(id, _universal_filter_num))) continue; cmps++; float dist = dist_scratch[m]; @@ -1313,8 +1424,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s // process each frontier nhood - compute distances to unvisited nodes int completedIndex = -1; long requestCount = static_cast(frontier_read_reqs.size()); - // If we issued read requests and if a read is complete or there are reads - // in wait state, then enter the while loop. + // If we issued read requests and if a read is complete or there are + // reads in wait state, then enter the while loop. while (requestCount > 0 && getNextCompletedRequest(ctx, requestCount, completedIndex)) { assert(completedIndex >= 0); @@ -1325,50 +1436,45 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s { #endif char *node_disk_buf = OFFSET_TO_NODE(frontier_nhood.second, frontier_nhood.first); - unsigned *node_buf = OFFSET_TO_NODE_NHOOD(node_disk_buf); - _u64 nnbrs = (_u64)(*node_buf); + uint32_t *node_buf = OFFSET_TO_NODE_NHOOD(node_disk_buf); + uint64_t nnbrs = (uint64_t)(*node_buf); T *node_fp_coords = OFFSET_TO_NODE_COORDS(node_disk_buf); - // assert(data_buf_idx < MAX_N_CMPS); - if (data_buf_idx == MAX_N_CMPS) - data_buf_idx = 0; - - T *node_fp_coords_copy = data_buf + (data_buf_idx * aligned_dim); - data_buf_idx++; - memcpy(node_fp_coords_copy, node_fp_coords, disk_bytes_per_point); + memcpy(data_buf, node_fp_coords, disk_bytes_per_point); float cur_expanded_dist; if (!use_disk_index_pq) { - cur_expanded_dist = dist_cmp->compare(aligned_query_T, node_fp_coords_copy, (unsigned)aligned_dim); + cur_expanded_dist = dist_cmp->compare(aligned_query_T, data_buf, (uint32_t)aligned_dim); } else { if (metric == diskann::Metric::INNER_PRODUCT) - cur_expanded_dist = disk_pq_table.inner_product(query_float, (_u8 *)node_fp_coords_copy); + cur_expanded_dist = disk_pq_table.inner_product(query_float, (uint8_t *)data_buf); else - cur_expanded_dist = disk_pq_table.l2_distance(query_float, (_u8 *)node_fp_coords_copy); + cur_expanded_dist = disk_pq_table.l2_distance(query_float, (uint8_t *)data_buf); } full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist)); - unsigned *node_nbrs = (node_buf + 1); + uint32_t *node_nbrs = (node_buf + 1); // compute node_nbrs <-> query dist in PQ space cpu_timer.reset(); compute_dists(node_nbrs, nnbrs, dist_scratch); if (stats != nullptr) { - stats->n_cmps += (double)nnbrs; - stats->cpu_us += (double)cpu_timer.elapsed(); + stats->n_cmps += (uint32_t)nnbrs; + stats->cpu_us += (float)cpu_timer.elapsed(); } cpu_timer.reset(); // process prefetch-ed nhood - for (_u64 m = 0; m < nnbrs; ++m) + for (uint64_t m = 0; m < nnbrs; ++m) { - unsigned id = node_nbrs[m]; + uint32_t id = node_nbrs[m]; if (visited.insert(id).second) { if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) continue; - if (use_filter && !point_has_label(id, filter_num) && !point_has_label(id, _universal_filter_num)) + if (use_filter && !point_has_label(id, filter_num) + && (!_use_universal_label || !point_has_label(id, _universal_filter_num))) continue; cmps++; float dist = dist_scratch[m]; @@ -1384,7 +1490,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s if (stats != nullptr) { - stats->cpu_us += (double)cpu_timer.elapsed(); + stats->cpu_us += (float)cpu_timer.elapsed(); } } @@ -1398,7 +1504,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s { if (!(this->reorder_data_exists)) { - throw ANNException("Requested use of reordering data which does not exist in index " + throw ANNException("Requested use of reordering data which does " + "not exist in index " "file", -1, __FUNCSIG__, __FILE__, __LINE__); } @@ -1435,20 +1542,20 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s { auto id = full_retset[i].id; auto location = (sector_scratch + i * SECTOR_LEN) + VECTOR_SECTOR_OFFSET(id); - full_retset[i].distance = dist_cmp->compare(aligned_query_T, (T *)location, this->data_dim); + full_retset[i].distance = dist_cmp->compare(aligned_query_T, (T *)location, (uint32_t)this->data_dim); } std::sort(full_retset.begin(), full_retset.end()); } // copy k_search values - for (_u64 i = 0; i < k_search; i++) + for (uint64_t i = 0; i < k_search; i++) { indices[i] = full_retset[i].id; - - if (_dummy_pts.find(indices[i]) != _dummy_pts.end()) + auto key = (uint32_t)indices[i]; + if (_dummy_pts.find(key) != _dummy_pts.end()) { - indices[i] = _dummy_to_real_map[indices[i]]; + indices[i] = _dummy_to_real_map[key]; } if (distances != nullptr) @@ -1458,8 +1565,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s { // flip the sign to convert min to max distances[i] = (-distances[i]); - // rescale to revert back to original norms (cancelling the effect of - // base and query pre-processing) + // rescale to revert back to original norms (cancelling the + // effect of base and query pre-processing) if (max_base_norm != 0) distances[i] *= (max_base_norm * query_norm); } @@ -1472,7 +1579,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s if (stats != nullptr) { - stats->total_us = (double)query_timer.elapsed(); + stats->total_us = (float)query_timer.elapsed(); } } @@ -1480,25 +1587,26 @@ void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_s // indices and distances need to be pre-allocated of size l_search and the // return value is the number of matching hits. template -_u32 PQFlashIndex::range_search(const T *query1, const double range, const _u64 min_l_search, - const _u64 max_l_search, std::vector<_u64> &indices, - std::vector &distances, const _u64 min_beam_width, QueryStats *stats) +uint32_t PQFlashIndex::range_search(const T *query1, const double range, const uint64_t min_l_search, + const uint64_t max_l_search, std::vector &indices, + std::vector &distances, const uint64_t min_beam_width, + QueryStats *stats) { - _u32 res_count = 0; + uint32_t res_count = 0; bool stop_flag = false; - _u32 l_search = min_l_search; // starting size of the candidate list + uint32_t l_search = (uint32_t)min_l_search; // starting size of the candidate list while (!stop_flag) { indices.resize(l_search); distances.resize(l_search); - _u64 cur_bw = min_beam_width > (l_search / 5) ? min_beam_width : l_search / 5; + uint64_t cur_bw = min_beam_width > (l_search / 5) ? min_beam_width : l_search / 5; cur_bw = (cur_bw > 100) ? 100 : cur_bw; for (auto &x : distances) x = std::numeric_limits::max(); this->cached_beam_search(query1, l_search, l_search, indices.data(), distances.data(), cur_bw, false, stats); - for (_u32 i = 0; i < l_search; i++) + for (uint32_t i = 0; i < l_search; i++) { if (distances[i] > (float)range) { @@ -1508,7 +1616,7 @@ _u32 PQFlashIndex::range_search(const T *query1, const double range, else if (i == l_search - 1) res_count = l_search; } - if (res_count < (_u32)(l_search / 2.0)) + if (res_count < (uint32_t)(l_search / 2.0)) stop_flag = true; l_search = l_search * 2; if (l_search > max_l_search) @@ -1519,7 +1627,7 @@ _u32 PQFlashIndex::range_search(const T *query1, const double range, return res_count; } -template _u64 PQFlashIndex::get_data_dim() +template uint64_t PQFlashIndex::get_data_dim() { return data_dim; } @@ -1548,11 +1656,11 @@ template char *PQFlashIndex::getHeaderB #endif // instantiations -template class PQFlashIndex<_u8>; -template class PQFlashIndex<_s8>; +template class PQFlashIndex; +template class PQFlashIndex; template class PQFlashIndex; -template class PQFlashIndex<_u8, uint16_t>; -template class PQFlashIndex<_s8, uint16_t>; +template class PQFlashIndex; +template class PQFlashIndex; template class PQFlashIndex; } // namespace diskann diff --git a/src/restapi/search_wrapper.cpp b/src/restapi/search_wrapper.cpp index 595882acb..dc9f5734e 100644 --- a/src/restapi/search_wrapper.cpp +++ b/src/restapi/search_wrapper.cpp @@ -176,7 +176,7 @@ template SearchResult PQFlashSearch::search(const T *query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls) { - _u64 *indices_u64 = new _u64[K]; + uint64_t *indices_u64 = new uint64_t[K]; unsigned *indices = new unsigned[K]; float *distances = new float[K]; diff --git a/src/scratch.cpp b/src/scratch.cpp index 9f369c61f..745daa6a7 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -13,7 +13,7 @@ namespace diskann // template InMemQueryScratch::InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, - bool init_pq_scratch, size_t bitmask_size) + size_t aligned_dim, size_t alignment_factor, bool init_pq_scratch, size_t bitmask_size) : _L(0), _R(r), _maxc(maxc) { if (search_l == 0 || indexing_l == 0 || r == 0 || dim == 0) @@ -24,8 +24,7 @@ InMemQueryScratch::InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, throw diskann::ANNException(ss.str(), -1); } - auto aligned_dim = ROUND_UP(dim, 8); - alloc_aligned(((void **)&_aligned_query), aligned_dim * sizeof(T), 8 * sizeof(T)); + alloc_aligned(((void **)&_aligned_query), aligned_dim * sizeof(T), alignment_factor * sizeof(T)); memset(_aligned_query, 0, aligned_dim * sizeof(T)); if (init_pq_scratch) @@ -35,8 +34,8 @@ InMemQueryScratch::InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, _occlude_factor.reserve(maxc); _inserted_into_pool_bs = new boost::dynamic_bitset<>(); - _id_scratch.reserve(std::ceil(1.5 * GRAPH_SLACK_FACTOR * _R)); - _dist_scratch.reserve(std::ceil(1.5 * GRAPH_SLACK_FACTOR * _R)); + _id_scratch.reserve((size_t)std::ceil(1.5 * GRAPH_SLACK_FACTOR * _R)); + _dist_scratch.reserve((size_t)std::ceil(1.5 * GRAPH_SLACK_FACTOR * _R)); resize_for_new_L(std::max(search_l, indexing_l)); @@ -92,7 +91,6 @@ template InMemQueryScratch::~InMemQueryScratch() // template void SSDQueryScratch::reset() { - coord_idx = 0; sector_idx = 0; visited.clear(); retset.clear(); @@ -101,15 +99,15 @@ template void SSDQueryScratch::reset() template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve) { - _u64 coord_alloc_size = ROUND_UP(MAX_N_CMPS * aligned_dim, 256); + size_t coord_alloc_size = ROUND_UP(sizeof(T) * aligned_dim, 256); diskann::alloc_aligned((void **)&coord_scratch, coord_alloc_size, 256); - diskann::alloc_aligned((void **)§or_scratch, (_u64)MAX_N_SECTOR_READS * (_u64)SECTOR_LEN, SECTOR_LEN); + diskann::alloc_aligned((void **)§or_scratch, (size_t)MAX_N_SECTOR_READS * (size_t)SECTOR_LEN, SECTOR_LEN); diskann::alloc_aligned((void **)&aligned_query_T, aligned_dim * sizeof(T), 8 * sizeof(T)); _pq_scratch = new PQScratch(MAX_GRAPH_DEGREE, aligned_dim); - memset(coord_scratch, 0, MAX_N_CMPS * aligned_dim); + memset(coord_scratch, 0, coord_alloc_size); memset(aligned_query_T, 0, aligned_dim * sizeof(T)); visited.reserve(visited_reserve); @@ -120,6 +118,7 @@ template SSDQueryScratch::~SSDQueryScratch() { diskann::aligned_free((void *)coord_scratch); diskann::aligned_free((void *)sector_scratch); + diskann::aligned_free((void *)aligned_query_T); delete[] _pq_scratch; } @@ -138,11 +137,11 @@ template DISKANN_DLLEXPORT class InMemQueryScratch; template DISKANN_DLLEXPORT class InMemQueryScratch; template DISKANN_DLLEXPORT class InMemQueryScratch; -template DISKANN_DLLEXPORT class SSDQueryScratch<_u8>; -template DISKANN_DLLEXPORT class SSDQueryScratch<_s8>; +template DISKANN_DLLEXPORT class SSDQueryScratch; +template DISKANN_DLLEXPORT class SSDQueryScratch; template DISKANN_DLLEXPORT class SSDQueryScratch; -template DISKANN_DLLEXPORT class SSDThreadData<_u8>; -template DISKANN_DLLEXPORT class SSDThreadData<_s8>; +template DISKANN_DLLEXPORT class SSDThreadData; +template DISKANN_DLLEXPORT class SSDThreadData; template DISKANN_DLLEXPORT class SSDThreadData; -} // namespace diskann \ No newline at end of file +} // namespace diskann diff --git a/src/utils.cpp b/src/utils.cpp index ec25dead3..b675e656d 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -73,20 +73,20 @@ bool AvxSupportedCPU = false; namespace diskann { -void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, _u64 npts, _u64 ndims) +void block_convert(std::ofstream &writr, std::ifstream &readr, float *read_buf, size_t npts, size_t ndims) { readr.read((char *)read_buf, npts * ndims * sizeof(float)); - _u32 ndims_u32 = (_u32)ndims; + uint32_t ndims_u32 = (uint32_t)ndims; #pragma omp parallel for - for (_s64 i = 0; i < (_s64)npts; i++) + for (int64_t i = 0; i < (int64_t)npts; i++) { float norm_pt = std::numeric_limits::epsilon(); - for (_u32 dim = 0; dim < ndims_u32; dim++) + for (uint32_t dim = 0; dim < ndims_u32; dim++) { norm_pt += *(read_buf + i * ndims + dim) * *(read_buf + i * ndims + dim); } norm_pt = std::sqrt(norm_pt); - for (_u32 dim = 0; dim < ndims_u32; dim++) + for (uint32_t dim = 0; dim < ndims_u32; dim++) { *(read_buf + i * ndims + dim) = *(read_buf + i * ndims + dim) / norm_pt; } @@ -100,24 +100,25 @@ void normalize_data_file(const std::string &inFileName, const std::string &outFi std::ofstream writr(outFileName, std::ios::binary); int npts_s32, ndims_s32; - readr.read((char *)&npts_s32, sizeof(_s32)); - readr.read((char *)&ndims_s32, sizeof(_s32)); + readr.read((char *)&npts_s32, sizeof(int32_t)); + readr.read((char *)&ndims_s32, sizeof(int32_t)); - writr.write((char *)&npts_s32, sizeof(_s32)); - writr.write((char *)&ndims_s32, sizeof(_s32)); + writr.write((char *)&npts_s32, sizeof(int32_t)); + writr.write((char *)&ndims_s32, sizeof(int32_t)); - _u64 npts = (_u64)npts_s32, ndims = (_u64)ndims_s32; + size_t npts = (size_t)npts_s32; + size_t ndims = (size_t)ndims_s32; diskann::cout << "Normalizing FLOAT vectors in file: " << inFileName << std::endl; diskann::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl; - _u64 blk_size = 131072; - _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + size_t blk_size = 131072; + size_t nblks = ROUND_UP(npts, blk_size) / blk_size; diskann::cout << "# blks: " << nblks << std::endl; float *read_buf = new float[npts * ndims]; - for (_u64 i = 0; i < nblks; i++) + for (size_t i = 0; i < nblks; i++) { - _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + size_t cblk_size = std::min(npts - i * blk_size, blk_size); block_convert(writr, readr, read_buf, cblk_size, ndims); } delete[] read_buf; @@ -125,18 +126,18 @@ void normalize_data_file(const std::string &inFileName, const std::string &outFi diskann::cout << "Wrote normalized points to file: " << outFileName << std::endl; } -double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs, - unsigned *our_results, unsigned dim_or, unsigned recall_at) +double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist, uint32_t dim_gs, + uint32_t *our_results, uint32_t dim_or, uint32_t recall_at) { double total_recall = 0; - std::set gt, res; + std::set gt, res; for (size_t i = 0; i < num_queries; i++) { gt.clear(); res.clear(); - unsigned *gt_vec = gold_std + dim_gs * i; - unsigned *res_vec = our_results + dim_or * i; + uint32_t *gt_vec = gold_std + dim_gs * i; + uint32_t *res_vec = our_results + dim_or * i; size_t tie_breaker = recall_at; if (gs_dist != nullptr) { @@ -148,9 +149,9 @@ double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist gt.insert(gt_vec, gt_vec + tie_breaker); res.insert(res_vec, - res_vec + recall_at); // change to recall_at for recall k@k or - // dim_or for k@dim_or - unsigned cur_recall = 0; + res_vec + recall_at); // change to recall_at for recall k@k + // or dim_or for k@dim_or + uint32_t cur_recall = 0; for (auto &v : gt) { if (res.find(v) != res.end()) @@ -163,22 +164,22 @@ double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist return total_recall / (num_queries) * (100.0 / recall_at); } -double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs, - unsigned *our_results, unsigned dim_or, unsigned recall_at, - const tsl::robin_set &active_tags) +double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist, uint32_t dim_gs, + uint32_t *our_results, uint32_t dim_or, uint32_t recall_at, + const tsl::robin_set &active_tags) { double total_recall = 0; - std::set gt, res; + std::set gt, res; bool printed = false; for (size_t i = 0; i < num_queries; i++) { gt.clear(); res.clear(); - unsigned *gt_vec = gold_std + dim_gs * i; - unsigned *res_vec = our_results + dim_or * i; + uint32_t *gt_vec = gold_std + dim_gs * i; + uint32_t *res_vec = our_results + dim_or * i; size_t tie_breaker = recall_at; - unsigned active_points_count = 0; - unsigned cur_counter = 0; + uint32_t active_points_count = 0; + uint32_t cur_counter = 0; while (active_points_count < recall_at && cur_counter < dim_gs) { if (active_tags.find(*(gt_vec + cur_counter)) != active_tags.end()) @@ -209,7 +210,7 @@ double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist gt.insert(gt_vec, gt_vec + tie_breaker); res.insert(res_vec, res_vec + recall_at); - unsigned cur_recall = 0; + uint32_t cur_recall = 0; for (auto &v : res) { if (gt.find(v) != gt.end()) @@ -222,11 +223,11 @@ double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist return ((double)(total_recall / (num_queries))) * ((double)(100.0 / recall_at)); } -double calculate_range_search_recall(unsigned num_queries, std::vector> &groundtruth, - std::vector> &our_results) +double calculate_range_search_recall(uint32_t num_queries, std::vector> &groundtruth, + std::vector> &our_results) { double total_recall = 0; - std::set gt, res; + std::set gt, res; for (size_t i = 0; i < num_queries; i++) { @@ -235,7 +236,7 @@ double calculate_range_search_recall(unsigned num_queries, std::vector #include "utils.h" +#include #define SECTOR_LEN 4096 void WindowsAlignedFileReader::open(const std::string &fname) { - #ifdef UNICODE m_filename = std::wstring(fname.begin(), fname.end()); #else @@ -44,15 +44,24 @@ void WindowsAlignedFileReader::register_thread() FILE_ATTRIBUTE_READONLY | FILE_FLAG_NO_BUFFERING | FILE_FLAG_OVERLAPPED | FILE_FLAG_RANDOM_ACCESS, NULL); if (ctx.fhandle == INVALID_HANDLE_VALUE) { - diskann::cout << "Error opening " << std::string(m_filename.begin(), m_filename.end()) - << " -- error=" << GetLastError() << std::endl; + const size_t c_max_filepath_len = 256; + size_t actual_len = 0; + char filePath[c_max_filepath_len]; + if (wcstombs_s(&actual_len, filePath, c_max_filepath_len, m_filename.c_str(), m_filename.length()) == 0) + { + diskann::cout << "Error opening " << filePath << " -- error=" << GetLastError() << std::endl; + } + else + { + diskann::cout << "Error converting wchar to char -- error=" << GetLastError() << std::endl; + } } // create IOCompletionPort ctx.iocp = CreateIoCompletionPort(ctx.fhandle, ctx.iocp, 0, 0); // create MAX_DEPTH # of reqs - for (_u64 i = 0; i < MAX_IO_DEPTH; i++) + for (uint64_t i = 0; i < MAX_IO_DEPTH; i++) { OVERLAPPED os; memset(&os, 0, sizeof(OVERLAPPED)); @@ -80,9 +89,9 @@ void WindowsAlignedFileReader::read(std::vector &read_reqs, IOConte { using namespace std::chrono_literals; // execute each request sequentially - _u64 n_reqs = read_reqs.size(); - _u64 n_batches = ROUND_UP(n_reqs, MAX_IO_DEPTH) / MAX_IO_DEPTH; - for (_u64 i = 0; i < n_batches; i++) + size_t n_reqs = read_reqs.size(); + uint64_t n_batches = ROUND_UP(n_reqs, MAX_IO_DEPTH) / MAX_IO_DEPTH; + for (uint64_t i = 0; i < n_batches; i++) { // reset all OVERLAPPED objects for (auto &os : ctx.reqs) @@ -100,17 +109,17 @@ void WindowsAlignedFileReader::read(std::vector &read_reqs, IOConte } // batch start/end - _u64 batch_start = MAX_IO_DEPTH * i; - _u64 batch_size = std::min((_u64)(n_reqs - batch_start), (_u64)MAX_IO_DEPTH); + uint64_t batch_start = MAX_IO_DEPTH * i; + uint64_t batch_size = std::min((uint64_t)(n_reqs - batch_start), (uint64_t)MAX_IO_DEPTH); // fill OVERLAPPED and issue them - for (_u64 j = 0; j < batch_size; j++) + for (uint64_t j = 0; j < batch_size; j++) { AlignedRead &req = read_reqs[batch_start + j]; OVERLAPPED &os = ctx.reqs[j]; - _u64 offset = req.offset; - _u64 nbytes = req.len; + uint64_t offset = req.offset; + uint64_t nbytes = req.len; char *read_buf = (char *)req.buf; assert(IS_ALIGNED(read_buf, SECTOR_LEN)); assert(IS_ALIGNED(offset, SECTOR_LEN)); @@ -120,7 +129,7 @@ void WindowsAlignedFileReader::read(std::vector &read_reqs, IOConte os.Offset = offset & 0xffffffff; os.OffsetHigh = (offset >> 32); - BOOL ret = ReadFile(ctx.fhandle, read_buf, nbytes, NULL, &os); + BOOL ret = ReadFile(ctx.fhandle, read_buf, (DWORD)nbytes, NULL, &os); if (ret == FALSE) { auto error = GetLastError(); @@ -135,7 +144,7 @@ void WindowsAlignedFileReader::read(std::vector &read_reqs, IOConte } } DWORD n_read = 0; - _u64 n_complete = 0; + uint64_t n_complete = 0; ULONG_PTR completion_key = 0; OVERLAPPED *lp_os; while (n_complete < batch_size) @@ -153,11 +162,14 @@ void WindowsAlignedFileReader::read(std::vector &read_reqs, IOConte DWORD error = GetLastError(); if (error != WAIT_TIMEOUT) { - diskann::cerr << "GetQueuedCompletionStatus() failed with error = " << error << std::endl; + diskann::cerr << "GetQueuedCompletionStatus() failed " + "with error = " + << error << std::endl; throw diskann::ANNException("GetQueuedCompletionStatus failed with error: ", error, __FUNCSIG__, __FILE__, __LINE__); } - // no completion packet dequeued ==> sleep for 5us and try again + // no completion packet dequeued ==> sleep for 5us and try + // again std::this_thread::sleep_for(5us); } else diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6aa5532ef..6af8405cc 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,29 +1,41 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT license. -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_COMPILE_WARNING_AS_ERROR ON) -add_executable(build_memory_index build_memory_index.cpp) -target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) +find_package(Boost COMPONENTS unit_test_framework) -add_executable(build_stitched_index build_stitched_index.cpp) -target_link_libraries(build_stitched_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) +# For Windows, fall back to nuget version if find_package didn't find it. +if (MSVC AND NOT Boost_FOUND) + set(DISKANN_BOOST_INCLUDE "${DISKANN_MSVC_PACKAGES}/boost/lib/native/include") + # Multi-threaded static library. + set(UNIT_TEST_FRAMEWORK_LIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_unit_test_framework-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_unit_test_framework-vc${MSVC_TOOLSET_VERSION}-mt-x64-*.lib") + file(GLOB DISKANN_BOOST_UNIT_TEST_FRAMEWORK_LIB ${UNIT_TEST_FRAMEWORK_LIB_PATTERN}) -add_executable(search_memory_index search_memory_index.cpp) -target_link_libraries(search_memory_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) + set(UNIT_TEST_FRAMEWORK_DLIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_unit_test_framework-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_unit_test_framework-vc${MSVC_TOOLSET_VERSION}-mt-gd-x64-*.lib") + file(GLOB DISKANN_BOOST_UNIT_TEST_FRAMEWORK_DLIB ${UNIT_TEST_FRAMEWORK_DLIB_PATTERN}) -add_executable(build_disk_index build_disk_index.cpp) -target_link_libraries(build_disk_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB} Boost::program_options) + if (EXISTS ${DISKANN_BOOST_INCLUDE} AND EXISTS ${DISKANN_BOOST_UNIT_TEST_FRAMEWORK_LIB} AND EXISTS ${DISKANN_BOOST_UNIT_TEST_FRAMEWORK_DLIB}) + set(Boost_FOUND ON) + set(Boost_INCLUDE_DIR ${DISKANN_BOOST_INCLUDE}) + add_library(Boost::unit_test_framework STATIC IMPORTED) + set_target_properties(Boost::unit_test_framework PROPERTIES IMPORTED_LOCATION_RELEASE "${DISKANN_BOOST_UNIT_TEST_FRAMEWORK_LIB}") + set_target_properties(Boost::unit_test_framework PROPERTIES IMPORTED_LOCATION_DEBUG "${DISKANN_BOOST_UNIT_TEST_FRAMEWORK_DLIB}") + message(STATUS "Falling back to using Boost from the nuget package") + else() + message(WARNING "Couldn't find Boost. Was looking for ${DISKANN_BOOST_INCLUDE} and ${UNIT_TEST_FRAMEWORK_LIB_PATTERN}") + endif() +endif() -add_executable(search_disk_index search_disk_index.cpp) -target_link_libraries(search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) +if (NOT Boost_FOUND) + message(FATAL_ERROR "Couldn't find Boost dependency") +endif() -add_executable(range_search_disk_index range_search_disk_index.cpp) -target_link_libraries(range_search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) -add_executable(test_streaming_scenario test_streaming_scenario.cpp) -target_link_libraries(test_streaming_scenario ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) +set(DISKANN_UNIT_TEST_SOURCES main.cpp index_write_parameters_builder_tests.cpp) -add_executable(test_insert_deletes_consolidate test_insert_deletes_consolidate.cpp) -target_link_libraries(test_insert_deletes_consolidate ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) +add_executable(${PROJECT_NAME}_unit_tests ${DISKANN_SOURCES} ${DISKANN_UNIT_TEST_SOURCES}) +target_link_libraries(${PROJECT_NAME}_unit_tests ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::unit_test_framework) + +add_test(NAME ${PROJECT_NAME}_unit_tests COMMAND ${PROJECT_NAME}_unit_tests) diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 000000000..113c9980b --- /dev/null +++ b/tests/README.md @@ -0,0 +1,11 @@ +# Unit Test project + +This unit test project is based on the [boost unit test framework](https://www.boost.org/doc/libs/1_78_0/libs/test/doc/html/index.html). Below are the simple steps to add new unit test, you could find more usage from the [boost unit test document](https://www.boost.org/doc/libs/1_78_0/libs/test/doc/html/index.html). + +## How to add unit test + +- Create new [BOOST_AUTO_TEST_SUITE](https://www.boost.org/doc/libs/1_78_0/libs/test/doc/html/boost_test/utf_reference/test_org_reference/test_org_boost_auto_test_suite.html) for each class in an individual cpp file + +- Add [BOOST_AUTO_TEST_CASE](https://www.boost.org/doc/libs/1_78_0/libs/test/doc/html/boost_test/utf_reference/test_org_reference/test_org_boost_auto_test_case.html) for each test case in the [BOOST_AUTO_TEST_SUITE](https://www.boost.org/doc/libs/1_78_0/libs/test/doc/html/boost_test/utf_reference/test_org_reference/test_org_boost_auto_test_suite.html) + +- Update the [CMakeLists.txt](CMakeLists.txt) file to add the new cpp file to the test project \ No newline at end of file diff --git a/tests/build_disk_index.cpp b/tests/build_disk_index.cpp deleted file mode 100644 index 8bb89141e..000000000 --- a/tests/build_disk_index.cpp +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include - -#include "utils.h" -#include "disk_utils.h" -#include "math_utils.h" -#include "index.h" -#include "partition.h" - -namespace po = boost::program_options; - -int main(int argc, char **argv) -{ - std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type; - unsigned num_threads, R, L, disk_PQ, build_PQ, Lf, filter_threshold; - float B, M; - bool append_reorder_data = false; - bool use_opq = false; - - po::options_description desc{"Arguments"}; - try - { - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("dist_fn", po::value(&dist_fn)->required(), "distance function "); - desc.add_options()("data_path", po::value(&data_path)->required(), - "Input data file in bin format"); - desc.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - "Path prefix for saving index file components"); - desc.add_options()("max_degree,R", po::value(&R)->default_value(64), "Maximum graph degree"); - desc.add_options()("Lbuild,L", po::value(&L)->default_value(100), - "Build complexity, higher value results in better graphs"); - desc.add_options()("search_DRAM_budget,B", po::value(&B)->required(), - "DRAM budget in GB for searching the index to set the " - "compressed level for data while search happens"); - desc.add_options()("build_DRAM_budget,M", po::value(&M)->required(), - "DRAM budget in GB for building the index"); - desc.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), - "Number of threads used for building index (defaults to " - "omp_get_num_procs())"); - desc.add_options()("PQ_disk_bytes", po::value(&disk_PQ)->default_value(0), - "Number of bytes to which vectors should be compressed " - "on SSD; 0 for no compression"); - desc.add_options()("append_reorder_data", po::bool_switch()->default_value(false), - "Include full precision data in the index. Use only in " - "conjuction with compressed data on SSD."); - desc.add_options()("build_PQ_bytes", po::value(&build_PQ)->default_value(0), - "Number of PQ bytes to build the index; 0 for full precision build"); - desc.add_options()("use_opq", po::bool_switch()->default_value(false), - "Use Optimized Product Quantization (OPQ)."); - desc.add_options()("label_file", po::value(&label_file)->default_value(""), - "Input label file in txt format for Filtered Index build ." - "The file should contain comma separated filters for each node " - "with each line corresponding to a graph node"); - desc.add_options()("universal_label", po::value(&universal_label)->default_value(""), - "Universal label, Use only in conjuction with label file for filtered " - "index build. If a graph node has all the labels against it, we can " - "assign a special universal filter to the point instead of comma " - "separated filters for that point"); - desc.add_options()("FilteredLbuild,Lf", po::value(&Lf)->default_value(0), - "Build complexity for filtered points, higher value " - "results in better graphs"); - desc.add_options()("filter_threshold,F", po::value(&filter_threshold)->default_value(0), - "Threshold to break up the existing nodes to generate new graph " - "internally where each node has a maximum F labels."); - desc.add_options()("label_type", po::value(&label_type)->default_value("uint"), - "Storage type of Labels , default value is uint which " - "will consume memory 4 bytes per filter"); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - if (vm["append_reorder_data"].as()) - append_reorder_data = true; - if (vm["use_opq"].as()) - use_opq = true; - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; - } - - bool use_filters = false; - if (label_file != "") - { - use_filters = true; - } - - diskann::Metric metric; - if (dist_fn == std::string("l2")) - metric = diskann::Metric::L2; - else if (dist_fn == std::string("mips")) - metric = diskann::Metric::INNER_PRODUCT; - else - { - std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl; - return -1; - } - - if (append_reorder_data) - { - if (disk_PQ == 0) - { - std::cout << "Error: It is not necessary to append data for reordering " - "when vectors are not compressed on disk." - << std::endl; - return -1; - } - if (data_type != std::string("float")) - { - std::cout << "Error: Appending data for reordering currently only " - "supported for float data type." - << std::endl; - return -1; - } - } - - std::string params = std::string(std::to_string(R)) + " " + std::string(std::to_string(L)) + " " + - std::string(std::to_string(B)) + " " + std::string(std::to_string(M)) + " " + - std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " + - std::string(std::to_string(append_reorder_data)) + " " + std::string(std::to_string(build_PQ)); - - try - { - if (label_file != "" && label_type == "ushort") - { - if (data_type == std::string("int8")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), - params.c_str(), metric, use_opq, use_filters, - label_file, universal_label, filter_threshold, Lf); - else if (data_type == std::string("uint8")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), - params.c_str(), metric, use_opq, use_filters, - label_file, universal_label, filter_threshold, Lf); - else if (data_type == std::string("float")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), - params.c_str(), metric, use_opq, use_filters, - label_file, universal_label, filter_threshold, Lf); - else - { - diskann::cerr << "Error. Unsupported data type" << std::endl; - return -1; - } - } - else - { - if (data_type == std::string("int8")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, use_filters, label_file, universal_label, - filter_threshold, Lf); - else if (data_type == std::string("uint8")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, use_filters, label_file, universal_label, - filter_threshold, Lf); - else if (data_type == std::string("float")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, use_filters, label_file, universal_label, - filter_threshold, Lf); - else - { - diskann::cerr << "Error. Unsupported data type" << std::endl; - return -1; - } - } - } - catch (const std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index build failed." << std::endl; - return -1; - } -} diff --git a/tests/build_memory_index.cpp b/tests/build_memory_index.cpp deleted file mode 100644 index 07dd67938..000000000 --- a/tests/build_memory_index.cpp +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include -#include - -#include "index.h" -#include "utils.h" - -#ifndef _WINDOWS -#include -#include -#else -#include -#endif - -#include "memory_mapper.h" -#include "ann_exception.h" - -namespace po = boost::program_options; - -template -int build_in_memory_index(const diskann::Metric &metric, const std::string &data_path, const unsigned R, - const unsigned L, const float alpha, const std::string &save_path, const unsigned num_threads, - const bool use_pq_build, const size_t num_pq_bytes, const bool use_opq, - const std::string &label_file, const std::string &universal_label, const _u32 Lf) -{ - diskann::Parameters paras; - paras.Set("R", R); - paras.Set("L", L); - paras.Set("Lf", Lf); - paras.Set("C", 750); // maximum candidate set size during pruning procedure - paras.Set("alpha", alpha); - paras.Set("saturate_graph", 0); - paras.Set("num_threads", num_threads); - std::string labels_file_to_use = save_path + "_label_formatted.txt"; - std::string mem_labels_int_map_file = save_path + "_labels_map.txt"; - - _u64 data_num, data_dim; - diskann::get_bin_metadata(data_path, data_num, data_dim); - - diskann::Index index(metric, data_dim, data_num, false, false, false, use_pq_build, num_pq_bytes, - use_opq); - auto s = std::chrono::high_resolution_clock::now(); - if (label_file == "") - { - index.build(data_path.c_str(), data_num, paras); - } - else - { - convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label); - if (universal_label != "") - { - LabelT unv_label_as_num = 0; - index.set_universal_label(unv_label_as_num); - } - index.build_filtered_index(data_path.c_str(), labels_file_to_use, data_num, paras); - } - std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; - - std::cout << "Indexing time: " << diff.count() << "\n"; - index.save(save_path.c_str()); - if (label_file != "") - std::remove(labels_file_to_use.c_str()); - return 0; -} - -int main(int argc, char **argv) -{ - std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type; - unsigned num_threads, R, L, Lf, build_PQ_bytes; - float alpha; - bool use_pq_build, use_opq; - - po::options_description desc{"Arguments"}; - try - { - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("dist_fn", po::value(&dist_fn)->required(), "distance function "); - desc.add_options()("data_path", po::value(&data_path)->required(), - "Input data file in bin format"); - desc.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), - "Path prefix for saving index file components"); - desc.add_options()("max_degree,R", po::value(&R)->default_value(64), "Maximum graph degree"); - desc.add_options()("Lbuild,L", po::value(&L)->default_value(100), - "Build complexity, higher value results in better graphs"); - desc.add_options()("alpha", po::value(&alpha)->default_value(1.2f), - "alpha controls density and diameter of graph, set 1 for sparse graph, " - "1.2 or 1.4 for denser graphs with lower diameter"); - desc.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), - "Number of threads used for building index (defaults to " - "omp_get_num_procs())"); - desc.add_options()("build_PQ_bytes", po::value(&build_PQ_bytes)->default_value(0), - "Number of PQ bytes to build the index; 0 for full precision build"); - desc.add_options()("use_opq", po::bool_switch()->default_value(false), - "Set true for OPQ compression while using PQ distance comparisons for " - "building the index, and false for PQ compression"); - desc.add_options()("label_file", po::value(&label_file)->default_value(""), - "Input label file in txt format for Filtered Index search. " - "The file should contain comma separated filters for each node " - "with each line corresponding to a graph node"); - desc.add_options()("universal_label", po::value(&universal_label)->default_value(""), - "Universal label, if using it, only in conjunction with labels_file"); - desc.add_options()("FilteredLbuild,Lf", po::value(&Lf)->default_value(0), - "Build complexity for filtered points, higher value " - "results in better graphs"); - desc.add_options()("label_type", po::value(&label_type)->default_value("uint"), - "Storage type of Labels , default value is uint which " - "will consume memory 4 bytes per filter"); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - return 0; - } - po::notify(vm); - use_pq_build = (build_PQ_bytes > 0); - use_opq = vm["use_opq"].as(); - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - return -1; - } - - diskann::Metric metric; - if (dist_fn == std::string("mips")) - { - metric = diskann::Metric::INNER_PRODUCT; - } - else if (dist_fn == std::string("l2")) - { - metric = diskann::Metric::L2; - } - else if (dist_fn == std::string("cosine")) - { - metric = diskann::Metric::COSINE; - } - else - { - std::cout << "Unsupported distance function. Currently only L2/ Inner " - "Product/Cosine are supported." - << std::endl; - return -1; - } - - try - { - diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha - << " #threads: " << num_threads << std::endl; - if (label_file != "" && label_type == "ushort") - { - if (data_type == std::string("int8")) - return build_in_memory_index( - metric, data_path, R, L, alpha, index_path_prefix, num_threads, use_pq_build, build_PQ_bytes, - use_opq, label_file, universal_label, Lf); - else if (data_type == std::string("uint8")) - return build_in_memory_index( - metric, data_path, R, L, alpha, index_path_prefix, num_threads, use_pq_build, build_PQ_bytes, - use_opq, label_file, universal_label, Lf); - else if (data_type == std::string("float")) - return build_in_memory_index( - metric, data_path, R, L, alpha, index_path_prefix, num_threads, use_pq_build, build_PQ_bytes, - use_opq, label_file, universal_label, Lf); - else - { - std::cout << "Unsupported type. Use one of int8, uint8 or float." << std::endl; - return -1; - } - } - else - { - if (data_type == std::string("int8")) - return build_in_memory_index(metric, data_path, R, L, alpha, index_path_prefix, num_threads, - use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label, - Lf); - else if (data_type == std::string("uint8")) - return build_in_memory_index(metric, data_path, R, L, alpha, index_path_prefix, num_threads, - use_pq_build, build_PQ_bytes, use_opq, label_file, - universal_label, Lf); - else if (data_type == std::string("float")) - return build_in_memory_index(metric, data_path, R, L, alpha, index_path_prefix, num_threads, - use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label, - Lf); - else - { - std::cout << "Unsupported type. Use one of int8, uint8 or float." << std::endl; - return -1; - } - } - } - catch (const std::exception &e) - { - std::cout << std::string(e.what()) << std::endl; - diskann::cerr << "Index build failed." << std::endl; - return -1; - } -} diff --git a/tests/build_stitched_index.cpp b/tests/build_stitched_index.cpp deleted file mode 100644 index 22c93d846..000000000 --- a/tests/build_stitched_index.cpp +++ /dev/null @@ -1,800 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include -#include -#include -#include -#include -#include - -#include -#ifndef _WINDOWS -#include -#endif - -#include "index.h" -#include "memory_mapper.h" -#include "parameters.h" -#include "utils.h" - -namespace po = boost::program_options; - -// macros -#define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||" -#define PBWIDTH 60 - -// custom types (for readability) -typedef tsl::robin_set label_set; -typedef std::string path; - -// structs for returning multiple items from a function -typedef std::tuple, tsl::robin_map, label_set> parse_label_file_return_values; -typedef std::tuple>, _u64> load_label_index_return_values; -typedef std::tuple>, _u64> stitch_indices_return_values; - -/* - * Inline function to display progress bar. - */ -inline void print_progress(double percentage) -{ - int val = (int)(percentage * 100); - int lpad = (int)(percentage * PBWIDTH); - int rpad = PBWIDTH - lpad; - printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, ""); - fflush(stdout); -} - -/* - * Inline function to generate a random integer in a range. - */ -inline size_t random(size_t range_from, size_t range_to) -{ - std::random_device rand_dev; - std::mt19937 generator(rand_dev()); - std::uniform_int_distribution distr(range_from, range_to); - return distr(generator); -} - -/* - * function to handle command line parsing. - * - * Arguments are merely the inputs from the command line. - */ -void handle_args(int argc, char **argv, std::string &data_type, path &input_data_path, path &final_index_path_prefix, - path &label_data_path, std::string &universal_label, unsigned &num_threads, unsigned &R, unsigned &L, - unsigned &stitched_R, float &alpha) -{ - po::options_description desc{"Arguments"}; - try - { - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); - desc.add_options()("data_path", po::value(&input_data_path)->required(), "Input data file in bin format"); - desc.add_options()("index_path_prefix", po::value(&final_index_path_prefix)->required(), - "Path prefix for saving index file components"); - desc.add_options()("max_degree,R", po::value(&R)->default_value(64), "Maximum graph degree"); - desc.add_options()("Lbuild,L", po::value(&L)->default_value(100), - "Build complexity, higher value results in better graphs"); - desc.add_options()("stitched_R", po::value(&stitched_R)->default_value(100), - "Degree to prune final graph down to"); - desc.add_options()("alpha", po::value(&alpha)->default_value(1.2f), - "alpha controls density and diameter of graph, set 1 for sparse graph, " - "1.2 or 1.4 for denser graphs with lower diameter"); - desc.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), - "Number of threads used for building index (defaults to " - "omp_get_num_procs())"); - desc.add_options()("label_file", po::value(&label_data_path)->default_value(""), - "Input label file in txt format if present"); - desc.add_options()("universal_label", po::value(&universal_label)->default_value(""), - "If a point comes with the specified universal label (and only the " - "univ. " - "label), then the point is considered to have every possible label"); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) - { - std::cout << desc; - exit(0); - } - po::notify(vm); - } - catch (const std::exception &ex) - { - std::cerr << ex.what() << '\n'; - throw; - } -} - -/* - * Parses the label datafile, which has comma-separated labels on - * each line. Line i corresponds to point id i. - * - * Returns three objects via std::tuple: - * 1. map: key is point id, value is vector of labels said point has - * 2. map: key is label, value is number of points with the label - * 3. the label universe as a set - */ -parse_label_file_return_values parse_label_file(path label_data_path, std::string universal_label) -{ - std::ifstream label_data_stream(label_data_path); - std::string line, token; - unsigned line_cnt = 0; - - // allows us to reserve space for the points_to_labels vector - while (std::getline(label_data_stream, line)) - line_cnt++; - label_data_stream.clear(); - label_data_stream.seekg(0, std::ios::beg); - - // values to return - std::vector point_ids_to_labels(line_cnt); - tsl::robin_map labels_to_number_of_points; - label_set all_labels; - - std::vector<_u32> points_with_universal_label; - line_cnt = 0; - while (std::getline(label_data_stream, line)) - { - std::istringstream current_labels_comma_separated(line); - label_set current_labels; - - // get point id - _u32 point_id = line_cnt; - - // parse comma separated labels - bool current_universal_label_check = false; - while (getline(current_labels_comma_separated, token, ',')) - { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - - // if token is empty, there's no labels for the point - if (token == universal_label) - { - points_with_universal_label.push_back(point_id); - current_universal_label_check = true; - } - else - { - all_labels.insert(token); - current_labels.insert(token); - labels_to_number_of_points[token]++; - } - } - - if (current_labels.size() <= 0 && !current_universal_label_check) - { - std::cerr << "Error: " << point_id << " has no labels." << std::endl; - exit(-1); - } - point_ids_to_labels[point_id] = current_labels; - line_cnt++; - } - - // for every point with universal label, set its label set to all labels - // also, increment the count for number of points a label has - for (const auto &point_id : points_with_universal_label) - { - point_ids_to_labels[point_id] = all_labels; - for (const auto &lbl : all_labels) - labels_to_number_of_points[lbl]++; - } - - std::cout << "Identified " << all_labels.size() << " distinct label(s) for " << point_ids_to_labels.size() - << " points\n" - << std::endl; - - return std::make_tuple(point_ids_to_labels, labels_to_number_of_points, all_labels); -} - -/* - * For each label, generates a file containing all vectors that have said label. - * Also copies data from original bin file to new dimension-aligned file. - * - * Utilizes POSIX functions mmap and writev in order to minimize memory - * overhead, so we include an STL version as well. - * - * Each data file is saved under the following format: - * input_data_path + "_" + label - */ -template -tsl::robin_map> generate_label_specific_vector_files( - path input_data_path, tsl::robin_map labels_to_number_of_points, - std::vector point_ids_to_labels, label_set all_labels) -{ - auto file_writing_timer = std::chrono::high_resolution_clock::now(); - diskann::MemoryMapper input_data(input_data_path); - char *input_start = input_data.getBuf(); - - _u32 number_of_points, dimension; - std::memcpy(&number_of_points, input_start, sizeof(_u32)); - std::memcpy(&dimension, input_start + sizeof(_u32), sizeof(_u32)); - const _u32 VECTOR_SIZE = dimension * sizeof(T); - const size_t METADATA = 2 * sizeof(_u32); - if (number_of_points != point_ids_to_labels.size()) - { - std::cerr << "Error: number of points in labels file and data file differ." << std::endl; - throw; - } - - tsl::robin_map label_to_iovec_map; - tsl::robin_map label_to_curr_iovec; - tsl::robin_map> label_id_to_orig_id; - - // setup iovec list for each label - for (const auto &lbl : all_labels) - { - iovec *label_iovecs = (iovec *)malloc(labels_to_number_of_points[lbl] * sizeof(iovec)); - if (label_iovecs == nullptr) - { - throw; - } - label_to_iovec_map[lbl] = label_iovecs; - label_to_curr_iovec[lbl] = 0; - label_id_to_orig_id[lbl].reserve(labels_to_number_of_points[lbl]); - } - - // each point added to corresponding per-label iovec list - for (_u32 point_id = 0; point_id < number_of_points; point_id++) - { - char *curr_point = input_start + METADATA + (VECTOR_SIZE * point_id); - iovec curr_iovec; - - curr_iovec.iov_base = curr_point; - curr_iovec.iov_len = VECTOR_SIZE; - for (const auto &lbl : point_ids_to_labels[point_id]) - { - *(label_to_iovec_map[lbl] + label_to_curr_iovec[lbl]) = curr_iovec; - label_to_curr_iovec[lbl]++; - label_id_to_orig_id[lbl].push_back(point_id); - } - } - - // write each label iovec to resp. file - for (const auto &lbl : all_labels) - { - int label_input_data_fd; - path curr_label_input_data_path(input_data_path + "_" + lbl); - _u32 curr_num_pts = labels_to_number_of_points[lbl]; - - label_input_data_fd = - open(curr_label_input_data_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC | O_APPEND, (mode_t)0644); - if (label_input_data_fd == -1) - throw; - - // write metadata - _u32 metadata[2] = {curr_num_pts, dimension}; - int return_value = write(label_input_data_fd, metadata, sizeof(_u32) * 2); - if (return_value == -1) - { - throw; - } - - // limits on number of iovec structs per writev means we need to perform - // multiple writevs - size_t i = 0; - while (curr_num_pts > IOV_MAX) - { - return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), IOV_MAX); - if (return_value == -1) - { - close(label_input_data_fd); - throw; - } - curr_num_pts -= IOV_MAX; - i += 1; - } - return_value = writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), curr_num_pts); - if (return_value == -1) - { - close(label_input_data_fd); - throw; - } - - free(label_to_iovec_map[lbl]); - close(label_input_data_fd); - } - - std::chrono::duration file_writing_time = std::chrono::high_resolution_clock::now() - file_writing_timer; - std::cout << "generated " << all_labels.size() << " label-specific vector files for index building in time " - << file_writing_time.count() << "\n" - << std::endl; - - return label_id_to_orig_id; -} - -// for use on systems without writev (i.e. Windows) -template -tsl::robin_map> generate_label_specific_vector_files_compat( - path input_data_path, tsl::robin_map labels_to_number_of_points, - std::vector point_ids_to_labels, label_set all_labels) -{ - auto file_writing_timer = std::chrono::high_resolution_clock::now(); - std::ifstream input_data_stream(input_data_path); - - _u32 number_of_points, dimension; - input_data_stream.read((char *)&number_of_points, sizeof(_u32)); - input_data_stream.read((char *)&dimension, sizeof(_u32)); - const _u32 VECTOR_SIZE = dimension * sizeof(T); - if (number_of_points != point_ids_to_labels.size()) - { - std::cerr << "Error: number of points in labels file and data file differ." << std::endl; - throw; - } - - tsl::robin_map labels_to_vectors; - tsl::robin_map labels_to_curr_vector; - tsl::robin_map> label_id_to_orig_id; - - for (const auto &lbl : all_labels) - { - _u32 number_of_label_pts = labels_to_number_of_points[lbl]; - char *vectors = (char *)malloc(number_of_label_pts * VECTOR_SIZE); - if (vectors == nullptr) - { - throw; - } - labels_to_vectors[lbl] = vectors; - labels_to_curr_vector[lbl] = 0; - label_id_to_orig_id[lbl].reserve(number_of_label_pts); - } - - for (_u32 point_id = 0; point_id < number_of_points; point_id++) - { - char *curr_vector = (char *)malloc(VECTOR_SIZE); - input_data_stream.read(curr_vector, VECTOR_SIZE); - for (const auto &lbl : point_ids_to_labels[point_id]) - { - char *curr_label_vector_ptr = labels_to_vectors[lbl] + (labels_to_curr_vector[lbl] * VECTOR_SIZE); - memcpy(curr_label_vector_ptr, curr_vector, VECTOR_SIZE); - labels_to_curr_vector[lbl]++; - label_id_to_orig_id[lbl].push_back(point_id); - } - free(curr_vector); - } - - for (const auto &lbl : all_labels) - { - path curr_label_input_data_path(input_data_path + "_" + lbl); - _u32 number_of_label_pts = labels_to_number_of_points[lbl]; - - std::ofstream label_file_stream; - label_file_stream.exceptions(std::ios::badbit | std::ios::failbit); - label_file_stream.open(curr_label_input_data_path, std::ios_base::binary); - label_file_stream.write((char *)&number_of_label_pts, sizeof(_u32)); - label_file_stream.write((char *)&dimension, sizeof(_u32)); - label_file_stream.write((char *)labels_to_vectors[lbl], number_of_label_pts * VECTOR_SIZE); - - label_file_stream.close(); - free(labels_to_vectors[lbl]); - } - input_data_stream.close(); - - std::chrono::duration file_writing_time = std::chrono::high_resolution_clock::now() - file_writing_timer; - std::cout << "generated " << all_labels.size() << " label-specific vector files for index building in time " - << file_writing_time.count() << "\n" - << std::endl; - - return label_id_to_orig_id; -} - -/* - * Using passed in parameters and files generated from step 3, - * builds a vanilla diskANN index for each label. - * - * Each index is saved under the following path: - * final_index_path_prefix + "_" + label - */ -template -void generate_label_indices(path input_data_path, path final_index_path_prefix, label_set all_labels, unsigned R, - unsigned L, float alpha, unsigned num_threads) -{ - diskann::Parameters label_index_build_parameters; - label_index_build_parameters.Set("R", R); - label_index_build_parameters.Set("L", L); - label_index_build_parameters.Set("C", 750); - label_index_build_parameters.Set("Lf", 0); - label_index_build_parameters.Set("saturate_graph", 0); - label_index_build_parameters.Set("alpha", alpha); - label_index_build_parameters.Set("num_threads", num_threads); - - std::cout << "Generating indices per label..." << std::endl; - // for each label, build an index on resp. points - double total_indexing_time = 0.0, indexing_percentage = 0.0; - std::cout.setstate(std::ios_base::failbit); - diskann::cout.setstate(std::ios_base::failbit); - for (const auto &lbl : all_labels) - { - path curr_label_input_data_path(input_data_path + "_" + lbl); - path curr_label_index_path(final_index_path_prefix + "_" + lbl); - - size_t number_of_label_points, dimension; - diskann::get_bin_metadata(curr_label_input_data_path, number_of_label_points, dimension); - diskann::Index index(diskann::Metric::L2, dimension, number_of_label_points, false, false); - - auto index_build_timer = std::chrono::high_resolution_clock::now(); - index.build(curr_label_input_data_path.c_str(), number_of_label_points, label_index_build_parameters); - std::chrono::duration current_indexing_time = - std::chrono::high_resolution_clock::now() - index_build_timer; - - total_indexing_time += current_indexing_time.count(); - indexing_percentage += (1 / (double)all_labels.size()); - print_progress(indexing_percentage); - - index.save(curr_label_index_path.c_str()); - } - std::cout.clear(); - diskann::cout.clear(); - - std::cout << "\nDone. Generated per-label indices in " << total_indexing_time << " seconds\n" << std::endl; -} - -/* - * Manually loads a graph index in from a given file. - * - * Returns both the graph index and the size of the file in bytes. - */ -load_label_index_return_values load_label_index(path label_index_path, _u32 label_number_of_points) -{ - std::ifstream label_index_stream; - label_index_stream.exceptions(std::ios::badbit | std::ios::failbit); - label_index_stream.open(label_index_path, std::ios::binary); - - _u64 index_file_size, index_num_frozen_points; - _u32 index_max_observed_degree, index_entry_point; - const size_t INDEX_METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32); - label_index_stream.read((char *)&index_file_size, sizeof(_u64)); - label_index_stream.read((char *)&index_max_observed_degree, sizeof(_u32)); - label_index_stream.read((char *)&index_entry_point, sizeof(_u32)); - label_index_stream.read((char *)&index_num_frozen_points, sizeof(_u64)); - size_t bytes_read = INDEX_METADATA; - - std::vector> label_index(label_number_of_points); - _u32 nodes_read = 0; - while (bytes_read != index_file_size) - { - _u32 current_node_num_neighbors; - label_index_stream.read((char *)¤t_node_num_neighbors, sizeof(_u32)); - nodes_read++; - - std::vector<_u32> current_node_neighbors(current_node_num_neighbors); - label_index_stream.read((char *)current_node_neighbors.data(), current_node_num_neighbors * sizeof(_u32)); - label_index[nodes_read - 1].swap(current_node_neighbors); - bytes_read += sizeof(_u32) * (current_node_num_neighbors + 1); - } - - return std::make_tuple(label_index, index_file_size); -} - -/* - * Custom index save to write the in-memory index to disk. - * Also writes required files for diskANN API - - * 1. labels_to_medoids - * 2. universal_label - * 3. data (redundant for static indices) - * 4. labels (redundant for static indices) - */ -void save_full_index(path final_index_path_prefix, path input_data_path, _u64 final_index_size, - std::vector> stitched_graph, tsl::robin_map entry_points, - std::string universal_label, path label_data_path) -{ - // aux. file 1 - auto saving_index_timer = std::chrono::high_resolution_clock::now(); - std::ifstream original_label_data_stream; - original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - original_label_data_stream.open(label_data_path, std::ios::binary); - std::ofstream new_label_data_stream; - new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - new_label_data_stream.open(final_index_path_prefix + "_labels.txt", std::ios::binary); - new_label_data_stream << original_label_data_stream.rdbuf(); - original_label_data_stream.close(); - new_label_data_stream.close(); - - // aux. file 2 - std::ifstream original_input_data_stream; - original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - original_input_data_stream.open(input_data_path, std::ios::binary); - std::ofstream new_input_data_stream; - new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); - new_input_data_stream.open(final_index_path_prefix + ".data", std::ios::binary); - new_input_data_stream << original_input_data_stream.rdbuf(); - original_input_data_stream.close(); - new_input_data_stream.close(); - - // aux. file 3 - std::ofstream labels_to_medoids_writer; - labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit); - labels_to_medoids_writer.open(final_index_path_prefix + "_labels_to_medoids.txt"); - for (auto iter : entry_points) - labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl; - labels_to_medoids_writer.close(); - - // aux. file 4 (only if we're using a universal label) - if (universal_label != "") - { - std::ofstream universal_label_writer; - universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit); - universal_label_writer.open(final_index_path_prefix + "_universal_label.txt"); - universal_label_writer << universal_label << std::endl; - universal_label_writer.close(); - } - - // main index - _u64 index_num_frozen_points = 0, index_num_edges = 0; - _u32 index_max_observed_degree = 0, index_entry_point = 0; - const size_t METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32); - for (auto &point_neighbors : stitched_graph) - { - index_max_observed_degree = std::max(index_max_observed_degree, (_u32)point_neighbors.size()); - } - - std::ofstream stitched_graph_writer; - stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit); - stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary); - - stitched_graph_writer.write((char *)&final_index_size, sizeof(_u64)); - stitched_graph_writer.write((char *)&index_max_observed_degree, sizeof(_u32)); - stitched_graph_writer.write((char *)&index_entry_point, sizeof(_u32)); - stitched_graph_writer.write((char *)&index_num_frozen_points, sizeof(_u64)); - - size_t bytes_written = METADATA; - for (_u32 node_point = 0; node_point < stitched_graph.size(); node_point++) - { - _u32 current_node_num_neighbors = stitched_graph[node_point].size(); - std::vector<_u32> current_node_neighbors = stitched_graph[node_point]; - stitched_graph_writer.write((char *)¤t_node_num_neighbors, sizeof(_u32)); - bytes_written += sizeof(_u32); - for (const auto ¤t_node_neighbor : current_node_neighbors) - { - stitched_graph_writer.write((char *)¤t_node_neighbor, sizeof(_u32)); - bytes_written += sizeof(_u32); - } - index_num_edges += current_node_num_neighbors; - } - - if (bytes_written != final_index_size) - { - std::cerr << "Error: written bytes does not match allocated space" << std::endl; - throw; - } - - stitched_graph_writer.close(); - - std::chrono::duration saving_index_time = std::chrono::high_resolution_clock::now() - saving_index_timer; - std::cout << "Stitched graph written in " << saving_index_time.count() << " seconds" << std::endl; - std::cout << "Stitched graph average degree: " << ((float)index_num_edges) / ((float)(stitched_graph.size())) - << std::endl; - std::cout << "Stitched graph max degree: " << index_max_observed_degree << std::endl << std::endl; -} - -/* - * Unions the per-label graph indices together via the following policy: - * - any two nodes can only have at most one edge between them - - * - * Returns the "stitched" graph and its expected file size. - */ -template -stitch_indices_return_values stitch_label_indices( - path final_index_path_prefix, _u32 total_number_of_points, label_set all_labels, - tsl::robin_map labels_to_number_of_points, tsl::robin_map &label_entry_points, - tsl::robin_map> label_id_to_orig_id_map) -{ - size_t final_index_size = 0; - std::vector> stitched_graph(total_number_of_points); - - auto stitching_index_timer = std::chrono::high_resolution_clock::now(); - for (const auto &lbl : all_labels) - { - path curr_label_index_path(final_index_path_prefix + "_" + lbl); - std::vector> curr_label_index; - _u64 curr_label_index_size; - _u32 curr_label_entry_point; - - std::tie(curr_label_index, curr_label_index_size) = - load_label_index(curr_label_index_path, labels_to_number_of_points[lbl]); - curr_label_entry_point = random(0, curr_label_index.size()); - label_entry_points[lbl] = label_id_to_orig_id_map[lbl][curr_label_entry_point]; - - for (_u32 node_point = 0; node_point < curr_label_index.size(); node_point++) - { - _u32 original_point_id = label_id_to_orig_id_map[lbl][node_point]; - for (auto &node_neighbor : curr_label_index[node_point]) - { - _u32 original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor]; - std::vector<_u32> curr_point_neighbors = stitched_graph[original_point_id]; - if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), original_neighbor_id) == - curr_point_neighbors.end()) - { - stitched_graph[original_point_id].push_back(original_neighbor_id); - final_index_size += sizeof(_u32); - } - } - } - } - - const size_t METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32); - final_index_size += (total_number_of_points * sizeof(_u32) + METADATA); - - std::chrono::duration stitching_index_time = - std::chrono::high_resolution_clock::now() - stitching_index_timer; - std::cout << "stitched graph generated in memory in " << stitching_index_time.count() << " seconds" << std::endl; - - return std::make_tuple(stitched_graph, final_index_size); -} - -/* - * Applies the prune_neighbors function from src/index.cpp to - * every node in the stitched graph. - * - * This is an optional step, hence the saving of both the full - * and pruned graph. - */ -template -void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, path input_data_path, - std::vector> stitched_graph, unsigned stitched_R, - tsl::robin_map label_entry_points, std::string universal_label, - path label_data_path, unsigned num_threads) -{ - size_t dimension, number_of_label_points; - auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr); - auto std_cout_buffer = std::cout.rdbuf(nullptr); - auto pruning_index_timer = std::chrono::high_resolution_clock::now(); - - diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension); - diskann::Index index(diskann::Metric::L2, dimension, number_of_label_points, false, false); - - // not searching this index, set search_l to 0 - index.load(full_index_path_prefix.c_str(), num_threads, 1); - - diskann::Parameters paras; - paras.Set("R", stitched_R); - paras.Set("C", 750); // maximum candidate set size during pruning procedure - paras.Set("alpha", 1.2); - paras.Set("saturate_graph", 1); - std::cout << "parsing labels" << std::endl; - - index.prune_all_nbrs(paras); - index.save((final_index_path_prefix).c_str()); - - diskann::cout.rdbuf(diskann_cout_buffer); - std::cout.rdbuf(std_cout_buffer); - std::chrono::duration pruning_index_time = std::chrono::high_resolution_clock::now() - pruning_index_timer; - std::cout << "pruning performed in " << pruning_index_time.count() << " seconds\n" << std::endl; -} - -/* - * Delete all temporary artifacts. - * In the process of creating the stitched index, some temporary artifacts are - * created: - * 1. the separate bin files for each labels' points - * 2. the separate diskANN indices built for each label - * 3. the '.data' file created while generating the indices - */ -void clean_up_artifacts(path input_data_path, path final_index_path_prefix, label_set all_labels) -{ - for (const auto &lbl : all_labels) - { - path curr_label_input_data_path(input_data_path + "_" + lbl); - path curr_label_index_path(final_index_path_prefix + "_" + lbl); - path curr_label_index_path_data(curr_label_index_path + ".data"); - - if (std::remove(curr_label_index_path.c_str()) != 0) - throw; - if (std::remove(curr_label_input_data_path.c_str()) != 0) - throw; - if (std::remove(curr_label_index_path_data.c_str()) != 0) - throw; - } -} - -int main(int argc, char **argv) -{ - // 1. handle cmdline inputs - std::string data_type; - path input_data_path, final_index_path_prefix, label_data_path; - std::string universal_label; - unsigned num_threads, R, L, stitched_R; - float alpha; - - auto index_timer = std::chrono::high_resolution_clock::now(); - handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, label_data_path, universal_label, - num_threads, R, L, stitched_R, alpha); - - path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt"; - path labels_map_file = final_index_path_prefix + "_labels_map.txt"; - - convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label); - - // 2. parse label file and create necessary data structures - std::vector point_ids_to_labels; - tsl::robin_map labels_to_number_of_points; - label_set all_labels; - - std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) = - parse_label_file(labels_file_to_use, universal_label); - - // 3. for each label, make a separate data file - tsl::robin_map> label_id_to_orig_id_map; - _u32 total_number_of_points = point_ids_to_labels.size(); - -#ifndef _WINDOWS - if (data_type == "uint8") - label_id_to_orig_id_map = generate_label_specific_vector_files( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else if (data_type == "int8") - label_id_to_orig_id_map = generate_label_specific_vector_files( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else if (data_type == "float") - label_id_to_orig_id_map = generate_label_specific_vector_files( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else - throw; -#else - if (data_type == "uint8") - label_id_to_orig_id_map = generate_label_specific_vector_files_compat( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else if (data_type == "int8") - label_id_to_orig_id_map = generate_label_specific_vector_files_compat( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else if (data_type == "float") - label_id_to_orig_id_map = generate_label_specific_vector_files_compat( - input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels); - else - throw; -#endif - - // 4. for each created data file, create a vanilla diskANN index - if (data_type == "uint8") - generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, num_threads); - else if (data_type == "int8") - generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, num_threads); - else if (data_type == "float") - generate_label_indices(input_data_path, final_index_path_prefix, all_labels, R, L, alpha, num_threads); - else - throw; - - // 5. "stitch" the indices together - std::vector> stitched_graph; - tsl::robin_map label_entry_points; - _u64 stitched_graph_size; - - if (data_type == "uint8") - std::tie(stitched_graph, stitched_graph_size) = - stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, - labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); - else if (data_type == "int8") - std::tie(stitched_graph, stitched_graph_size) = - stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, - labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); - else if (data_type == "float") - std::tie(stitched_graph, stitched_graph_size) = - stitch_label_indices(final_index_path_prefix, total_number_of_points, all_labels, - labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map); - else - throw; - path full_index_path_prefix = final_index_path_prefix + "_full"; - // 5a. save the stitched graph to disk - save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, stitched_graph, label_entry_points, - universal_label, labels_file_to_use); - - // 6. run a prune on the stitched index, and save to disk - if (data_type == "uint8") - prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, - stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); - else if (data_type == "int8") - prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, - stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); - else if (data_type == "float") - prune_and_save(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph, - stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads); - else - throw; - - std::chrono::duration index_time = std::chrono::high_resolution_clock::now() - index_timer; - std::cout << "pruned/stitched graph generated in " << index_time.count() << " seconds" << std::endl; - - clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels); -} diff --git a/tests/index_write_parameters_builder_tests.cpp b/tests/index_write_parameters_builder_tests.cpp new file mode 100644 index 000000000..acd5e2227 --- /dev/null +++ b/tests/index_write_parameters_builder_tests.cpp @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include + +#include "parameters.h" + +BOOST_AUTO_TEST_SUITE(IndexWriteParametersBuilder_tests) + +BOOST_AUTO_TEST_CASE(test_build) +{ + uint32_t search_list_size = rand(); + uint32_t max_degree = rand(); + float alpha = (float)rand(); + uint32_t filter_list_size = rand(); + uint32_t max_occlusion_size = rand(); + uint32_t num_frozen_points = rand(); + bool saturate_graph = true; + + diskann::IndexWriteParametersBuilder builder(search_list_size, max_degree); + + builder.with_alpha(alpha) + .with_filter_list_size(filter_list_size) + .with_max_occlusion_size(max_occlusion_size) + .with_num_frozen_points(num_frozen_points) + .with_num_threads(0) + .with_saturate_graph(saturate_graph); + + { + auto parameters = builder.build(); + + BOOST_TEST(search_list_size == parameters.search_list_size); + BOOST_TEST(max_degree == parameters.max_degree); + BOOST_TEST(alpha == parameters.alpha); + BOOST_TEST(filter_list_size == parameters.filter_list_size); + BOOST_TEST(max_occlusion_size == parameters.max_occlusion_size); + BOOST_TEST(num_frozen_points == parameters.num_frozen_points); + BOOST_TEST(saturate_graph == parameters.saturate_graph); + + BOOST_TEST(parameters.num_threads > (uint32_t)0); + } + + { + uint32_t num_threads = rand() + 1; + saturate_graph = false; + builder.with_num_threads(num_threads) + .with_saturate_graph(saturate_graph); + + auto parameters = builder.build(); + + BOOST_TEST(search_list_size == parameters.search_list_size); + BOOST_TEST(max_degree == parameters.max_degree); + BOOST_TEST(alpha == parameters.alpha); + BOOST_TEST(filter_list_size == parameters.filter_list_size); + BOOST_TEST(max_occlusion_size == parameters.max_occlusion_size); + BOOST_TEST(num_frozen_points == parameters.num_frozen_points); + BOOST_TEST(saturate_graph == parameters.saturate_graph); + + BOOST_TEST(num_threads == parameters.num_threads); + } +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/tests/main.cpp b/tests/main.cpp new file mode 100644 index 000000000..53440a17a --- /dev/null +++ b/tests/main.cpp @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#define BOOST_TEST_MODULE diskann_unit_tests + +#include diff --git a/windows/packages.config.in b/windows/packages.config.in index d444e66b1..f8eecf02f 100644 --- a/windows/packages.config.in +++ b/windows/packages.config.in @@ -4,6 +4,7 @@ + diff --git a/workflows/SSD_index.md b/workflows/SSD_index.md index 1b8440ea5..f86856796 100644 --- a/workflows/SSD_index.md +++ b/workflows/SSD_index.md @@ -1,7 +1,7 @@ **Usage for SSD-based indices** =============================== -To generate an SSD-friendly index, use the `tests/build_disk_index` program. +To generate an SSD-friendly index, use the `apps/build_disk_index` program. ---------------------------------------------------------------------------- The arguments are as follows: @@ -19,7 +19,7 @@ The arguments are as follows: 11. **--build_PQ_bytes** (default is 0): Set to a positive value less than the dimensionality of the data to enable faster index build with PQ based distance comparisons. 12. **--use_opq**: use the flag to use OPQ rather than PQ compression. OPQ is more space efficient for some high dimensional datasets, but also needs a bit more build time. -To search the SSD-index, use the `tests/search_disk_index` program. +To search the SSD-index, use the `apps/search_disk_index` program. ------------------------------------------------------------------- The arguments are as follows: @@ -31,7 +31,7 @@ The arguments are as follows: 5. **-T (--num_threads)** (default is to get_omp_num_procs()): The number of threads used for searching. Threads run in parallel and one thread handles one query at a time. More threads will result in higher aggregate query throughput, but will also use more IOs/second across the system, which may lead to higher per-query latency. So find the balance depending on the maximum number of IOPs supported by the SSD. 6. **-W (--beamwidth)** (default is 2): The beamwidth to be used for search. This is the maximum number of IO requests each query will issue per iteration of search code. Larger beamwidth will result in fewer IO round-trips per query, but might result in slightly higher total number of IO requests to SSD per query. For the highest query throughput with a fixed SSD IOps rating, use `W=1`. For best latency, use `W=4,8` or higher complexity search. Specifying 0 will optimize the beamwidth depending on the number of threads performing search, but will involve some tuning overhead. 7. **--query_file**: The queries to be searched on in same binary file format as the data file in arg (2) above. The query file must be the same type as argument (1). -8. **--gt_file**: The ground truth file for the queries in arg (7) and data file used in index construction. The binary file must start with *n*, the number of queries (4 bytes), followed by *d*, the number of ground truth elements per query (4 bytes), followed by `n*d` entries per query representing the d closest IDs per query in integer format, followed by `n*d` entries representing the corresponding distances (float). Total file size is `8 + 4*n*d + 4*n*d` bytes. The groundtruth file, if not available, can be calculated using the program `tests/utils/compute_groundtruth`. Use "null" if you do not have this file and if you do not want to compute recall. +8. **--gt_file**: The ground truth file for the queries in arg (7) and data file used in index construction. The binary file must start with *n*, the number of queries (4 bytes), followed by *d*, the number of ground truth elements per query (4 bytes), followed by `n*d` entries per query representing the d closest IDs per query in integer format, followed by `n*d` entries representing the corresponding distances (float). Total file size is `8 + 4*n*d + 4*n*d` bytes. The groundtruth file, if not available, can be calculated using the program `apps/utils/compute_groundtruth`. Use "null" if you do not have this file and if you do not want to compute recall. 9. **K**: search for *K* neighbors and measure *K*-recall@*K*, meaning the intersection between the retrieved top-*K* nearest neighbors and ground truth *K* nearest neighbors. 10. **result_output_prefix**: Search results will be stored in files with specified prefix, in bin format. 11. **-L (--search_list)**: A list of search_list sizes to perform search with. Larger parameters will result in slower latencies, but higher accuracies. Must be atleast the value of *K* in arg (9). @@ -48,16 +48,16 @@ mkdir -p DiskANN/build/data && cd DiskANN/build/data wget ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz tar -xf sift.tar.gz cd .. -./tests/utils/fvecs_to_bin float data/sift/sift_learn.fvecs data/sift/sift_learn.fbin -./tests/utils/fvecs_to_bin float data/sift/sift_query.fvecs data/sift/sift_query.fbin +./apps/utils/fvecs_to_bin float data/sift/sift_learn.fvecs data/sift/sift_learn.fbin +./apps/utils/fvecs_to_bin float data/sift/sift_query.fvecs data/sift/sift_query.fbin ``` Now build and search the index and measure the recall using ground truth computed using brutefoce. ```bash -./tests/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file data/sift/sift_learn.fbin --query_file data/sift/sift_query.fbin --gt_file data/sift/sift_query_learn_gt100 --K 100 +./apps/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file data/sift/sift_learn.fbin --query_file data/sift/sift_query.fbin --gt_file data/sift/sift_query_learn_gt100 --K 100 # Using 0.003GB search memory budget for 100K vectors implies 32 byte PQ compression -./tests/build_disk_index --data_type float --dist_fn l2 --data_path data/sift/sift_learn.fbin --index_path_prefix data/sift/disk_index_sift_learn_R32_L50_A1.2 -R 32 -L50 -B 0.003 -M 1 - ./tests/search_disk_index --data_type float --dist_fn l2 --index_path_prefix data/sift/disk_index_sift_learn_R32_L50_A1.2 --query_file data/sift/sift_query.fbin --gt_file data/sift/sift_query_learn_gt100 -K 10 -L 10 20 30 40 50 100 --result_path data/sift/res --num_nodes_to_cache 10000 +./apps/build_disk_index --data_type float --dist_fn l2 --data_path data/sift/sift_learn.fbin --index_path_prefix data/sift/disk_index_sift_learn_R32_L50_A1.2 -R 32 -L50 -B 0.003 -M 1 + ./apps/search_disk_index --data_type float --dist_fn l2 --index_path_prefix data/sift/disk_index_sift_learn_R32_L50_A1.2 --query_file data/sift/sift_query.fbin --gt_file data/sift/sift_query_learn_gt100 -K 10 -L 10 20 30 40 50 100 --result_path data/sift/res --num_nodes_to_cache 10000 ``` The search might be slower on machine with remote SSDs. The output lists the quer throughput, the mean and 99.9pc latency in microseconds and mean number of 4KB IOs to disk for each `L` parameter provided. diff --git a/workflows/dynamic_index.md b/workflows/dynamic_index.md index e5ad0f0d9..ca3bfbf68 100644 --- a/workflows/dynamic_index.md +++ b/workflows/dynamic_index.md @@ -9,20 +9,20 @@ While eager deletes can be supported by DiskANN, `lazy_deletes` are the preferre A sequence of lazy deletions must be followed by an invocation of the `consolidate_deletes` method that frees up slots in the index and edits the graph to maintain good recall. -The program `tests/test_insert_deletes_consolidate` demonstrates this functionality. It allows the user to specify which points from the data file will be used +The program `apps/test_insert_deletes_consolidate` demonstrates this functionality. It allows the user to specify which points from the data file will be used to initially build the index, which points will be deleted from the index, and which points will be inserted into the index. Insertions, searches and lazy deletions can be performed concurrently. Conslolidation of lazy deletes can be performed synchnronously or concurrently with insertions and deletions. When modifying the index sequentially, the user has the ability to take *snapshots*-- that is, save the index to memory for every *m* insertions or deletions instead of only at the end of the build. -The program `tests/test_streaming_scenario` simulates a scenario where the index actively maintains a sliding window of active points from a larger dataset. +The program `apps/test_streaming_scenario` simulates a scenario where the index actively maintains a sliding window of active points from a larger dataset. The program starts with an index build over the first `active_window` set of points from a data file. The program then simultaneously inserts newer points drawn from the file and deletes older points from the index in chunks of `consolidate_interval` points so that the number of active points in the index is approximately `active_window`. It terminates when the end of data file is reached, and the final index has `active_window + consolidate_interval` number of points. -`tests/test_insert_deletes_consolidate` to try inserting, lazy deletes and consolidate_delete +`apps/test_insert_deletes_consolidate` to try inserting, lazy deletes and consolidate_delete --------------------------------------------------------------------------------------------- The arguments are as follows: @@ -44,7 +44,7 @@ The arguments are as follows: 15. **--start_point_norm**: Set the starting node to a random point on a sphere of this radius. A reasonable choice is to set this to the average norm of the data set. Use when starting an index with zero points. 16. **--do_concurrent** (default false): whether to perform conslidate_deletes and other updates concurrently or sequentially. If concurrent is specified, half the threads are used for insertions and half the threads are used for processing deletes. Note that insertions are performed before deletions if this flag is set to false, so in this case is possible to delete more than beginning_index_size points. -`tests/test_streaming_scenario` to try inserting, lazy deletes and consolidate_delete +`apps/test_streaming_scenario` to try inserting, lazy deletes and consolidate_delete --------------------------------------------------------------------------------------------- The arguments are as follows: @@ -65,7 +65,7 @@ The arguments are as follows: -To search the generated index, use the `tests/search_memory_index` program: +To search the generated index, use the `apps/search_memory_index` program: --------------------------------------------------------------------------- @@ -76,7 +76,7 @@ The arguments are as follows: 3. **memory_index_path**: index built above in argument (4). 4. **T**: The number of threads used for searching. Threads run in parallel and one thread handles one query at a time. More threads will result in higher aggregate query throughput, but may lead to higher per-query latency, especially if the DRAM bandwidth is a bottleneck. So find the balance depending on throughput and latency required for your application. 5. **query_bin**: The queries to be searched on in same binary file format as the data file (ii) above. The query file must be the same type as in argument (1). -6. **truthset.bin**: The ground truth file for the queries in arg (7) and data file used in index construction. The binary file must start with *n*, the number of queries (4 bytes), followed by *d*, the number of ground truth elements per query (4 bytes), followed by `n*d` entries per query representing the d closest IDs per query in integer format, followed by `n*d` entries representing the corresponding distances (float). Total file size is `8 + 4*n*d + 4*n*d` bytes. The groundtruth file, if not available, can be calculated using the program `tests/utils/compute_groundtruth`. Use "null" if you do not have this file and if you do not want to compute recall. +6. **truthset.bin**: The ground truth file for the queries in arg (7) and data file used in index construction. The binary file must start with *n*, the number of queries (4 bytes), followed by *d*, the number of ground truth elements per query (4 bytes), followed by `n*d` entries per query representing the d closest IDs per query in integer format, followed by `n*d` entries representing the corresponding distances (float). Total file size is `8 + 4*n*d + 4*n*d` bytes. The groundtruth file, if not available, can be calculated using the program `apps/utils/compute_groundtruth`. Use "null" if you do not have this file and if you do not want to compute recall. 7. **K**: search for *K* neighbors and measure *K*-recall@*K*, meaning the intersection between the retrieved top-*K* nearest neighbors and ground truth *K* nearest neighbors. 8. **result_output_prefix**: search results will be stored in files, one per L value (see next arg), with specified prefix, in binary format. 9. **-L (--search_list)**: A list of search_list sizes to perform search with. Larger parameters will result in slower latencies, but higher accuracies. Must be at least the value of *K* in (7). @@ -95,8 +95,8 @@ mkdir -p DiskANN/build/data && cd DiskANN/build/data wget ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz tar -xf sift.tar.gz cd .. -./tests/utils/fvecs_to_bin float data/sift/sift_learn.fvecs data/sift/sift_learn.fbin -./tests/utils/fvecs_to_bin float data/sift/sift_query.fvecs data/sift/sift_query.fbin +./apps/utils/fvecs_to_bin float data/sift/sift_learn.fvecs data/sift/sift_learn.fbin +./apps/utils/fvecs_to_bin float data/sift/sift_query.fvecs data/sift/sift_query.fbin ``` The example below tests the following scenario: using a file with 100000 points, the index is incrementally constructed point by point. After the first 50000 ponts are inserted, another concurrent job deletes the first 25000 points from the index and consolidates the index (edit the graph and cleans up resources). At the same time an additional 25000 points (i.e. points 50001 to 75000) are concurrently inserted into the index. Note that the index should be built **before** calculating the ground truth, since the memory index returns the slice of the sift100K dataset that was used to build the final graph (that is, points 25001-75000 in the original index). @@ -115,15 +115,15 @@ thr=64 index=${index_prefix}.after-concurrent-delete-del${deletes}-${inserts} gt_file=data/sift/gt100_learn-conc-${deletes}-${inserts} - ~/DiskANN/build/tests/test_insert_deletes_consolidate --data_type ${type} --dist_fn l2 --data_path ${data} --index_path_prefix ${index_prefix} -R 64 -L 300 --alpha 1.2 -T ${thr} --points_to_skip 0 --max_points_to_insert ${inserts} --beginning_index_size ${begin} --points_per_checkpoint ${pts_per_checkpoint} --checkpoints_per_snapshot 0 --points_to_delete_from_beginning ${deletes} --start_deletes_after ${deletes_after} --do_concurrent true; + ~/DiskANN/build/apps/test_insert_deletes_consolidate --data_type ${type} --dist_fn l2 --data_path ${data} --index_path_prefix ${index_prefix} -R 64 -L 300 --alpha 1.2 -T ${thr} --points_to_skip 0 --max_points_to_insert ${inserts} --beginning_index_size ${begin} --points_per_checkpoint ${pts_per_checkpoint} --checkpoints_per_snapshot 0 --points_to_delete_from_beginning ${deletes} --start_deletes_after ${deletes_after} --do_concurrent true; - ~/DiskANN/build/tests/utils/compute_groundtruth --data_type ${type} --dist_fn l2 --base_file ${index}.data --query_file ${query} --K 100 --gt_file ${gt_file} --tags_file ${index}.tags + ~/DiskANN/build/apps/utils/compute_groundtruth --data_type ${type} --dist_fn l2 --base_file ${index}.data --query_file ${query} --K 100 --gt_file ${gt_file} --tags_file ${index}.tags -~/DiskANN/build/tests/search_memory_index --data_type ${type} --dist_fn l2 --index_path_prefix ${index} --result_path ${result} --query_file ${query} --gt_file ${gt_file} -K 10 -L 20 40 60 80 100 -T ${thr} --dynamic true --tags 1 +~/DiskANN/build/apps/search_memory_index --data_type ${type} --dist_fn l2 --index_path_prefix ${index} --result_path ${result} --query_file ${query} --gt_file ${gt_file} -K 10 -L 20 40 60 80 100 -T ${thr} --dynamic true --tags 1 ``` The example below tests the following scenario: using a file with 100000 points, insert 10000 points at a time. After the first 40000 -are inserted, start deleteing the first 10000 points while inserting points 40000--50000. Then delete points 10000--20000 while inserting +are inserted, start deleting the first 10000 points while inserting points 40000--50000. Then delete points 10000--20000 while inserting points 50000--60000 and so until the index is left with points 60000-100000. ``` @@ -140,7 +140,7 @@ cons_int=10000 index=${index_prefix}.after-streaming-act${active}-cons${cons_int}-max${inserts} gt=data/sift/gt100_learn-act${active}-cons${cons_int}-max${inserts} -./tests/test_streaming_scenario --data_type ${type} --dist_fn l2 --data_path ${data} --index_path_prefix ${index_prefix} -R 64 -L 600 --alpha 1.2 --insert_threads ${ins_thr} --consolidate_threads ${cons_thr} --max_points_to_insert ${inserts} --active_window ${active} --consolidate_interval ${cons_int} --start_point_norm 508; -./tests/utils/compute_groundtruth --data_type ${type} --dist_fn l2 --base_file ${index}.data --query_file ${query} --K 100 --gt_file ${gt} --tags_file ${index}.tags -./tests/search_memory_index --data_type ${type} --dist_fn l2 --index_path_prefix ${index} --result_path ${result} --query_file ${query} --gt_file ${gt} -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1 +./apps/test_streaming_scenario --data_type ${type} --dist_fn l2 --data_path ${data} --index_path_prefix ${index_prefix} -R 64 -L 600 --alpha 1.2 --insert_threads ${ins_thr} --consolidate_threads ${cons_thr} --max_points_to_insert ${inserts} --active_window ${active} --consolidate_interval ${cons_int} --start_point_norm 508; +./apps/utils/compute_groundtruth --data_type ${type} --dist_fn l2 --base_file ${index}.data --query_file ${query} --K 100 --gt_file ${gt} --tags_file ${index}.tags +./apps/search_memory_index --data_type ${type} --dist_fn l2 --index_path_prefix ${index} --result_path ${result} --query_file ${query} --gt_file ${gt} -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1 ``` \ No newline at end of file diff --git a/workflows/filtered_in_memory.md b/workflows/filtered_in_memory.md index c3b652685..fe34b80f8 100644 --- a/workflows/filtered_in_memory.md +++ b/workflows/filtered_in_memory.md @@ -1,7 +1,7 @@ **Usage for filtered indices** ================================ ## Building a filtered Index -DiskANN provides two algorithms for building an index with filters support: filtered-vamana and stitched-vamana. Here, we describe the parameters for building both. `tests/build_memory_index.cpp` and `tests/build_stitched_index.cpp` are respectively used to build each kind of index. +DiskANN provides two algorithms for building an index with filters support: filtered-vamana and stitched-vamana. Here, we describe the parameters for building both. `apps/build_memory_index.cpp` and `apps/build_stitched_index.cpp` are respectively used to build each kind of index. ### 1. filtered-vamana @@ -32,7 +32,7 @@ DiskANN provides two algorithms for building an index with filters support: filt 10. **`--Stitched_R`**: Once all sub-indices are "stitched" together, we prune the resulting graph down to the degree given by this parameter. ## Computing a groundtruth file for a filtered index -In order to evaluate the performance of our algorithms, we can compare its results (i.e. the top `k` neighbors found for each query) against the results found by an exact nearest neighbor search. We provide the program `tests/utils/compute_groundtruth.cpp` to provide the results for the latter: +In order to evaluate the performance of our algorithms, we can compare its results (i.e. the top `k` neighbors found for each query) against the results found by an exact nearest neighbor search. We provide the program `apps/utils/compute_groundtruth.cpp` to provide the results for the latter: 1. **`--data_type`** The type of dataset you built an index with. float(32 bit), signed int8 and unsigned uint8 are supported. 2. **`--dist_fn`**: There are two distance functions supported: l2 and mips. @@ -48,7 +48,7 @@ In order to evaluate the performance of our algorithms, we can compare its resul ## Searching a Filtered Index -Searching a filtered index uses the `tests/search_memory_index.cpp`: +Searching a filtered index uses the `apps/search_memory_index.cpp`: 1. **`--data_type`**: The type of dataset you built the index on. float(32 bit), signed int8 and unsigned uint8 are supported. Use the same data type as in arg (1) above used in building the index. 2. **`--dist_fn`**: There are two distance functions supported: l2 and mips. There is an additional *fast_l2* implementation that could provide faster results for small (about a million-sized) indices. Use the same distance as in arg (2) above used in building the index. Note that stitched-vamana only supports l2. @@ -64,23 +64,23 @@ Searching a filtered index uses the `tests/search_memory_index.cpp`: Example with SIFT10K: -------------------- We demonstrate how to work through this pipeline using the SIFT10K dataset (http://corpus-texmex.irisa.fr/). Before starting, make sure you have compiled diskANN according to the instructions in the README and can see the following binaries (paths with respect to repository root): -- `build/tests/utils/compute_groundtruth` -- `build/tests/utils/fvecs_to_bin` -- `build/tests/build_memory_index` -- `build/tests/build_stitched_index` -- `build/tests/search_memory_index` +- `build/apps/utils/compute_groundtruth` +- `build/apps/utils/fvecs_to_bin` +- `build/apps/build_memory_index` +- `build/apps/build_stitched_index` +- `build/apps/search_memory_index` Now, download the base and query set and convert the data to binary format: ```bash wget ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz tar -zxvf siftsmall.tar.gz -build/tests/utils/fvecs_to_bin float siftsmall/siftsmall_base.fvecs siftsmall/siftsmall_base.bin -build/tests/utils/fvecs_to_bin float siftsmall/siftsmall_query.fvecs siftsmall/siftsmall_query.bin +build/apps/utils/fvecs_to_bin float siftsmall/siftsmall_base.fvecs siftsmall/siftsmall_base.bin +build/apps/utils/fvecs_to_bin float siftsmall/siftsmall_query.fvecs siftsmall/siftsmall_query.bin ``` We now need to make label file for our vectors. For convenience, we've included a synthetic label generator through which we can generate label file as follow ```bash - build/tests/utils/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file ./rand_labels_50_10K.txt --distribution_type zipf + build/apps/utils/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file ./rand_labels_50_10K.txt --distribution_type zipf ``` Note : `distribution_type` can be `rand` or `zipf` @@ -88,18 +88,18 @@ This will genearate label file with 10000 data points with 50 distinct labels, r Label count for each unique label in the generated label file can be printed with help of following command ```bash - build/tests/utils/stats_label_data.exe --labels_file ./rand_labels_50_10K.txt --universal_label 0 + build/apps/utils/stats_label_data.exe --labels_file ./rand_labels_50_10K.txt --universal_label 0 ``` Note that neither approach is designed for use with random synthetic labels, which will lead to unpredictable accuracy at search time. Now build and search the index and measure the recall using ground truth computed using bruteforce. We search for results with the filter 35. ```bash -build/tests/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file siftsmall/siftsmall_base.bin --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --K 100 --label_file ./rand_labels_50_10K.txt --filter_label 35 --universal_label 0 -build/tests/build_memory_index --data_type float --dist_fn l2 --data_path siftsmall/siftsmall_base.bin --index_path_prefix siftsmall/siftsmall_R32_L50_filtered_index -R 32 --FilteredLbuild 50 --alpha 1.2 --label_file ./rand_labels_50_10K.txt --universal_label 0 -build/tests/build_stitched_index --data_type float --data_path siftsmall/siftsmall_base.bin --index_path_prefix siftsmall/siftsmall_R20_L40_SR32_stitched_index -R 20 -L 40 --stitched_R 32 --alpha 1.2 --label_file ./rand_labels_50_10K.txt --universal_label 0 -build/tests/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R20_L40_SR32_filtered_index --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --filter_label 35 -K 10 -L 10 20 30 40 50 100 --result_path siftsmall/filtered_search_results -build/tests/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R20_L40_SR32_stitched_index --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --filter_label 35 -K 10 -L 10 20 30 40 50 100 --result_path siftsmall/stitched_search_results +build/apps/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file siftsmall/siftsmall_base.bin --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --K 100 --label_file ./rand_labels_50_10K.txt --filter_label 35 --universal_label 0 +build/apps/build_memory_index --data_type float --dist_fn l2 --data_path siftsmall/siftsmall_base.bin --index_path_prefix siftsmall/siftsmall_R32_L50_filtered_index -R 32 --FilteredLbuild 50 --alpha 1.2 --label_file ./rand_labels_50_10K.txt --universal_label 0 +build/apps/build_stitched_index --data_type float --data_path siftsmall/siftsmall_base.bin --index_path_prefix siftsmall/siftsmall_R20_L40_SR32_stitched_index -R 20 -L 40 --stitched_R 32 --alpha 1.2 --label_file ./rand_labels_50_10K.txt --universal_label 0 +build/apps/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R20_L40_SR32_filtered_index --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --filter_label 35 -K 10 -L 10 20 30 40 50 100 --result_path siftsmall/filtered_search_results +build/apps/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R20_L40_SR32_stitched_index --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --filter_label 35 -K 10 -L 10 20 30 40 50 100 --result_path siftsmall/stitched_search_results ``` The output of both searches is listed below. The throughput (Queries/sec) as well as mean and 99.9 latency in microseconds for each `L` parameter provided. (Measured on a physical machine with a Intel(R) Xeon(R) W-2145 CPU and 64 GB RAM) diff --git a/workflows/filtered_ssd_index.md b/workflows/filtered_ssd_index.md index 1b602985e..272100e6d 100644 --- a/workflows/filtered_ssd_index.md +++ b/workflows/filtered_ssd_index.md @@ -1,7 +1,7 @@ **Usage for filtered indices** ================================ -To generate an SSD-friendly index, use the `tests/build_disk_index` program. +To generate an SSD-friendly index, use the `apps/build_disk_index` program. ---------------------------------------------------------------------------- ## Building a SSD based filtered Index @@ -27,7 +27,7 @@ To generate an SSD-friendly index, use the `tests/build_disk_index` program. ## Computing a groundtruth file for a filtered index -In order to evaluate the performance of our algorithms, we can compare its results (i.e. the top `k` neighbors found for each query) against the results found by an exact nearest neighbor search. We provide the program `tests/utils/compute_groundtruth.cpp` to provide the results for the latter: +In order to evaluate the performance of our algorithms, we can compare its results (i.e. the top `k` neighbors found for each query) against the results found by an exact nearest neighbor search. We provide the program `apps/utils/compute_groundtruth.cpp` to provide the results for the latter: 1. **`--data_type`** The type of dataset you built an index with. float(32 bit), signed int8 and unsigned uint8 are supported. 2. **`--dist_fn`**: There are two distance functions supported: l2 and mips. @@ -41,7 +41,7 @@ In order to evaluate the performance of our algorithms, we can compare its resul ## Searching a Filtered Index -Searching a filtered index uses the `tests/search_disk_index.cpp`: +Searching a filtered index uses the `apps/search_disk_index.cpp`: 1. **--data_type**: The type of dataset you wish to build an index on. float(32 bit), signed int8 and unsigned uint8 are supported. Use the same data type as in arg (1) above used in building the index. 2. **--dist_fn**: There are two distance functions supported: minimum Euclidean distance (l2) and maximum inner product (mips). Use the same distance as in arg (2) above used in building the index. @@ -50,7 +50,7 @@ Searching a filtered index uses the `tests/search_disk_index.cpp`: 5. **-T (--num_threads)** (default is to get_omp_num_procs()): The number of threads used for searching. Threads run in parallel and one thread handles one query at a time. More threads will result in higher aggregate query throughput, but will also use more IOs/second across the system, which may lead to higher per-query latency. So find the balance depending on the maximum number of IOPs supported by the SSD. 6. **-W (--beamwidth)** (default is 2): The beamwidth to be used for search. This is the maximum number of IO requests each query will issue per iteration of search code. Larger beamwidth will result in fewer IO round-trips per query, but might result in slightly higher total number of IO requests to SSD per query. For the highest query throughput with a fixed SSD IOps rating, use `W=1`. For best latency, use `W=4,8` or higher complexity search. Specifying 0 will optimize the beamwidth depending on the number of threads performing search, but will involve some tuning overhead. 7. **--query_file**: The queries to be searched on in same binary file format as the data file in arg (2) above. The query file must be the same type as argument (1). -8. **--gt_file**: The ground truth file for the queries in arg (7) and data file used in index construction. The binary file must start with *n*, the number of queries (4 bytes), followed by *d*, the number of ground truth elements per query (4 bytes), followed by `n*d` entries per query representing the d closest IDs per query in integer format, followed by `n*d` entries representing the corresponding distances (float). Total file size is `8 + 4*n*d + 4*n*d` bytes. The groundtruth file, if not available, can be calculated using the program `tests/utils/compute_groundtruth`. Use "null" if you do not have this file and if you do not want to compute recall. +8. **--gt_file**: The ground truth file for the queries in arg (7) and data file used in index construction. The binary file must start with *n*, the number of queries (4 bytes), followed by *d*, the number of ground truth elements per query (4 bytes), followed by `n*d` entries per query representing the d closest IDs per query in integer format, followed by `n*d` entries representing the corresponding distances (float). Total file size is `8 + 4*n*d + 4*n*d` bytes. The groundtruth file, if not available, can be calculated using the program `apps/utils/compute_groundtruth`. Use "null" if you do not have this file and if you do not want to compute recall. 9. **-K**: search for *K* neighbors and measure *K*-recall@*K*, meaning the intersection between the retrieved top-*K* nearest neighbors and ground truth *K* nearest neighbors. 10. **--result_path**: Search results will be stored in files with specified prefix, in bin format. 11. **-L (--search_list)**: A list of search_list sizes to perform search with. Larger parameters will result in slower latencies, but higher accuracies. Must be atleast the value of *K* in arg (9). @@ -60,22 +60,22 @@ Searching a filtered index uses the `tests/search_disk_index.cpp`: Example with SIFT10K: -------------------- We demonstrate how to work through this pipeline using the SIFT10K dataset (http://corpus-texmex.irisa.fr/). Before starting, make sure you have compiled diskANN according to the instructions in the README and can see the following binaries (paths with respect to repository root): -- `build/tests/utils/compute_groundtruth` -- `build/tests/utils/fvecs_to_bin` -- `build/tests/build_disk_index` -- `build/tests/search_disk_index` +- `build/apps/utils/compute_groundtruth` +- `build/apps/utils/fvecs_to_bin` +- `build/apps/build_disk_index` +- `build/apps/search_disk_index` Now, download the base and query set and convert the data to binary format: ```bash wget ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz tar -zxvf siftsmall.tar.gz -build/tests/utils/fvecs_to_bin float siftsmall/siftsmall_base.fvecs siftsmall/siftsmall_base.bin -build/tests/utils/fvecs_to_bin float siftsmall/siftsmall_query.fvecs siftsmall/siftsmall_query.bin +build/apps/utils/fvecs_to_bin float siftsmall/siftsmall_base.fvecs siftsmall/siftsmall_base.bin +build/apps/utils/fvecs_to_bin float siftsmall/siftsmall_query.fvecs siftsmall/siftsmall_query.bin ``` We now need to make label file for our vectors. For convenience, we've included a synthetic label generator through which we can generate label file as follow ```bash - build/tests/utils/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file ./rand_labels_50_10K.txt --distribution_type zipf + build/apps/utils/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file ./rand_labels_50_10K.txt --distribution_type zipf ``` Note : `distribution_type` can be `rand` or `zipf` @@ -83,9 +83,9 @@ This will genearate label file with 10000 data points with 50 distinct labels, r Now build and search the index and measure the recall using ground truth computed using bruteforce. We search for results with the filter 35. ```bash -build/tests/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file siftsmall/siftsmall_base.bin --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall_gt_35.bin --K 100 --label_file rand_labels_50_10K.txt --filter_label 35 --universal_label 0 -build/tests/build_disk_index --data_type float --dist_fn l2 --data_path siftsmall/siftsmall_base.bin --index_path_prefix data/sift/siftsmall_R32_L50_filtered -R 32 --FilteredLbuild 50 -B 1 -M 1 --label_file rand_labels_50_10K.txt --universal_label 0 -F 0 -build/tests/search_disk_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R32_L50_filtered --result_path siftsmall/search_35 --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall_gt_35.bin -K 10 -L 10 20 30 40 50 100 --filter_label 35 -W 4 -T 8 +build/apps/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file siftsmall/siftsmall_base.bin --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall_gt_35.bin --K 100 --label_file rand_labels_50_10K.txt --filter_label 35 --universal_label 0 +build/apps/build_disk_index --data_type float --dist_fn l2 --data_path siftsmall/siftsmall_base.bin --index_path_prefix data/sift/siftsmall_R32_L50_filtered -R 32 --FilteredLbuild 50 -B 1 -M 1 --label_file rand_labels_50_10K.txt --universal_label 0 -F 0 +build/apps/search_disk_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R32_L50_filtered --result_path siftsmall/search_35 --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall_gt_35.bin -K 10 -L 10 20 30 40 50 100 --filter_label 35 -W 4 -T 8 ``` The output of both searches is listed below. The throughput (Queries/sec) as well as mean and 99.9 latency in microseconds for each `L` parameter provided. (Measured on a physical machine with a 11th Gen Intel(R) Core(TM) i7-1185G7 CPU and 32 GB RAM) diff --git a/workflows/in_memory_index.md b/workflows/in_memory_index.md index cc59a2c91..6d783204a 100644 --- a/workflows/in_memory_index.md +++ b/workflows/in_memory_index.md @@ -1,7 +1,7 @@ **Usage for in-memory indices** ================================ -To generate index, use the `tests/build_memory_index` program. +To generate index, use the `apps/build_memory_index` program. -------------------------------------------------------------- The arguments are as follows: @@ -18,7 +18,7 @@ The arguments are as follows: 10.**--use_opq**: use the flag to use OPQ rather than PQ compression. OPQ is more space efficient for some high dimensional datasets, but also needs a bit more build time. -To search the generated index, use the `tests/search_memory_index` program: +To search the generated index, use the `apps/search_memory_index` program: --------------------------------------------------------------------------- @@ -29,7 +29,7 @@ The arguments are as follows: 3. **memory_index_path**: index built above in argument (4). 4. **T**: The number of threads used for searching. Threads run in parallel and one thread handles one query at a time. More threads will result in higher aggregate query throughput, but may lead to higher per-query latency, especially if the DRAM bandwidth is a bottleneck. So find the balance depending on throughput and latency required for your application. 5. **query_bin**: The queries to be searched on in same binary file format as the data file (ii) above. The query file must be the same type as in argument (1). -6. **truthset.bin**: The ground truth file for the queries in arg (7) and data file used in index construction. The binary file must start with *n*, the number of queries (4 bytes), followed by *d*, the number of ground truth elements per query (4 bytes), followed by `n*d` entries per query representing the d closest IDs per query in integer format, followed by `n*d` entries representing the corresponding distances (float). Total file size is `8 + 4*n*d + 4*n*d` bytes. The groundtruth file, if not available, can be calculated using the program `tests/utils/compute_groundtruth`. Use "null" if you do not have this file and if you do not want to compute recall. +6. **truthset.bin**: The ground truth file for the queries in arg (7) and data file used in index construction. The binary file must start with *n*, the number of queries (4 bytes), followed by *d*, the number of ground truth elements per query (4 bytes), followed by `n*d` entries per query representing the d closest IDs per query in integer format, followed by `n*d` entries representing the corresponding distances (float). Total file size is `8 + 4*n*d + 4*n*d` bytes. The groundtruth file, if not available, can be calculated using the program `apps/utils/compute_groundtruth`. Use "null" if you do not have this file and if you do not want to compute recall. 7. **K**: search for *K* neighbors and measure *K*-recall@*K*, meaning the intersection between the retrieved top-*K* nearest neighbors and ground truth *K* nearest neighbors. 8. **result_output_prefix**: search results will be stored in files, one per L value (see next arg), with specified prefix, in binary format. 9. **-L (--search_list)**: A list of search_list sizes to perform search with. Larger parameters will result in slower latencies, but higher accuracies. Must be atleast the value of *K* in (7). @@ -46,15 +46,15 @@ mkdir -p DiskANN/build/data && cd DiskANN/build/data wget ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz tar -xf sift.tar.gz cd .. -./tests/utils/fvecs_to_bin float data/sift/sift_learn.fvecs data/sift/sift_learn.fbin -./tests/utils/fvecs_to_bin float data/sift/sift_query.fvecs data/sift/sift_query.fbin +./apps/utils/fvecs_to_bin float data/sift/sift_learn.fvecs data/sift/sift_learn.fbin +./apps/utils/fvecs_to_bin float data/sift/sift_query.fvecs data/sift/sift_query.fbin ``` Now build and search the index and measure the recall using ground truth computed using brutefoce. ```bash -./tests/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file data/sift/sift_learn.fbin --query_file data/sift/sift_query.fbin --gt_file data/sift/sift_query_learn_gt100 --K 100 -./tests/build_memory_index --data_type float --dist_fn l2 --data_path data/sift/sift_learn.fbin --index_path_prefix data/sift/index_sift_learn_R32_L50_A1.2 -R 32 -L 50 --alpha 1.2 - ./tests/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/index_sift_learn_R32_L50_A1.2 --query_file data/sift/sift_query.fbin --gt_file data/sift/sift_query_learn_gt100 -K 10 -L 10 20 30 40 50 100 --result_path data/sift/res +./apps/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file data/sift/sift_learn.fbin --query_file data/sift/sift_query.fbin --gt_file data/sift/sift_query_learn_gt100 --K 100 +./apps/build_memory_index --data_type float --dist_fn l2 --data_path data/sift/sift_learn.fbin --index_path_prefix data/sift/index_sift_learn_R32_L50_A1.2 -R 32 -L 50 --alpha 1.2 + ./apps/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/index_sift_learn_R32_L50_A1.2 --query_file data/sift/sift_query.fbin --gt_file data/sift/sift_query_learn_gt100 -K 10 -L 10 20 30 40 50 100 --result_path data/sift/res ``` diff --git a/workflows/rest_api.md b/workflows/rest_api.md index bae965996..2a88d721d 100644 --- a/workflows/rest_api.md +++ b/workflows/rest_api.md @@ -20,10 +20,10 @@ Follow the instructions for [building an in-memory DiskANN index](/workflows/in_ ```bash # To start serving an in-memory index -./tests/restapi/inmem_server --address --data_type --data_file --index_path_prefix --num_threads --l_search --tags_file [tags_file] +./apps/restapi/inmem_server --address --data_type --data_file --index_path_prefix --num_threads --l_search --tags_file [tags_file] # To start serving an SSD-based index. -./tests/restapi/ssd_server --address --data_type --index_path_prefix --num_nodes_to_cache --num_threads --tags_file [tags_file] +./apps/restapi/ssd_server --address --data_type --index_path_prefix --num_nodes_to_cache --num_threads --tags_file [tags_file] ``` The `data_type` and the `data_file` should be the same as those used in the construction of the index. The server returns the ids and distances of the closests vector in the index to the query. The ids are implicitly defined by the order of the vector in the data file. If you wish to assign a different numbering or GUID or URL to the vectors in the index, use the optional `tags_file`. This should be a file which lists a "tag" string for each vector in the index. The file should contain one string per line. The string on the line `n` is considered the tag corresponding to the vector `n` in the index (in the implicit order defined in the `data_file`).