-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a CI job to generate backwards compatibility test indexes and a p…
…ython unit test to query them (#206)
- Loading branch information
1 parent
ba91a70
commit 7b4374a
Showing
6 changed files
with
163 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,58 @@ on: | |
- '*wheel*' # must quote since "*" is a YAML reserved character; we want a string | ||
|
||
jobs: | ||
generate_backwards_compatibility_data: | ||
name: Generate Backwards Compatibility Data | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v3 | ||
|
||
# Based on https://github.com/TileDB-Inc/conda-forge-nightly-controller/blob/51519a0f8340b32cf737fcb59b76c6a91c42dc47/.github/workflows/activity.yml#L19C10-L19C10 | ||
- name: Setup git | ||
run: | | ||
git config user.name "GitHub Actions" | ||
git config user.email "[email protected]" | ||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: "3.9" | ||
|
||
- name: Print Python version | ||
run: | | ||
which python | ||
which pip | ||
python --version | ||
- name: Build Indexes | ||
run: | | ||
# Get the release tag. | ||
release_tag=$(git describe --tags --abbrev=0) | ||
echo $release_tag | ||
# Install dependencies. | ||
cd apis/python && pip install . && cd ../.. | ||
# Generate data. | ||
python backwards-compatibility-data/generate_data.py $release_tag | ||
# Push this data to a new branch and create a PR from it. | ||
git fetch | ||
branch_name="update-backwards-compatibility-data-${release_tag}" | ||
echo $branch_name | ||
git checkout -b "$branch_name" | ||
git add backwards-compatibility-data/data/ | ||
git commit -m "[automated] Update backwards-compatibility-data for release $release_tag" | ||
git push origin "$branch_name" | ||
gh pr create --base main --head "$branch_name" --title "[automated] Update backwards-compatibility-data for release $release_tag" | ||
env: | ||
GH_TOKEN: ${{ github.token }} | ||
|
||
build_wheels: | ||
name: Build wheels on ${{ matrix.os }} | ||
# TODO(paris): Add this back once generate_backwards_compatibility_data is confirmed to work. | ||
# needs: generate_backwards_compatibility_data | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
matrix: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from common import * | ||
|
||
from tiledb.vector_search.flat_index import FlatIndex | ||
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex | ||
from tiledb.vector_search.utils import load_fvecs | ||
|
||
MINIMUM_ACCURACY = 0.85 | ||
|
||
def test_query_old_indices(): | ||
''' | ||
Tests that current code can query indices which were written to disk by old code. | ||
''' | ||
backwards_compatibility_path = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'backwards-compatibility-data') | ||
datasets_path = os.path.join(backwards_compatibility_path, 'data') | ||
base = load_fvecs(os.path.join(backwards_compatibility_path, 'siftmicro_base.fvecs')) | ||
query_indices = [0, 3, 4, 8, 10, 19, 28, 31, 39, 40, 41, 47, 49, 50, 56, 64, 68, 70, 71, 79, 82, 89, 90, 94] | ||
queries = base[query_indices] | ||
|
||
for directory_name in os.listdir(datasets_path): | ||
version_path = os.path.join(datasets_path, directory_name) | ||
if not os.path.isdir(version_path): | ||
continue | ||
|
||
for index_name in os.listdir(version_path): | ||
index_uri = os.path.join(version_path, index_name) | ||
if not os.path.isdir(index_uri): | ||
continue | ||
|
||
if "ivf_flat" in index_name: | ||
index = IVFFlatIndex(uri=index_uri) | ||
elif "flat" in index_name: | ||
index = FlatIndex(uri=index_uri) | ||
else: | ||
assert False, f"Unknown index name: {index_name}" | ||
|
||
result_d, result_i = index.query(queries, k=1) | ||
assert query_indices == result_i.flatten().tolist() | ||
assert result_d.flatten().tolist() == [0 for _ in range(len(query_indices))] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
### What | ||
This folder contains test indices built using different versions of TileDB-Vector-Search. It is used to test the ability of the latest version of TileDB-Vector-Search to load and query arrays built by previous versions. | ||
|
||
### Usage | ||
To generate new data, run: | ||
```bash | ||
cd apis/python | ||
pip install . | ||
cd ../.. | ||
python generate_data.py my_version | ||
``` | ||
This will build new indexes and save them to `backwards-compatibility-data/data/my_version`. | ||
|
||
To run the backwards compability test: | ||
```bash | ||
cd apis/python | ||
pip install ".[test]" | ||
pytest test/test_backwards_compatibility.py -s | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
### What | ||
Holds test indices built using different versions of TileDB-Vector-Search. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import os | ||
import shutil | ||
|
||
from tiledb.vector_search.ingestion import ingest | ||
from tiledb.vector_search.utils import load_fvecs, write_fvecs | ||
|
||
def create_sift_micro(): | ||
''' | ||
Create a smaller version of the base SIFT 10K dataset (http://corpus-texmex.irisa.fr). You | ||
don't need to run this again, but it's saved here just in case. To query an index built with | ||
this data just select vectors from this file as the query vectors. | ||
''' | ||
script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
base_uri = os.path.join(script_dir, "..", "apis", "python", "test", "data", "siftsmall", "siftsmall_base.fvecs") | ||
write_fvecs(os.path.join(script_dir, "siftmicro_base.fvecs"), load_fvecs(base_uri)[:100]) | ||
|
||
def generate_release_data(version): | ||
script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
# Create the new release directory. | ||
release_dir = os.path.join(script_dir, "data", version) | ||
shutil.rmtree(release_dir, ignore_errors=True) | ||
os.makedirs(release_dir, exist_ok=True) | ||
|
||
# Get the data we'll use to generate the index. | ||
base_uri = os.path.join(script_dir, "siftmicro_base.fvecs") | ||
base = load_fvecs(base_uri) | ||
indices = [0, 3, 4, 8, 10, 19, 28, 31, 39, 40, 41, 47, 49, 50, 56, 64, 68, 70, 71, 79, 82, 89, 90, 94] | ||
queries = base[indices] | ||
|
||
# Generate each index and query to make sure it works before we write it. | ||
index_types = ["FLAT", "IVF_FLAT"] | ||
data_types = ["float32", "uint8"] | ||
for index_type in index_types: | ||
for data_type in data_types: | ||
index_uri = f"{release_dir}/{index_type.lower()}_{data_type}" | ||
print(f"Creating index at {index_uri}") | ||
index = ingest( | ||
index_type=index_type, | ||
index_uri=index_uri, | ||
input_vectors=base.astype(data_type), | ||
) | ||
|
||
result_d, result_i = index.query(queries, k=1) | ||
assert indices == result_i.flatten().tolist() | ||
assert result_d.flatten().tolist() == [0 for _ in range(len(indices))] | ||
|
||
if __name__ == "__main__": | ||
import argparse | ||
p = argparse.ArgumentParser() | ||
p.add_argument("version", help="The name of the of the TileDB-Vector-Search version which we are creating indices for.") | ||
args = p.parse_args() | ||
print(f"Building indexes for version {args.version}") | ||
generate_release_data(args.version) |
Binary file not shown.