Skip to content

Commit

Permalink
update from_dict for v4 API
Browse files Browse the repository at this point in the history
  • Loading branch information
hsm207 committed Feb 28, 2024
1 parent 4a31050 commit 928c88b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import weaviate
from weaviate.collections.classes.internal import Object
from weaviate.config import AdditionalConfig, Config, ConnectionConfig
from weaviate.config import AdditionalConfig # , Config, ConnectionConfig
from weaviate.embedded import EmbeddedOptions
from weaviate.util import generate_uuid5

Expand Down Expand Up @@ -155,17 +155,17 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore":
if (timeout_config := data["init_parameters"].get("timeout_config")) is not None:
data["init_parameters"]["timeout_config"] = (
tuple(timeout_config) if isinstance(timeout_config, list) else timeout_config
)
# if (timeout_config := data["init_parameters"].get("timeout_config")) is not None:
# data["init_parameters"]["timeout_config"] = (
# tuple(timeout_config) if isinstance(timeout_config, list) else timeout_config
# )
if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None:
data["init_parameters"]["auth_client_secret"] = AuthCredentials.from_dict(auth_client_secret)
if (embedded_options := data["init_parameters"].get("embedded_options")) is not None:
data["init_parameters"]["embedded_options"] = EmbeddedOptions(**embedded_options)
if (additional_config := data["init_parameters"].get("additional_config")) is not None:
additional_config["connection_config"] = ConnectionConfig(**additional_config["connection_config"])
data["init_parameters"]["additional_config"] = Config(**additional_config)
# additional_config["connection_config"] = ConnectionConfig(**additional_config["connection_config"])
data["init_parameters"]["additional_config"] = AdditionalConfig(**additional_config)
return default_from_dict(
cls,
data,
Expand Down
24 changes: 11 additions & 13 deletions integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,7 @@ def test_from_dict(self, _mock_weaviate, monkeypatch):
"api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"}
},
},
"timeout_config": [10, 60],
"proxies": {"http": "http://proxy:1234"},
"trust_env": False,
"additional_headers": {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
"startup_period": 5,
"embedded_options": {
"persistence_data_path": DEFAULT_PERSISTENCE_DATA_PATH,
"binary_path": DEFAULT_BINARY_PATH,
Expand All @@ -285,11 +281,13 @@ def test_from_dict(self, _mock_weaviate, monkeypatch):
"grpc_port": DEFAULT_GRPC_PORT,
},
"additional_config": {
"grpc_port_experimental": 12345,
"connection_config": {
"connection": {
"session_pool_connections": 20,
"session_pool_maxsize": 20,
},
"proxies": {"http": "http://proxy:1234"},
"timeout": [10, 60],
"trust_env": False,
},
},
}
Expand All @@ -309,21 +307,21 @@ def test_from_dict(self, _mock_weaviate, monkeypatch):
],
}
assert document_store._auth_client_secret == AuthApiKey()
assert document_store._timeout_config == (10, 60)
assert document_store._proxies == {"http": "http://proxy:1234"}
assert not document_store._trust_env
assert document_store._additional_config.timeout == (10, 60)
assert document_store._additional_config.proxies == {"http": "http://proxy:1234"}
assert not document_store._additional_config.trust_env
assert document_store._additional_headers == {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}
assert document_store._startup_period == 5
# assert document_store._startup_period == 5
assert document_store._embedded_options.persistence_data_path == DEFAULT_PERSISTENCE_DATA_PATH
assert document_store._embedded_options.binary_path == DEFAULT_BINARY_PATH
assert document_store._embedded_options.version == "1.23.0"
assert document_store._embedded_options.port == DEFAULT_PORT
assert document_store._embedded_options.hostname == "127.0.0.1"
assert document_store._embedded_options.additional_env_vars is None
assert document_store._embedded_options.grpc_port == DEFAULT_GRPC_PORT
assert document_store._additional_config.grpc_port_experimental == 12345
assert document_store._additional_config.connection_config.session_pool_connections == 20
assert document_store._additional_config.connection_config.session_pool_maxsize == 20
# assert document_store._additional_config.grpc_port_experimental == 12345
assert document_store._additional_config.connection.session_pool_connections == 20
assert document_store._additional_config.connection.session_pool_maxsize == 20

def test_to_data_object(self, document_store, test_files_path):
doc = Document(content="test doc")
Expand Down

0 comments on commit 928c88b

Please sign in to comment.