Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Multivector #113

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ classifiers =
include_package_data = True
python_requires = >=3.9
install_requires =
weaviate-client>=4.10.0
weaviate-client>=4.11.0
click==8.1.7
semver>=3.0.2
numpy>=1.24.0
Expand Down
17 changes: 17 additions & 0 deletions weaviate_cli/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def create() -> None:
"hnsw_bq",
"hnsw_sq",
"hnsw_acorn",
"hnsw_multivector",
"flat_bq",
]
),
Expand Down Expand Up @@ -110,11 +111,18 @@ def create() -> None:
"ollama",
"cohere",
"jinaai",
"jinaai_colbert",
"none_multi_vector",
"weaviate",
]
),
help="Vectorizer to use.",
)
@click.option(
"--multi_vector",
is_flag=True,
help="Enable multi-vector (default: False).",
)
@click.option(
"--replication_deletion_strategy",
default=CreateCollectionDefaults.replication_deletion_strategy,
Expand All @@ -139,6 +147,7 @@ def create_collection_cli(
shards: int,
vectorizer: Optional[str],
replication_deletion_strategy: str,
multi_vector: bool,
) -> None:
"""Create a collection in Weaviate."""

Expand All @@ -161,6 +170,7 @@ def create_collection_cli(
shards=shards,
vectorizer=vectorizer,
replication_deletion_strategy=replication_deletion_strategy,
multi_vector=multi_vector,
)
except Exception as e:
click.echo(f"Error: {e}")
Expand Down Expand Up @@ -311,6 +321,11 @@ def create_backup_cli(ctx, backend, backup_id, include, exclude, wait, cpu_for_b
default=None,
help="UUID of the object to be used when the data is randomized. It requires --limit=1 and --randomize to be enabled.",
)
@click.option(
"--multi_vector",
is_flag=True,
help="Enable multi-vector (default: False).",
)
@click.pass_context
def create_data_cli(
ctx,
Expand All @@ -321,6 +336,7 @@ def create_data_cli(
auto_tenants,
vector_dimensions,
uuid,
multi_vector,
):
"""Ingest data into a collection in Weaviate."""

Expand Down Expand Up @@ -351,6 +367,7 @@ def create_data_cli(
auto_tenants=auto_tenants,
vector_dimensions=vector_dimensions,
uuid=uuid,
multi_vector=multi_vector,
)
except Exception as e:
click.echo(f"Error: {e}")
Expand Down
3 changes: 3 additions & 0 deletions weaviate_cli/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class CreateCollectionDefaults:
shards: int = 1
vectorizer: Optional[str] = None
replication_deletion_strategy: str = "delete_on_conflict"
multi_vector: bool = False
named_vector: Optional[str] = "default"


@dataclass
Expand Down Expand Up @@ -67,6 +69,7 @@ class CreateDataDefaults:
randomize: bool = False
auto_tenants: int = 0
vector_dimensions: int = 1536
multi_vector: bool = False


@dataclass
Expand Down
29 changes: 26 additions & 3 deletions weaviate_cli/managers/collection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def create_collection(
replication_deletion_strategy: Optional[
str
] = CreateCollectionDefaults.replication_deletion_strategy,
multi_vector: bool = CreateCollectionDefaults.multi_vector,
named_vector: Optional[str] = CreateCollectionDefaults.named_vector,
) -> None:

if self.client.collections.exists(collection):
Expand Down Expand Up @@ -143,6 +145,9 @@ def create_collection(
"hnsw_acorn": wvc.Configure.VectorIndex.hnsw(
filter_strategy=VectorFilterStrategy.ACORN
),
"hnsw_multivector": wvc.Configure.VectorIndex.hnsw(
multi_vector=wvc.Configure.VectorIndex.MultiVector.multi_vector(),
),
# Should fail at the moment as Flat and PQ are not compatible
"flat_pq": wvc.Configure.VectorIndex.flat(
quantizer=wvc.Configure.VectorIndex.Quantizer.pq()
Expand All @@ -169,6 +174,16 @@ def create_collection(
"cohere": wvc.Configure.Vectorizer.text2vec_cohere(),
"jinaai": wvc.Configure.Vectorizer.text2vec_jinaai(),
"weaviate": wvc.Configure.Vectorizer.text2vec_weaviate(),
"jinaai_colbert": wvc.Configure.NamedVectors.text2colbert_jinaai(
name=named_vector,
vector_index_config=vector_index_map[vector_index],
),
"none_multi_vector": [
wvc.Configure.NamedVectors.none(
name=named_vector,
vector_index_config=vector_index_map[vector_index],
)
],
}

inverted_index_map: Dict[str, wvc.InvertedIndexConfig] = {
Expand Down Expand Up @@ -219,7 +234,9 @@ def create_collection(
try:
self.client.collections.create(
name=collection,
vector_index_config=vector_index_map[vector_index],
vector_index_config=(
vector_index_map[vector_index] if not multi_vector else None
),
inverted_index_config=(
inverted_index_map[inverted_index] if inverted_index else None
),
Expand All @@ -240,8 +257,14 @@ def create_collection(
auto_tenant_creation=auto_tenant_creation,
auto_tenant_activation=auto_tenant_activation,
),
vectorizer_config=(vectorizer_map[vectorizer] if vectorizer else None),
properties=properties if not force_auto_schema else None,
vectorizer_config=(
vectorizer_map[vectorizer]
if vectorizer and not multi_vector
else None
),
properties=(
properties if not force_auto_schema and not multi_vector else None
),
)
except Exception as e:

Expand Down
23 changes: 19 additions & 4 deletions weaviate_cli/managers/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from weaviate.collections import Collection
from datetime import datetime, timedelta
from weaviate_cli.defaults import (
CreateCollectionDefaults,
CreateDataDefaults,
QueryDataDefaults,
UpdateDataDefaults,
Expand Down Expand Up @@ -147,6 +148,7 @@ def __ingest_data(
vector_dimensions: Optional[int] = 1536,
uuid: Optional[str] = None,
named_vectors: Optional[List[str]] = None,
multi_vector: bool = False,
) -> Collection:
if randomize:
counter = 0
Expand All @@ -173,11 +175,21 @@ def __ingest_data(
if named_vectors is None:
vector = (2 * np.random.rand(vector_dimensions) - 1).tolist()
batch.add_object(properties=obj, uuid=uuid, vector=vector)

else:
vector = {
name: (2 * np.random.rand(vector_dimensions) - 1).tolist()
for name in named_vectors
}
if multi_vector:
vector = {
CreateCollectionDefaults.named_vector: [
(2 * np.random.rand(vector_dimensions) - 1).tolist()
]
}
else:
vector = {
name: (
2 * np.random.rand(vector_dimensions) - 1
).tolist()
for name in named_vectors
}
batch.add_object(properties=obj, uuid=uuid, vector=vector)
counter += 1

Expand Down Expand Up @@ -207,6 +219,7 @@ def create_data(
vector_dimensions: Optional[int] = CreateDataDefaults.vector_dimensions,
uuid: Optional[str] = None,
named_vectors: Optional[List[str]] = None,
multi_vector: bool = CreateDataDefaults.multi_vector,
) -> Collection:

if not self.client.collections.exists(collection):
Expand Down Expand Up @@ -261,6 +274,7 @@ def create_data(
vector_dimensions,
uuid,
named_vectors,
multi_vector,
)
else:
click.echo(f"Processing tenant '{tenant}'")
Expand All @@ -272,6 +286,7 @@ def create_data(
vector_dimensions,
uuid,
named_vectors,
multi_vector,
)

if len(collection) != limit:
Expand Down
Loading