Skip to content

Commit

Permalink
makes weaviate running
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Jun 12, 2024
1 parent 1f17a44 commit 71e418b
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 101 deletions.
20 changes: 20 additions & 0 deletions dlt/common/normalizers/naming/sql_cs_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any, Sequence

from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention


class NamingConvention(BaseNamingConvention):
PATH_SEPARATOR = "__"

_CLEANUP_TABLE = str.maketrans(".\n\r'\"▶", "______")

def normalize_identifier(self, identifier: str) -> str:
identifier = super().normalize_identifier(identifier)
norm_identifier = identifier.translate(self._CLEANUP_TABLE)
return self.shorten_identifier(norm_identifier, identifier, self.max_length)

def make_path(self, *identifiers: Any) -> str:
return self.PATH_SEPARATOR.join(filter(lambda x: x.strip(), identifiers))

def break_path(self, path: str) -> Sequence[str]:
return [ident for ident in path.split(self.PATH_SEPARATOR) if ident.strip()]
17 changes: 9 additions & 8 deletions dlt/destinations/impl/weaviate/weaviate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@
get_columns_names_with_prop,
loads_table,
normalize_table_identifiers,
pipeline_state_table,
version_table,
)
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync
from dlt.common.data_types import TDataType
from dlt.common.storages import FileStorage

from dlt.destinations.impl.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT
Expand All @@ -49,6 +47,7 @@
from dlt.destinations.impl.weaviate.configuration import WeaviateClientConfiguration
from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict, WeaviateGrpcError
from dlt.destinations.type_mapping import TypeMapper
from dlt.destinations.utils import get_pipeline_state_query_columns


NON_VECTORIZED_CLASS = {
Expand Down Expand Up @@ -251,11 +250,13 @@ def __init__(
self.version_collection_properties = list(version_table_["columns"].keys())
loads_table_ = normalize_table_identifiers(loads_table(), schema.naming)
self.loads_collection_properties = list(loads_table_["columns"].keys())
state_table_ = normalize_table_identifiers(pipeline_state_table(), schema.naming)
state_table_ = normalize_table_identifiers(
get_pipeline_state_query_columns(), schema.naming
)
self.pipeline_state_properties = list(state_table_["columns"].keys())

self.config: WeaviateClientConfiguration = config
self.db_client = self.create_db_client(config)
self.db_client: weaviate.Client = None

self._vectorizer_config = {
"vectorizer": config.vectorizer,
Expand Down Expand Up @@ -529,7 +530,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
if len(state_records) == 0:
return None
for state in state_records:
load_id = state["_dlt_load_id"]
load_id = state[p_dlt_load_id]
load_records = self.get_records(
self.schema.loads_table_name,
where={
Expand All @@ -543,7 +544,6 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
# if there is a load for this state which was successful, return the state
if len(load_records):
state["dlt_load_id"] = state.pop(p_dlt_load_id)
state.pop("version_hash")
return StateInfo(**state)

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
Expand Down Expand Up @@ -582,7 +582,6 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI
return None

@wrap_weaviate_error
# @wrap_grpc_error
def get_records(
self,
table_name: str,
Expand Down Expand Up @@ -697,6 +696,7 @@ def complete_load(self, load_id: str) -> None:
self.create_object(properties, self.schema.loads_table_name)

def __enter__(self) -> "WeaviateClient":
self.db_client = self.create_db_client(self.config)
return self

def __exit__(
Expand All @@ -705,7 +705,8 @@ def __exit__(
exc_val: BaseException,
exc_tb: TracebackType,
) -> None:
pass
if self.db_client:
self.db_client = None

def _update_schema_in_storage(self, schema: Schema) -> None:
schema_str = json.dumps(schema.to_dict())
Expand Down
11 changes: 10 additions & 1 deletion dlt/destinations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from dlt.common import logger
from dlt.common.schema import Schema
from dlt.common.schema.exceptions import SchemaCorruptedException
from dlt.common.schema.typing import MERGE_STRATEGIES
from dlt.common.schema.typing import MERGE_STRATEGIES, TTableSchema
from dlt.common.schema.utils import (
get_columns_names_with_prop,
get_first_column_name_with_prop,
has_column_with_prop,
pipeline_state_table,
)
from typing import Any, cast, Tuple, Dict, Type

Expand Down Expand Up @@ -51,6 +52,14 @@ def parse_db_data_type_str_with_precision(db_type: str) -> Tuple[str, Optional[i
return db_type, None, None


def get_pipeline_state_query_columns() -> TTableSchema:
"""We get definition of pipeline state table without columns we do not need for the query"""
state_table = pipeline_state_table()
# we do not need version_hash to be backward compatible as long as we can
state_table["columns"].pop("version_hash")
return state_table


def verify_sql_job_client_schema(schema: Schema, warnings: bool = True) -> List[Exception]:
log = logger.warning if warnings else logger.info
# collect all exceptions to show all problems in the schema
Expand Down
2 changes: 1 addition & 1 deletion docs/website/docs/dlt-ecosystem/destinations/weaviate.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ it will be normalized to:
so your best course of action is to clean up the data yourself before loading and use the default naming convention. Nevertheless, you can configure the alternative in `config.toml`:
```toml
[schema]
naming="dlt.destinations.weaviate.impl.ci_naming"
naming="dlt.destinations.impl.weaviate.ci_naming"
```

## Additional destination options
Expand Down
File renamed without changes.
38 changes: 12 additions & 26 deletions tests/common/schema/test_normalize_identifiers.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,22 @@
from copy import deepcopy
import os
from typing import Callable, List, Sequence, cast
from typing import Callable
import pytest

from dlt.common import pendulum, json
from dlt.common import json
from dlt.common.configuration import resolve_configuration
from dlt.common.configuration.container import Container
from dlt.common.normalizers.naming.naming import NamingConvention
from dlt.common.schema.migrations import migrate_schema
from dlt.common.storages import SchemaStorageConfiguration
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
from dlt.common.exceptions import DictValidationException
from dlt.common.normalizers.naming import snake_case, direct, sql_upper
from dlt.common.typing import DictStrAny, StrAny
from dlt.common.utils import uniq_id
from dlt.common.schema import TColumnSchema, Schema, TStoredSchema, utils, TColumnHint
from dlt.common.schema.exceptions import (
InvalidSchemaName,
ParentTableNotFoundException,
SchemaEngineNoUpgradePathException,
TableIdentifiersFrozen,
)
from dlt.common.schema.typing import (
LOADS_TABLE_NAME,
SIMPLE_REGEX_PREFIX,
VERSION_TABLE_NAME,
TColumnName,
TSimpleRegex,
COLUMN_HINTS,
)
from dlt.common.normalizers.naming import snake_case, direct
from dlt.common.schema import TColumnSchema, Schema, TStoredSchema, utils
from dlt.common.schema.exceptions import TableIdentifiersFrozen
from dlt.common.schema.typing import SIMPLE_REGEX_PREFIX
from dlt.common.storages import SchemaStorage
from tests.common.utils import load_json_case, load_yml_case, COMMON_TEST_CASES_PATH

from tests.common.cases.normalizers import sql_upper
from tests.common.utils import load_json_case, load_yml_case


@pytest.fixture
Expand Down Expand Up @@ -178,7 +164,7 @@ def test_update_normalizers() -> None:
# save default hints in original form
default_hints = schema._settings["default_hints"]

os.environ["SCHEMA__NAMING"] = "sql_upper"
os.environ["SCHEMA__NAMING"] = "tests.common.cases.normalizers.sql_upper"
schema.update_normalizers()
assert isinstance(schema.naming, sql_upper.NamingConvention)
# print(schema.to_pretty_yaml())
Expand All @@ -188,7 +174,7 @@ def test_update_normalizers() -> None:
assert schema.tables["ISSUES"]["resource"] == "issues"

# make sure normalizer config is replaced
assert schema._normalizers_config["names"] == "sql_upper"
assert schema._normalizers_config["names"] == "tests.common.cases.normalizers.sql_upper"
assert "allow_identifier_change_on_table_with_data" not in schema._normalizers_config

# regexes are uppercased
Expand Down Expand Up @@ -273,7 +259,7 @@ def test_raise_on_change_identifier_table_with_data() -> None:
# mark issues table to seen data and change naming to sql upper
issues_table = schema.tables["issues"]
issues_table["x-normalizer"] = {"seen-data": True}
os.environ["SCHEMA__NAMING"] = "sql_upper"
os.environ["SCHEMA__NAMING"] = "tests.common.cases.normalizers.sql_upper"
with pytest.raises(TableIdentifiersFrozen) as fr_ex:
schema.update_normalizers()
assert fr_ex.value.table_name == "issues"
Expand Down
6 changes: 5 additions & 1 deletion tests/load/pipeline/test_duckdb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import os

from dlt.common.schema.exceptions import SchemaIdentifierNormalizationClash
from dlt.common.time import ensure_pendulum_datetime
from dlt.destinations.exceptions import DatabaseTerminalException
from dlt.pipeline.exceptions import PipelineStepFailed
Expand Down Expand Up @@ -54,7 +55,10 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No
table_name="🦚peacocks🦚",
loader_file_format=destination_config.file_format,
)
assert isinstance(pip_ex.value.__context__, DatabaseTerminalException)
assert isinstance(pip_ex.value.__context__, SchemaIdentifierNormalizationClash)
assert pip_ex.value.__context__.conflict_identifier_name == "🦚Peacocks🦚"
assert pip_ex.value.__context__.identifier_name == "🦚peacocks🦚"
assert pip_ex.value.__context__.identifier_type == "table"

# show tables and columns
with pipeline.sql_client() as client:
Expand Down
39 changes: 31 additions & 8 deletions tests/load/pipeline/test_restore_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) -
initial_state["_local"]["_last_extracted_at"] = pendulum.now()
initial_state["_local"]["_last_extracted_hash"] = initial_state["_version_hash"]
# add _dlt_id and _dlt_load_id
resource, _ = state_resource(initial_state)
resource, _ = state_resource(initial_state, "not_used_load_id")
resource.apply_hints(
columns={
"_dlt_id": {"name": "_dlt_id", "data_type": "text", "nullable": False},
Expand Down Expand Up @@ -195,15 +195,19 @@ def test_silently_skip_on_invalid_credentials(
)
@pytest.mark.parametrize("use_single_dataset", [True, False])
@pytest.mark.parametrize(
"naming_convention", ["tests.common.cases.normalizers.title_case", "snake_case"]
"naming_convention",
[
"tests.common.cases.normalizers.title_case",
"snake_case",
"tests.common.cases.normalizers.sql_upper",
],
)
def test_get_schemas_from_destination(
destination_config: DestinationTestConfiguration,
use_single_dataset: bool,
naming_convention: str,
) -> None:
# use specific naming convention
os.environ["SCHEMA__NAMING"] = naming_convention
set_naming_env(destination_config.destination, naming_convention)

pipeline_name = "pipe_" + uniq_id()
dataset_name = "state_test_" + uniq_id()
Expand Down Expand Up @@ -288,13 +292,17 @@ def _make_dn_name(schema_name: str) -> str:
ids=lambda x: x.name,
)
@pytest.mark.parametrize(
"naming_convention", ["tests.common.cases.normalizers.title_case", "snake_case", "sql_upper"]
"naming_convention",
[
"tests.common.cases.normalizers.title_case",
"snake_case",
"tests.common.cases.normalizers.sql_upper",
],
)
def test_restore_state_pipeline(
destination_config: DestinationTestConfiguration, naming_convention: str
) -> None:
# use specific naming convention
os.environ["SCHEMA__NAMING"] = naming_convention
set_naming_env(destination_config.destination, naming_convention)
# enable restoring from destination
os.environ["RESTORE_FROM_DESTINATION"] = "True"
pipeline_name = "pipe_" + uniq_id()
Expand Down Expand Up @@ -471,6 +479,9 @@ def test_restore_schemas_while_import_schemas_exist(
# make sure schema got imported
schema = p.schemas["ethereum"]
assert "blocks" in schema.tables
# allow to modify tables even if naming convention is changed. some of the tables in ethereum schema
# have processing hints that lock the table schema. so when weaviate changes naming convention we have an exception
os.environ["SCHEMA__ALLOW_IDENTIFIER_CHANGE_ON_TABLE_WITH_DATA"] = "true"

# extract some additional data to upgrade schema in the pipeline
p.run(
Expand Down Expand Up @@ -516,7 +527,7 @@ def test_restore_schemas_while_import_schemas_exist(
assert normalized_annotations in schema.tables

# check if attached to import schema
assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9
assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9()
# extract some data with restored pipeline
p.run(
["C", "D", "E"], table_name="blacklist", loader_file_format=destination_config.file_format
Expand Down Expand Up @@ -729,3 +740,15 @@ def prepare_import_folder(p: Pipeline) -> None:
common_yml_case_path("schemas/eth/ethereum_schema_v5"),
os.path.join(p._schema_storage.config.import_schema_path, "ethereum.schema.yaml"),
)


def set_naming_env(destination: str, naming_convention: str) -> None:
# snake case is for default convention so do not set it
if naming_convention != "snake_case":
# path convention to test weaviate ci_naming
if destination == "weaviate":
if naming_convention.endswith("sql_upper"):
pytest.skip(f"{naming_convention} not supported on weaviate")
else:
naming_convention = "dlt.destinations.impl.weaviate.ci_naming"
os.environ["SCHEMA__NAMING"] = naming_convention
10 changes: 5 additions & 5 deletions tests/load/weaviate/test_weaviate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def make_client(naming_convention: str) -> Iterator[WeaviateClient]:
"test_schema",
{"names": f"dlt.destinations.impl.weaviate.{naming_convention}", "json": None},
)
_client = get_client_instance(schema)
try:
yield _client
finally:
_client.drop_storage()
with get_client_instance(schema) as _client:
try:
yield _client
finally:
_client.drop_storage()


@pytest.fixture
Expand Down
Loading

0 comments on commit 71e418b

Please sign in to comment.