diff --git a/setup.cfg b/setup.cfg index afbe4bc..84ff72e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ classifiers = include_package_data = True python_requires = >=3.9 install_requires = - weaviate-client>=4.10.0 + weaviate-client>=4.11.0 click==8.1.7 semver>=3.0.2 numpy>=1.24.0 diff --git a/weaviate_cli/commands/create.py b/weaviate_cli/commands/create.py index fa8eac4..eb219e2 100644 --- a/weaviate_cli/commands/create.py +++ b/weaviate_cli/commands/create.py @@ -60,6 +60,7 @@ def create() -> None: "hnsw_bq", "hnsw_sq", "hnsw_acorn", + "hnsw_multivector", "flat_bq", ] ), @@ -110,11 +111,18 @@ def create() -> None: "ollama", "cohere", "jinaai", + "jinaai_colbert", + "none_multi_vector", "weaviate", ] ), help="Vectorizer to use.", ) +@click.option( + "--multi_vector", + is_flag=True, + help="Enable multi-vector (default: False).", +) @click.option( "--replication_deletion_strategy", default=CreateCollectionDefaults.replication_deletion_strategy, @@ -139,6 +147,7 @@ def create_collection_cli( shards: int, vectorizer: Optional[str], replication_deletion_strategy: str, + multi_vector: bool, ) -> None: """Create a collection in Weaviate.""" @@ -161,6 +170,7 @@ def create_collection_cli( shards=shards, vectorizer=vectorizer, replication_deletion_strategy=replication_deletion_strategy, + multi_vector=multi_vector, ) except Exception as e: click.echo(f"Error: {e}") @@ -311,6 +321,11 @@ def create_backup_cli(ctx, backend, backup_id, include, exclude, wait, cpu_for_b default=None, help="UUID of the object to be used when the data is randomized. It requires --limit=1 and --randomize to be enabled.", ) +@click.option( + "--multi_vector", + is_flag=True, + help="Enable multi-vector (default: False).", +) @click.pass_context def create_data_cli( ctx, @@ -321,6 +336,7 @@ def create_data_cli( auto_tenants, vector_dimensions, uuid, + multi_vector, ): """Ingest data into a collection in Weaviate.""" @@ -351,6 +367,7 @@ def create_data_cli( auto_tenants=auto_tenants, vector_dimensions=vector_dimensions, uuid=uuid, + multi_vector=multi_vector, ) except Exception as e: click.echo(f"Error: {e}") diff --git a/weaviate_cli/defaults.py b/weaviate_cli/defaults.py index c1e61b4..edfc688 100644 --- a/weaviate_cli/defaults.py +++ b/weaviate_cli/defaults.py @@ -39,6 +39,8 @@ class CreateCollectionDefaults: shards: int = 1 vectorizer: Optional[str] = None replication_deletion_strategy: str = "delete_on_conflict" + multi_vector: bool = False + named_vector: Optional[str] = "default" @dataclass @@ -67,6 +69,7 @@ class CreateDataDefaults: randomize: bool = False auto_tenants: int = 0 vector_dimensions: int = 1536 + multi_vector: bool = False @dataclass diff --git a/weaviate_cli/managers/collection_manager.py b/weaviate_cli/managers/collection_manager.py index 45f91d1..ca58f95 100644 --- a/weaviate_cli/managers/collection_manager.py +++ b/weaviate_cli/managers/collection_manager.py @@ -112,6 +112,8 @@ def create_collection( replication_deletion_strategy: Optional[ str ] = CreateCollectionDefaults.replication_deletion_strategy, + multi_vector: bool = CreateCollectionDefaults.multi_vector, + named_vector: Optional[str] = CreateCollectionDefaults.named_vector, ) -> None: if self.client.collections.exists(collection): @@ -143,6 +145,9 @@ def create_collection( "hnsw_acorn": wvc.Configure.VectorIndex.hnsw( filter_strategy=VectorFilterStrategy.ACORN ), + "hnsw_multivector": wvc.Configure.VectorIndex.hnsw( + multi_vector=wvc.Configure.VectorIndex.MultiVector.multi_vector(), + ), # Should fail at the moment as Flat and PQ are not compatible "flat_pq": wvc.Configure.VectorIndex.flat( quantizer=wvc.Configure.VectorIndex.Quantizer.pq() @@ -169,6 +174,16 @@ def create_collection( "cohere": wvc.Configure.Vectorizer.text2vec_cohere(), "jinaai": wvc.Configure.Vectorizer.text2vec_jinaai(), "weaviate": wvc.Configure.Vectorizer.text2vec_weaviate(), + "jinaai_colbert": wvc.Configure.NamedVectors.text2colbert_jinaai( + name=named_vector, + vector_index_config=vector_index_map[vector_index], + ), + "none_multi_vector": [ + wvc.Configure.NamedVectors.none( + name=named_vector, + vector_index_config=vector_index_map[vector_index], + ) + ], } inverted_index_map: Dict[str, wvc.InvertedIndexConfig] = { @@ -219,7 +234,9 @@ def create_collection( try: self.client.collections.create( name=collection, - vector_index_config=vector_index_map[vector_index], + vector_index_config=( + vector_index_map[vector_index] if not multi_vector else None + ), inverted_index_config=( inverted_index_map[inverted_index] if inverted_index else None ), @@ -240,8 +257,14 @@ def create_collection( auto_tenant_creation=auto_tenant_creation, auto_tenant_activation=auto_tenant_activation, ), - vectorizer_config=(vectorizer_map[vectorizer] if vectorizer else None), - properties=properties if not force_auto_schema else None, + vectorizer_config=( + vectorizer_map[vectorizer] + if vectorizer and not multi_vector + else None + ), + properties=( + properties if not force_auto_schema and not multi_vector else None + ), ) except Exception as e: diff --git a/weaviate_cli/managers/data_manager.py b/weaviate_cli/managers/data_manager.py index 4ba247b..d709f3e 100644 --- a/weaviate_cli/managers/data_manager.py +++ b/weaviate_cli/managers/data_manager.py @@ -13,6 +13,7 @@ from weaviate.collections import Collection from datetime import datetime, timedelta from weaviate_cli.defaults import ( + CreateCollectionDefaults, CreateDataDefaults, QueryDataDefaults, UpdateDataDefaults, @@ -147,6 +148,7 @@ def __ingest_data( vector_dimensions: Optional[int] = 1536, uuid: Optional[str] = None, named_vectors: Optional[List[str]] = None, + multi_vector: bool = False, ) -> Collection: if randomize: counter = 0 @@ -173,11 +175,21 @@ def __ingest_data( if named_vectors is None: vector = (2 * np.random.rand(vector_dimensions) - 1).tolist() batch.add_object(properties=obj, uuid=uuid, vector=vector) + else: - vector = { - name: (2 * np.random.rand(vector_dimensions) - 1).tolist() - for name in named_vectors - } + if multi_vector: + vector = { + CreateCollectionDefaults.named_vector: [ + (2 * np.random.rand(vector_dimensions) - 1).tolist() + ] + } + else: + vector = { + name: ( + 2 * np.random.rand(vector_dimensions) - 1 + ).tolist() + for name in named_vectors + } batch.add_object(properties=obj, uuid=uuid, vector=vector) counter += 1 @@ -207,6 +219,7 @@ def create_data( vector_dimensions: Optional[int] = CreateDataDefaults.vector_dimensions, uuid: Optional[str] = None, named_vectors: Optional[List[str]] = None, + multi_vector: bool = CreateDataDefaults.multi_vector, ) -> Collection: if not self.client.collections.exists(collection): @@ -261,6 +274,7 @@ def create_data( vector_dimensions, uuid, named_vectors, + multi_vector, ) else: click.echo(f"Processing tenant '{tenant}'") @@ -272,6 +286,7 @@ def create_data( vector_dimensions, uuid, named_vectors, + multi_vector, ) if len(collection) != limit: