diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 591d4c87..123d5f92 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -16,4 +16,4 @@ jobs: uses: SneaksAndData/github-actions/semver_release@v0.1.4 with: major_v: 2 - minor_v: 8 + minor_v: 9 diff --git a/adapta/storage/distributed_object_store/datastax_astra/astra_client.py b/adapta/storage/distributed_object_store/datastax_astra/astra_client.py index 704f8edd..f9552cbf 100644 --- a/adapta/storage/distributed_object_store/datastax_astra/astra_client.py +++ b/adapta/storage/distributed_object_store/datastax_astra/astra_client.py @@ -70,6 +70,7 @@ TModel = TypeVar("TModel") # pylint: disable=C0103 +@typing.final class AstraClient: """ DataStax Astra (https://astra.datastax.com) credentials provider. @@ -94,7 +95,7 @@ class AstraClient: def __init__( self, client_name: str, - keyspace: str, + keyspace: Optional[str] = None, secure_connect_bundle_bytes: Optional[str] = None, client_id: Optional[str] = None, client_secret: Optional[str] = None, @@ -127,9 +128,9 @@ def __init__( if log_transient_errors: logging.getLogger("backoff").addHandler(logging.StreamHandler()) - def __enter__(self) -> "AstraClient": + def connect(self) -> None: """ - Creates an Astra client for this context. + Connects to the Astra database """ tmp_bundle_file_name = str(uuid4()) os.makedirs(self._tmp_bundle_path, exist_ok=True) @@ -174,11 +175,22 @@ def __enter__(self) -> "AstraClient": os.remove(os.path.join(self._tmp_bundle_path, tmp_bundle_file_name)) + def disconnect(self) -> None: + """ + Disconnect from the database and destroy the session. + """ + self._session.shutdown() + self._session = None + + def __enter__(self) -> "AstraClient": + """ + Creates an Astra client for this context. + """ + self.connect() return self def __exit__(self, exc_type, exc_val, exc_tb): - self._session.shutdown() - self._session = None + self.disconnect() def get_table_metadata(self, table_name: str) -> TableMetadata: """ @@ -216,6 +228,7 @@ def filter_entities( self, model_class: Type[TModel], key_column_filter_values: Union[Expression, List[Dict[str, Any]]], + keyspace: Optional[str] = None, table_name: Optional[str] = None, select_columns: Optional[List[str]] = None, primary_keys: Optional[List[str]] = None, @@ -239,6 +252,7 @@ class Test: :param: model_class: A dataclass type that should be mapped to Astra Model. :param: key_column_filter_values: Primary key filters in a form of list of dictionaries of my_key: my_value. Multiple entries will result in multiple queries being run and concatenated + :param: keyspace: Optional keyspace name, if not provided in the client constructor :param: table_name: Optional Astra table name, if it cannot be inferred from class name by converting it to snake_case. :param: select_columns: An optional list of columns to return with the query. :param: primary_keys: An optional list of columns that constitute a primary key, if it cannot be inferred from is_primary_key metadata on a dataclass field. @@ -291,6 +305,7 @@ def to_pandas( model_class: Type[Model] = self._model_dataclass( value=model_class, + keyspace=keyspace, table_name=table_name, primary_keys=primary_keys, partition_keys=partition_keys, @@ -348,6 +363,7 @@ def get_entities_raw(self, query: str) -> DataFrame: def _model_dataclass( self, value: Type[TModel], + keyspace: Optional[str] = None, table_name: Optional[str] = None, primary_keys: Optional[List[str]] = None, partition_keys: Optional[List[str]] = None, @@ -358,6 +374,7 @@ def _model_dataclass( Maps a Python dataclass to Cassandra model. :param: value: A dataclass type that should be mapped to Astra Model. + :param: keyspace: Optional keyspace name, if not provided in the client constructor. :param: table_name: Astra table name, if it cannot be inferred from class name by converting it to snake_case. :param: primary_keys: An optional list of columns that constitute a primary key, if it cannot be inferred from is_primary_key metadata on a dataclass field. :param: partition_keys: An optional list of columns that constitute a partition key, if it cannot be inferred from is_partition_key metadata on a dataclass field. @@ -467,7 +484,7 @@ def map_to_cassandra( table_name = table_name or self._snake_pattern.sub("_", value.__name__).lower() - models_attributes: Dict[str, Column] = { + models_attributes: Dict[str, Union[Column, str]] = { field.name: map_to_cassandra( field.type, field.name, @@ -478,6 +495,9 @@ def map_to_cassandra( for field in selected_fields } + if keyspace: + models_attributes |= {"__keyspace__": keyspace} + return type(table_name, (Model,), models_attributes) def set_table_option(self, table_name: str, option_name: str, option_value: str) -> None: @@ -490,12 +510,13 @@ def set_table_option(self, table_name: str, option_name: str, option_value: str) """ self._session.execute(f"ALTER TABLE {self._keyspace}.{table_name} with {option_name}={option_value};") - def delete_entity(self, entity: TModel, table_name: Optional[str] = None) -> None: + def delete_entity(self, entity: TModel, table_name: Optional[str] = None, keyspace: Optional[str] = None) -> None: """ Delete an entity from Astra table :param: entity: entity to delete :param: table_name: Table to delete entity from. + :param: keyspace: Optional keyspace name, if not provided in the client constructor. """ @on_exception( @@ -514,13 +535,16 @@ def _delete_entity(model_class: Type[Model], key_filter: Dict): primary_keys = [field.name for field in fields(type(entity)) if field.metadata.get("is_primary_key", False)] _delete_entity( - model_class=self._model_dataclass(value=type(entity), table_name=table_name, primary_keys=primary_keys), + model_class=self._model_dataclass( + value=type(entity), table_name=table_name, primary_keys=primary_keys, keyspace=keyspace + ), key_filter={key: getattr(entity, key) for key in primary_keys}, ) def upsert_entity( self, entity: TModel, + keyspace: Optional[str] = None, table_name: Optional[str] = None, client_rate_limit: str = "1000 per second", ) -> None: @@ -529,6 +553,7 @@ def upsert_entity( :param: entity: an object to insert :param: table_name: Table to insert entity into. + :param: keyspace: Optional keyspace name, if not provided in the client constructor. :param: client_rate_limit: the limit string to parse (eg: "1 per hour"), default: "1000 per second" """ @@ -544,13 +569,16 @@ def _save_entity(model_object: Model): model_object.save() primary_keys = [field.name for field in fields(type(entity)) if field.metadata.get("is_primary_key", False)] - model_class = self._model_dataclass(value=type(entity), table_name=table_name, primary_keys=primary_keys) + model_class = self._model_dataclass( + value=type(entity), table_name=table_name, primary_keys=primary_keys, keyspace=keyspace + ) _save_entity(model_class(**asdict(entity))) def upsert_batch( self, entities: List[dict], entity_type: Type[TModel], + keyspace: Optional[str] = None, table_name: Optional[str] = None, batch_size=1000, client_rate_limit: str = "1000 per second", @@ -559,7 +587,8 @@ def upsert_batch( Inserts a batch into existing table. :param: entities: entity batch to insert. - :param: entity_type: type of entity in a batch . + :param: entity_type: type of entity in a batch. + :param: keyspace: Optional keyspace name, if not provided in the client constructor. :param: table_name: Table to insert entity into. :param: batch_size: elements per batch to upsert. :param: client_rate_limit: the limit string to parse (eg: "1 per hour"), default: "1000 per second" @@ -582,7 +611,9 @@ def _save_entities(model_class: Type[Model], values: List[dict]): for chunk in chunk_list(entities, batch_size): _save_entities( - model_class=self._model_dataclass(value=entity_type, table_name=table_name, primary_keys=primary_keys), + model_class=self._model_dataclass( + value=entity_type, table_name=table_name, primary_keys=primary_keys, keyspace=keyspace + ), values=chunk, ) diff --git a/adapta/storage/query_enabled_store/_models.py b/adapta/storage/query_enabled_store/_models.py index b9d41f0d..88359ef3 100644 --- a/adapta/storage/query_enabled_store/_models.py +++ b/adapta/storage/query_enabled_store/_models.py @@ -76,6 +76,12 @@ def open(self, path: DataPath) -> "QueryConfigurationBuilder": """ return QueryConfigurationBuilder(self, path) + @abstractmethod + def close(self) -> None: + """ + Optional logic to dispose of the store connections and related resources. + """ + @abstractmethod def _apply_filter( self, path: DataPath, filter_expression: Expression, columns: list[str] @@ -92,15 +98,23 @@ def _apply_query(self, query: str) -> Union[DataFrame, Iterator[DataFrame]]: @classmethod @abstractmethod - def _from_connection_string(cls, connection_string: str) -> "QueryEnabledStore[TCredential, TSettings]": + def _from_connection_string( + cls, connection_string: str, lazy_init: bool = False + ) -> "QueryEnabledStore[TCredential, TSettings]": """ Constructs the connection from a connection string + + :param: connection_string: QES connection string. + :param: lazy_init: Whether to set this instance QES for querying eagerly or lazily. """ @staticmethod - def from_string(connection_string: str) -> "QueryEnabledStore[TCredential, TSettings]": + def from_string(connection_string: str, lazy_init: bool = False) -> "QueryEnabledStore[TCredential, TSettings]": """ Constructs a concrete QES instance from a connection string. + + :param: connection_string: QES connection string. + :param: lazy_init: Whether to set this instance QES for querying eagerly or lazily. """ def get_qes_class(name: str) -> Type[QueryEnabledStore[TCredential, TSettings]]: @@ -112,7 +126,7 @@ def get_qes_class(name: str) -> Type[QueryEnabledStore[TCredential, TSettings]]: raise ModuleNotFoundError( f"Cannot locate QES implementation: {class_name}. Please check the name for spelling errors and make sure your application can resolve the import" ) - return class_object._from_connection_string(connection_string) + return class_object._from_connection_string(connection_string, lazy_init) @final diff --git a/adapta/storage/query_enabled_store/_qes_astra.py b/adapta/storage/query_enabled_store/_qes_astra.py index 0812f9ca..70ae8871 100644 --- a/adapta/storage/query_enabled_store/_qes_astra.py +++ b/adapta/storage/query_enabled_store/_qes_astra.py @@ -62,38 +62,60 @@ class AstraQueryEnabledStore(QueryEnabledStore[AstraCredential, AstraSettings]): QES Client for Astra DB (Cassandra). """ + def close(self) -> None: + if not self._lazy: + self._astra_client.disconnect() + + def __init__(self, credentials: AstraCredential, settings: AstraSettings, lazy_init: bool): + super().__init__(credentials, settings) + self._astra_client = AstraClient( + client_name=self.settings.client_name, + secure_connect_bundle_bytes=self.credentials.secret_connection_bundle_bytes, + client_id=self.credentials.client_id, + client_secret=self.credentials.client_secret, + ) + self._lazy = lazy_init + if not lazy_init: + self._astra_client.connect() + def _apply_filter( self, path: DataPath, filter_expression: Expression, columns: list[str] ) -> Union[DataFrame, Iterator[DataFrame]]: assert isinstance(path, AstraPath) astra_path: AstraPath = path - - with AstraClient( - client_name=self.settings.client_name, + if self._lazy: + with self._astra_client as astra_client: + return astra_client.filter_entities( + model_class=astra_path.model_class(), + key_column_filter_values=filter_expression, + keyspace=astra_path.keyspace, + table_name=astra_path.table, + select_columns=columns, + num_threads=-1, # auto-infer, see method documentation + ) + + return self._astra_client.filter_entities( + model_class=astra_path.model_class(), + key_column_filter_values=filter_expression, keyspace=astra_path.keyspace, - secure_connect_bundle_bytes=self.credentials.secret_connection_bundle_bytes, - client_id=self.credentials.client_id, - client_secret=self.credentials.client_secret, - ) as astra_client: - return astra_client.filter_entities( - model_class=astra_path.model_class(), - key_column_filter_values=filter_expression, - table_name=astra_path.table, - select_columns=columns, - num_threads=-1, # auto-infer, see method documentation - ) + table_name=astra_path.table, + select_columns=columns, + num_threads=-1, # auto-infer, see method documentation + ) def _apply_query(self, query: str) -> Union[DataFrame, Iterator[DataFrame]]: - with AstraClient( - client_name=self.settings.client_name, - keyspace=self.settings.keyspace, - secure_connect_bundle_bytes=self.credentials.secret_connection_bundle_bytes, - client_id=self.credentials.client_id, - client_secret=self.credentials.client_secret, - ) as astra_client: - return astra_client.get_entities_raw(query) + if self._lazy: + with self._astra_client as astra_client: + return astra_client.get_entities_raw(query) + return self._astra_client.get_entities_raw(query) @classmethod - def _from_connection_string(cls, connection_string: str) -> "QueryEnabledStore[AstraCredential, AstraSettings]": + def _from_connection_string( + cls, connection_string: str, lazy_init: bool = False + ) -> "QueryEnabledStore[AstraCredential, AstraSettings]": _, credentials, settings = re.findall(re.compile(CONNECTION_STRING_REGEX), connection_string)[0] - return cls(credentials=AstraCredential.from_json(credentials), settings=AstraSettings.from_json(settings)) + return cls( + credentials=AstraCredential.from_json(credentials), + settings=AstraSettings.from_json(settings), + lazy_init=lazy_init, + ) diff --git a/adapta/storage/query_enabled_store/_qes_delta.py b/adapta/storage/query_enabled_store/_qes_delta.py index 6e7a8acd..03b6b2a2 100644 --- a/adapta/storage/query_enabled_store/_qes_delta.py +++ b/adapta/storage/query_enabled_store/_qes_delta.py @@ -46,8 +46,13 @@ class DeltaQueryEnabledStore(QueryEnabledStore[DeltaCredential, DeltaSettings]): QES Client for Delta Lake reads using delta-rs. """ + def close(self) -> None: + pass + @classmethod - def _from_connection_string(cls, connection_string: str) -> "QueryEnabledStore[DeltaCredential, DeltaSettings]": + def _from_connection_string( + cls, connection_string: str, lazy_init: bool = False + ) -> "QueryEnabledStore[DeltaCredential, DeltaSettings]": _, credentials, settings = re.findall(re.compile(CONNECTION_STRING_REGEX), connection_string)[0] return cls(credentials=DeltaCredential.from_json(credentials), settings=DeltaSettings.from_json(settings)) diff --git a/tests/test_query_enabled_store.py b/tests/test_query_enabled_store.py index 9b943ecf..f4c828e1 100644 --- a/tests/test_query_enabled_store.py +++ b/tests/test_query_enabled_store.py @@ -50,7 +50,7 @@ def test_query_store_instantiation( connection_string: str, expected_store_type: Union[Type[QueryEnabledStore], Exception] ): try: - store = QueryEnabledStore.from_string(connection_string) + store = QueryEnabledStore.from_string(connection_string, lazy_init=True) assert isinstance(store, expected_store_type) except Exception as load_error: