Skip to content

Commit

Permalink
Allow client reuse for Astra QES (#401)
Browse files Browse the repository at this point in the history
  • Loading branch information
george-zubrienko authored Mar 13, 2024
1 parent 163465d commit 81bb767
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ jobs:
uses: SneaksAndData/github-actions/[email protected]
with:
major_v: 2
minor_v: 8
minor_v: 9
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
TModel = TypeVar("TModel") # pylint: disable=C0103


@typing.final
class AstraClient:
"""
DataStax Astra (https://astra.datastax.com) credentials provider.
Expand All @@ -94,7 +95,7 @@ class AstraClient:
def __init__(
self,
client_name: str,
keyspace: str,
keyspace: Optional[str] = None,
secure_connect_bundle_bytes: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
Expand Down Expand Up @@ -127,9 +128,9 @@ def __init__(
if log_transient_errors:
logging.getLogger("backoff").addHandler(logging.StreamHandler())

def __enter__(self) -> "AstraClient":
def connect(self) -> None:
"""
Creates an Astra client for this context.
Connects to the Astra database
"""
tmp_bundle_file_name = str(uuid4())
os.makedirs(self._tmp_bundle_path, exist_ok=True)
Expand Down Expand Up @@ -174,11 +175,22 @@ def __enter__(self) -> "AstraClient":

os.remove(os.path.join(self._tmp_bundle_path, tmp_bundle_file_name))

def disconnect(self) -> None:
"""
Disconnect from the database and destroy the session.
"""
self._session.shutdown()
self._session = None

def __enter__(self) -> "AstraClient":
"""
Creates an Astra client for this context.
"""
self.connect()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self._session.shutdown()
self._session = None
self.disconnect()

def get_table_metadata(self, table_name: str) -> TableMetadata:
"""
Expand Down Expand Up @@ -216,6 +228,7 @@ def filter_entities(
self,
model_class: Type[TModel],
key_column_filter_values: Union[Expression, List[Dict[str, Any]]],
keyspace: Optional[str] = None,
table_name: Optional[str] = None,
select_columns: Optional[List[str]] = None,
primary_keys: Optional[List[str]] = None,
Expand All @@ -239,6 +252,7 @@ class Test:
:param: model_class: A dataclass type that should be mapped to Astra Model.
:param: key_column_filter_values: Primary key filters in a form of list of dictionaries of my_key: my_value. Multiple entries will result in multiple queries being run and concatenated
:param: keyspace: Optional keyspace name, if not provided in the client constructor
:param: table_name: Optional Astra table name, if it cannot be inferred from class name by converting it to snake_case.
:param: select_columns: An optional list of columns to return with the query.
:param: primary_keys: An optional list of columns that constitute a primary key, if it cannot be inferred from is_primary_key metadata on a dataclass field.
Expand Down Expand Up @@ -291,6 +305,7 @@ def to_pandas(

model_class: Type[Model] = self._model_dataclass(
value=model_class,
keyspace=keyspace,
table_name=table_name,
primary_keys=primary_keys,
partition_keys=partition_keys,
Expand Down Expand Up @@ -348,6 +363,7 @@ def get_entities_raw(self, query: str) -> DataFrame:
def _model_dataclass(
self,
value: Type[TModel],
keyspace: Optional[str] = None,
table_name: Optional[str] = None,
primary_keys: Optional[List[str]] = None,
partition_keys: Optional[List[str]] = None,
Expand All @@ -358,6 +374,7 @@ def _model_dataclass(
Maps a Python dataclass to Cassandra model.
:param: value: A dataclass type that should be mapped to Astra Model.
:param: keyspace: Optional keyspace name, if not provided in the client constructor.
:param: table_name: Astra table name, if it cannot be inferred from class name by converting it to snake_case.
:param: primary_keys: An optional list of columns that constitute a primary key, if it cannot be inferred from is_primary_key metadata on a dataclass field.
:param: partition_keys: An optional list of columns that constitute a partition key, if it cannot be inferred from is_partition_key metadata on a dataclass field.
Expand Down Expand Up @@ -467,7 +484,7 @@ def map_to_cassandra(

table_name = table_name or self._snake_pattern.sub("_", value.__name__).lower()

models_attributes: Dict[str, Column] = {
models_attributes: Dict[str, Union[Column, str]] = {
field.name: map_to_cassandra(
field.type,
field.name,
Expand All @@ -478,6 +495,9 @@ def map_to_cassandra(
for field in selected_fields
}

if keyspace:
models_attributes |= {"__keyspace__": keyspace}

return type(table_name, (Model,), models_attributes)

def set_table_option(self, table_name: str, option_name: str, option_value: str) -> None:
Expand All @@ -490,12 +510,13 @@ def set_table_option(self, table_name: str, option_name: str, option_value: str)
"""
self._session.execute(f"ALTER TABLE {self._keyspace}.{table_name} with {option_name}={option_value};")

def delete_entity(self, entity: TModel, table_name: Optional[str] = None) -> None:
def delete_entity(self, entity: TModel, table_name: Optional[str] = None, keyspace: Optional[str] = None) -> None:
"""
Delete an entity from Astra table
:param: entity: entity to delete
:param: table_name: Table to delete entity from.
:param: keyspace: Optional keyspace name, if not provided in the client constructor.
"""

@on_exception(
Expand All @@ -514,13 +535,16 @@ def _delete_entity(model_class: Type[Model], key_filter: Dict):
primary_keys = [field.name for field in fields(type(entity)) if field.metadata.get("is_primary_key", False)]

_delete_entity(
model_class=self._model_dataclass(value=type(entity), table_name=table_name, primary_keys=primary_keys),
model_class=self._model_dataclass(
value=type(entity), table_name=table_name, primary_keys=primary_keys, keyspace=keyspace
),
key_filter={key: getattr(entity, key) for key in primary_keys},
)

def upsert_entity(
self,
entity: TModel,
keyspace: Optional[str] = None,
table_name: Optional[str] = None,
client_rate_limit: str = "1000 per second",
) -> None:
Expand All @@ -529,6 +553,7 @@ def upsert_entity(
:param: entity: an object to insert
:param: table_name: Table to insert entity into.
:param: keyspace: Optional keyspace name, if not provided in the client constructor.
:param: client_rate_limit: the limit string to parse (eg: "1 per hour"), default: "1000 per second"
"""

Expand All @@ -544,13 +569,16 @@ def _save_entity(model_object: Model):
model_object.save()

primary_keys = [field.name for field in fields(type(entity)) if field.metadata.get("is_primary_key", False)]
model_class = self._model_dataclass(value=type(entity), table_name=table_name, primary_keys=primary_keys)
model_class = self._model_dataclass(
value=type(entity), table_name=table_name, primary_keys=primary_keys, keyspace=keyspace
)
_save_entity(model_class(**asdict(entity)))

def upsert_batch(
self,
entities: List[dict],
entity_type: Type[TModel],
keyspace: Optional[str] = None,
table_name: Optional[str] = None,
batch_size=1000,
client_rate_limit: str = "1000 per second",
Expand All @@ -559,7 +587,8 @@ def upsert_batch(
Inserts a batch into existing table.
:param: entities: entity batch to insert.
:param: entity_type: type of entity in a batch .
:param: entity_type: type of entity in a batch.
:param: keyspace: Optional keyspace name, if not provided in the client constructor.
:param: table_name: Table to insert entity into.
:param: batch_size: elements per batch to upsert.
:param: client_rate_limit: the limit string to parse (eg: "1 per hour"), default: "1000 per second"
Expand All @@ -582,7 +611,9 @@ def _save_entities(model_class: Type[Model], values: List[dict]):

for chunk in chunk_list(entities, batch_size):
_save_entities(
model_class=self._model_dataclass(value=entity_type, table_name=table_name, primary_keys=primary_keys),
model_class=self._model_dataclass(
value=entity_type, table_name=table_name, primary_keys=primary_keys, keyspace=keyspace
),
values=chunk,
)

Expand Down
20 changes: 17 additions & 3 deletions adapta/storage/query_enabled_store/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def open(self, path: DataPath) -> "QueryConfigurationBuilder":
"""
return QueryConfigurationBuilder(self, path)

@abstractmethod
def close(self) -> None:
"""
Optional logic to dispose of the store connections and related resources.
"""

@abstractmethod
def _apply_filter(
self, path: DataPath, filter_expression: Expression, columns: list[str]
Expand All @@ -92,15 +98,23 @@ def _apply_query(self, query: str) -> Union[DataFrame, Iterator[DataFrame]]:

@classmethod
@abstractmethod
def _from_connection_string(cls, connection_string: str) -> "QueryEnabledStore[TCredential, TSettings]":
def _from_connection_string(
cls, connection_string: str, lazy_init: bool = False
) -> "QueryEnabledStore[TCredential, TSettings]":
"""
Constructs the connection from a connection string
:param: connection_string: QES connection string.
:param: lazy_init: Whether to set this instance QES for querying eagerly or lazily.
"""

@staticmethod
def from_string(connection_string: str) -> "QueryEnabledStore[TCredential, TSettings]":
def from_string(connection_string: str, lazy_init: bool = False) -> "QueryEnabledStore[TCredential, TSettings]":
"""
Constructs a concrete QES instance from a connection string.
:param: connection_string: QES connection string.
:param: lazy_init: Whether to set this instance QES for querying eagerly or lazily.
"""

def get_qes_class(name: str) -> Type[QueryEnabledStore[TCredential, TSettings]]:
Expand All @@ -112,7 +126,7 @@ def get_qes_class(name: str) -> Type[QueryEnabledStore[TCredential, TSettings]]:
raise ModuleNotFoundError(
f"Cannot locate QES implementation: {class_name}. Please check the name for spelling errors and make sure your application can resolve the import"
)
return class_object._from_connection_string(connection_string)
return class_object._from_connection_string(connection_string, lazy_init)


@final
Expand Down
70 changes: 46 additions & 24 deletions adapta/storage/query_enabled_store/_qes_astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,38 +62,60 @@ class AstraQueryEnabledStore(QueryEnabledStore[AstraCredential, AstraSettings]):
QES Client for Astra DB (Cassandra).
"""

def close(self) -> None:
if not self._lazy:
self._astra_client.disconnect()

def __init__(self, credentials: AstraCredential, settings: AstraSettings, lazy_init: bool):
super().__init__(credentials, settings)
self._astra_client = AstraClient(
client_name=self.settings.client_name,
secure_connect_bundle_bytes=self.credentials.secret_connection_bundle_bytes,
client_id=self.credentials.client_id,
client_secret=self.credentials.client_secret,
)
self._lazy = lazy_init
if not lazy_init:
self._astra_client.connect()

def _apply_filter(
self, path: DataPath, filter_expression: Expression, columns: list[str]
) -> Union[DataFrame, Iterator[DataFrame]]:
assert isinstance(path, AstraPath)
astra_path: AstraPath = path

with AstraClient(
client_name=self.settings.client_name,
if self._lazy:
with self._astra_client as astra_client:
return astra_client.filter_entities(
model_class=astra_path.model_class(),
key_column_filter_values=filter_expression,
keyspace=astra_path.keyspace,
table_name=astra_path.table,
select_columns=columns,
num_threads=-1, # auto-infer, see method documentation
)

return self._astra_client.filter_entities(
model_class=astra_path.model_class(),
key_column_filter_values=filter_expression,
keyspace=astra_path.keyspace,
secure_connect_bundle_bytes=self.credentials.secret_connection_bundle_bytes,
client_id=self.credentials.client_id,
client_secret=self.credentials.client_secret,
) as astra_client:
return astra_client.filter_entities(
model_class=astra_path.model_class(),
key_column_filter_values=filter_expression,
table_name=astra_path.table,
select_columns=columns,
num_threads=-1, # auto-infer, see method documentation
)
table_name=astra_path.table,
select_columns=columns,
num_threads=-1, # auto-infer, see method documentation
)

def _apply_query(self, query: str) -> Union[DataFrame, Iterator[DataFrame]]:
with AstraClient(
client_name=self.settings.client_name,
keyspace=self.settings.keyspace,
secure_connect_bundle_bytes=self.credentials.secret_connection_bundle_bytes,
client_id=self.credentials.client_id,
client_secret=self.credentials.client_secret,
) as astra_client:
return astra_client.get_entities_raw(query)
if self._lazy:
with self._astra_client as astra_client:
return astra_client.get_entities_raw(query)
return self._astra_client.get_entities_raw(query)

@classmethod
def _from_connection_string(cls, connection_string: str) -> "QueryEnabledStore[AstraCredential, AstraSettings]":
def _from_connection_string(
cls, connection_string: str, lazy_init: bool = False
) -> "QueryEnabledStore[AstraCredential, AstraSettings]":
_, credentials, settings = re.findall(re.compile(CONNECTION_STRING_REGEX), connection_string)[0]
return cls(credentials=AstraCredential.from_json(credentials), settings=AstraSettings.from_json(settings))
return cls(
credentials=AstraCredential.from_json(credentials),
settings=AstraSettings.from_json(settings),
lazy_init=lazy_init,
)
7 changes: 6 additions & 1 deletion adapta/storage/query_enabled_store/_qes_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@ class DeltaQueryEnabledStore(QueryEnabledStore[DeltaCredential, DeltaSettings]):
QES Client for Delta Lake reads using delta-rs.
"""

def close(self) -> None:
pass

@classmethod
def _from_connection_string(cls, connection_string: str) -> "QueryEnabledStore[DeltaCredential, DeltaSettings]":
def _from_connection_string(
cls, connection_string: str, lazy_init: bool = False
) -> "QueryEnabledStore[DeltaCredential, DeltaSettings]":
_, credentials, settings = re.findall(re.compile(CONNECTION_STRING_REGEX), connection_string)[0]
return cls(credentials=DeltaCredential.from_json(credentials), settings=DeltaSettings.from_json(settings))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_query_enabled_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_query_store_instantiation(
connection_string: str, expected_store_type: Union[Type[QueryEnabledStore], Exception]
):
try:
store = QueryEnabledStore.from_string(connection_string)
store = QueryEnabledStore.from_string(connection_string, lazy_init=True)

assert isinstance(store, expected_store_type)
except Exception as load_error:
Expand Down

0 comments on commit 81bb767

Please sign in to comment.