Skip to content

Commit

Permalink
Implement index tags
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Dec 19, 2024
1 parent 67704d5 commit 57cfac0
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 3 deletions.
28 changes: 25 additions & 3 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ConfigureIndexRequestSpecPod,
DeletionProtection,
IndexSpec,
IndexTags,
ServerlessSpec as ServerlessSpecModel,
PodSpec as PodSpecModel,
PodSpecMetadataConfig,
Expand Down Expand Up @@ -173,6 +174,7 @@ def create_index(
timeout: Optional[int] = None,
deletion_protection: Optional[Literal["enabled", "disabled"]] = "disabled",
vector_type: Optional[Literal["dense", "sparse"]] = "dense",
tags: Optional[Dict[str, str]] = None,
):
api_instance = self.index_api

Expand All @@ -184,6 +186,11 @@ def create_index(
else:
raise ValueError("deletion_protection must be either 'enabled' or 'disabled'")

if tags is None:
tags_obj = None
else:
tags_obj = IndexTags(**tags)

index_spec = self._parse_index_spec(spec)

api_instance.create_index(
Expand All @@ -196,6 +203,7 @@ def create_index(
("spec", index_spec),
("deletion_protection", dp),
("vector_type", vector_type),
("tags", tags_obj),
]
)
)
Expand Down Expand Up @@ -280,17 +288,31 @@ def configure_index(
replicas: Optional[int] = None,
pod_type: Optional[str] = None,
deletion_protection: Optional[Literal["enabled", "disabled"]] = None,
tags: Optional[Dict[str, str]] = None,
):
api_instance = self.index_api
description = self.describe_index(name=name)

if deletion_protection is None:
description = self.describe_index(name=name)
dp = DeletionProtection(description.deletion_protection)
elif deletion_protection in ["enabled", "disabled"]:
dp = DeletionProtection(deletion_protection)
else:
raise ValueError("deletion_protection must be either 'enabled' or 'disabled'")

fetched_tags = description.tags
if fetched_tags is None:
starting_tags = {}
else:
starting_tags = fetched_tags.to_dict()

if tags is None:
# Do not modify tags if none are provided
tags = starting_tags
else:
# Merge existing tags with new tags
tags = {**starting_tags, **tags}

pod_config_args: Dict[str, Any] = {}
if pod_type:
pod_config_args.update(pod_type=pod_type)
Expand All @@ -299,9 +321,9 @@ def configure_index(

if pod_config_args != {}:
spec = ConfigureIndexRequestSpec(pod=ConfigureIndexRequestSpecPod(**pod_config_args))
req = ConfigureIndexRequest(deletion_protection=dp, spec=spec)
req = ConfigureIndexRequest(deletion_protection=dp, spec=spec, tags=IndexTags(**tags))
else:
req = ConfigureIndexRequest(deletion_protection=dp)
req = ConfigureIndexRequest(deletion_protection=dp, tags=IndexTags(**tags))

api_instance.configure_index(name, configure_index_request=req)

Expand Down
10 changes: 10 additions & 0 deletions tests/integration/control/pod/test_create_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class TestCreateIndexPods:
def test_create_with_optional_tags(self, client, create_index_params):
index_name = create_index_params["name"]
tags = {"foo": "FOO", "bar": "BAR"}
create_index_params["tags"] = create_index_params

client.create_index(**create_index_params)

desc = client.describe_index(name=index_name)
assert desc.tags.to_dict() == tags
40 changes: 40 additions & 0 deletions tests/integration/control/serverless/test_configure_index_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest


class TestIndexTags:
def test_index_tags_none_by_default(self, client, ready_sl_index):
client.describe_index(name=ready_sl_index)
assert client.describe_index(name=ready_sl_index).tags is None

def test_add_index_tags(self, client, ready_sl_index):
client.configure_index(name=ready_sl_index, tags={"foo": "FOO", "bar": "BAR"})
assert client.describe_index(name=ready_sl_index).tags.to_dict() == {
"foo": "FOO",
"bar": "BAR",
}

def test_remove_tags_by_setting_empty_value_for_key(self, client, ready_sl_index):
client.configure_index(name=ready_sl_index, tags={"foo": "FOO", "bar": "BAR"})
client.configure_index(name=ready_sl_index, tags={})
assert client.describe_index(name=ready_sl_index).tags.to_dict() == {
"foo": "FOO",
"bar": "BAR",
}

client.configure_index(name=ready_sl_index, tags={"foo": ""})
assert client.describe_index(name=ready_sl_index).tags.to_dict() == {"bar": "BAR"}

def test_merge_new_tags_with_existing_tags(self, client, ready_sl_index):
client.configure_index(name=ready_sl_index, tags={"foo": "FOO", "bar": "BAR"})
client.configure_index(name=ready_sl_index, tags={"baz": "BAZ"})
assert client.describe_index(name=ready_sl_index).tags.to_dict() == {
"foo": "FOO",
"bar": "BAR",
"baz": "BAZ",
}

@pytest.mark.skip(reason="Backend bug filed")
def test_remove_all_tags(self, client, ready_sl_index):
client.configure_index(name=ready_sl_index, tags={"foo": "FOO", "bar": "BAR"})
client.configure_index(name=ready_sl_index, tags={"foo": "", "bar": ""})
assert client.describe_index(name=ready_sl_index).tags is None
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,10 @@ def test_create_dense_index_with_metric(self, client, create_sl_index_params, me
desc = client.describe_index(create_sl_index_params["name"])
assert desc.metric == metric
assert desc.vector_type == "dense"

def test_create_with_optional_tags(self, client, create_sl_index_params):
tags = {"foo": "FOO", "bar": "BAR"}
create_sl_index_params["tags"] = tags
client.create_index(**create_sl_index_params)
desc = client.describe_index(create_sl_index_params["name"])
assert desc.tags.to_dict() == tags

0 comments on commit 57cfac0

Please sign in to comment.