From 3bdb7d895c6b89fc3109187940534757f38af268 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Wed, 3 Apr 2024 13:44:58 +0300 Subject: [PATCH] Change formatting line length to 88 chars --- dlt/__init__.py | 4 +- dlt/cli/_dlt.py | 208 +++++++--- dlt/cli/config_toml_writer.py | 4 +- dlt/cli/deploy_command.py | 112 ++++-- dlt/cli/deploy_command_helpers.py | 100 +++-- dlt/cli/echo.py | 4 +- dlt/cli/init_command.py | 137 +++++-- dlt/cli/pipeline_command.py | 63 ++- dlt/cli/pipeline_files.py | 47 ++- dlt/cli/source_detection.py | 34 +- dlt/cli/telemetry_command.py | 8 +- dlt/cli/utils.py | 11 +- dlt/common/arithmetics.py | 4 +- dlt/common/configuration/__init__.py | 7 +- dlt/common/configuration/accessors.py | 17 +- dlt/common/configuration/container.py | 4 +- dlt/common/configuration/exceptions.py | 74 ++-- dlt/common/configuration/inject.py | 40 +- dlt/common/configuration/providers/airflow.py | 8 +- .../configuration/providers/google_secrets.py | 29 +- .../configuration/providers/provider.py | 4 +- dlt/common/configuration/providers/toml.py | 43 +- dlt/common/configuration/resolve.py | 74 +++- .../configuration/specs/api_credentials.py | 9 +- .../configuration/specs/aws_credentials.py | 16 +- .../configuration/specs/base_configuration.py | 58 ++- .../specs/config_providers_context.py | 19 +- .../specs/config_section_context.py | 41 +- .../specs/connection_string_credentials.py | 9 +- dlt/common/configuration/specs/exceptions.py | 30 +- .../configuration/specs/gcp_credentials.py | 57 ++- .../configuration/specs/run_configuration.py | 11 +- dlt/common/configuration/utils.py | 13 +- dlt/common/data_writers/__init__.py | 6 +- dlt/common/data_writers/buffered.py | 28 +- dlt/common/data_writers/escape.py | 13 +- dlt/common/data_writers/exceptions.py | 7 +- dlt/common/data_writers/writers.py | 21 +- dlt/common/destination/__init__.py | 6 +- dlt/common/destination/capabilities.py | 16 +- dlt/common/destination/exceptions.py | 61 ++- dlt/common/destination/reference.py | 67 ++-- dlt/common/exceptions.py | 52 ++- dlt/common/git.py | 14 +- dlt/common/json/__init__.py | 12 +- dlt/common/json/_orjson.py | 14 +- dlt/common/json/_simplejson.py | 4 +- dlt/common/jsonpath.py | 4 +- dlt/common/libs/pyarrow.py | 46 ++- dlt/common/libs/pydantic.py | 14 +- dlt/common/libs/sql_alchemy.py | 25 +- dlt/common/logger.py | 4 +- dlt/common/normalizers/configuration.py | 4 +- dlt/common/normalizers/json/__init__.py | 15 +- dlt/common/normalizers/json/relational.py | 73 +++- dlt/common/normalizers/naming/direct.py | 4 +- dlt/common/normalizers/naming/duck_case.py | 8 +- dlt/common/normalizers/naming/exceptions.py | 9 +- dlt/common/normalizers/naming/naming.py | 20 +- dlt/common/normalizers/naming/snake_case.py | 8 +- dlt/common/normalizers/utils.py | 12 +- dlt/common/pipeline.py | 65 ++- dlt/common/reflection/spec.py | 10 +- dlt/common/reflection/utils.py | 12 +- dlt/common/runners/pool_runner.py | 15 +- dlt/common/runners/stdout.py | 8 +- dlt/common/runners/venv.py | 5 +- dlt/common/runtime/collector.py | 81 +++- dlt/common/runtime/exec_info.py | 16 +- dlt/common/runtime/init.py | 6 +- dlt/common/runtime/json_logging.py | 8 +- dlt/common/runtime/prometheus.py | 4 +- dlt/common/runtime/segment.py | 21 +- dlt/common/runtime/sentry.py | 7 +- dlt/common/runtime/slack.py | 4 +- dlt/common/runtime/telemetry.py | 11 +- dlt/common/schema/exceptions.py | 36 +- dlt/common/schema/migrations.py | 4 +- dlt/common/schema/schema.py | 129 ++++-- dlt/common/schema/typing.py | 12 +- dlt/common/schema/utils.py | 107 +++-- dlt/common/storages/configuration.py | 33 +- dlt/common/storages/data_item_storage.py | 14 +- dlt/common/storages/exceptions.py | 34 +- dlt/common/storages/file_storage.py | 52 ++- dlt/common/storages/fsspec_filesystem.py | 59 ++- dlt/common/storages/fsspecs/google_drive.py | 50 ++- dlt/common/storages/load_package.py | 95 +++-- dlt/common/storages/load_storage.py | 60 ++- dlt/common/storages/normalize_storage.py | 29 +- dlt/common/storages/schema_storage.py | 32 +- dlt/common/storages/transactional_file.py | 19 +- dlt/common/storages/versioned_storage.py | 20 +- dlt/common/time.py | 4 +- dlt/common/typing.py | 8 +- dlt/common/utils.py | 40 +- dlt/common/validation.py | 38 +- dlt/common/versioned_state.py | 4 +- dlt/common/warnings.py | 15 +- dlt/destinations/decorators.py | 17 +- dlt/destinations/exceptions.py | 37 +- dlt/destinations/impl/athena/athena.py | 96 +++-- dlt/destinations/impl/athena/configuration.py | 4 +- dlt/destinations/impl/bigquery/bigquery.py | 72 +++- .../impl/bigquery/bigquery_adapter.py | 48 ++- .../impl/bigquery/configuration.py | 13 +- dlt/destinations/impl/bigquery/sql_client.py | 35 +- dlt/destinations/impl/databricks/__init__.py | 5 +- .../impl/databricks/configuration.py | 9 +- .../impl/databricks/databricks.py | 90 +++-- .../impl/databricks/sql_client.py | 24 +- dlt/destinations/impl/destination/__init__.py | 4 +- .../impl/destination/configuration.py | 4 +- .../impl/destination/destination.py | 17 +- dlt/destinations/impl/destination/factory.py | 23 +- dlt/destinations/impl/duckdb/__init__.py | 5 +- dlt/destinations/impl/duckdb/configuration.py | 31 +- dlt/destinations/impl/duckdb/duck.py | 16 +- dlt/destinations/impl/duckdb/factory.py | 5 +- dlt/destinations/impl/duckdb/sql_client.py | 16 +- dlt/destinations/impl/dummy/__init__.py | 12 +- dlt/destinations/impl/dummy/dummy.py | 13 +- dlt/destinations/impl/filesystem/factory.py | 8 +- .../impl/filesystem/filesystem.py | 29 +- dlt/destinations/impl/motherduck/__init__.py | 5 +- .../impl/motherduck/configuration.py | 21 +- .../impl/motherduck/motherduck.py | 4 +- .../impl/motherduck/sql_client.py | 9 +- dlt/destinations/impl/mssql/__init__.py | 5 +- dlt/destinations/impl/mssql/configuration.py | 20 +- dlt/destinations/impl/mssql/factory.py | 5 +- dlt/destinations/impl/mssql/mssql.py | 23 +- dlt/destinations/impl/mssql/sql_client.py | 23 +- dlt/destinations/impl/postgres/__init__.py | 10 +- .../impl/postgres/configuration.py | 8 +- dlt/destinations/impl/postgres/postgres.py | 16 +- dlt/destinations/impl/postgres/sql_client.py | 16 +- dlt/destinations/impl/qdrant/factory.py | 5 +- .../impl/qdrant/qdrant_adapter.py | 3 +- dlt/destinations/impl/qdrant/qdrant_client.py | 73 +++- dlt/destinations/impl/redshift/__init__.py | 5 +- dlt/destinations/impl/redshift/redshift.py | 41 +- .../impl/snowflake/configuration.py | 13 +- dlt/destinations/impl/snowflake/snowflake.py | 39 +- dlt/destinations/impl/snowflake/sql_client.py | 12 +- dlt/destinations/impl/synapse/__init__.py | 9 +- dlt/destinations/impl/synapse/sql_client.py | 4 +- dlt/destinations/impl/synapse/synapse.py | 48 ++- dlt/destinations/impl/weaviate/exceptions.py | 15 +- dlt/destinations/impl/weaviate/naming.py | 4 +- .../impl/weaviate/weaviate_adapter.py | 7 +- .../impl/weaviate/weaviate_client.py | 77 +++- dlt/destinations/insert_job_client.py | 31 +- dlt/destinations/job_client_impl.py | 130 ++++-- dlt/destinations/job_impl.py | 23 +- dlt/destinations/path_utils.py | 35 +- dlt/destinations/sql_client.py | 32 +- dlt/destinations/sql_jobs.py | 89 +++-- dlt/destinations/type_mapping.py | 12 +- dlt/extract/concurrency.py | 32 +- dlt/extract/decorators.py | 56 ++- dlt/extract/exceptions.py | 173 ++++---- dlt/extract/extract.py | 64 ++- dlt/extract/extractors.py | 50 ++- dlt/extract/hints.py | 51 ++- dlt/extract/incremental/__init__.py | 119 ++++-- dlt/extract/incremental/exceptions.py | 14 +- dlt/extract/incremental/transform.py | 62 ++- dlt/extract/items.py | 4 +- dlt/extract/pipe.py | 66 +++- dlt/extract/pipe_iterator.py | 73 +++- dlt/extract/resource.py | 72 +++- dlt/extract/source.py | 67 +++- dlt/extract/storage.py | 24 +- dlt/extract/utils.py | 29 +- dlt/extract/validation.py | 19 +- dlt/helpers/airflow_helper.py | 48 ++- dlt/helpers/dbt/configuration.py | 9 +- dlt/helpers/dbt/dbt_utils.py | 4 +- dlt/helpers/dbt/runner.py | 33 +- dlt/helpers/dbt_cloud/client.py | 20 +- dlt/helpers/streamlit_app/blocks/load_info.py | 8 +- dlt/helpers/streamlit_app/blocks/query.py | 7 +- dlt/helpers/streamlit_app/pages/load_info.py | 29 +- dlt/helpers/streamlit_app/utils.py | 4 +- dlt/load/exceptions.py | 21 +- dlt/load/load.py | 147 +++++-- dlt/load/utils.py | 32 +- dlt/normalize/exceptions.py | 3 +- dlt/normalize/items_normalizers.py | 81 +++- dlt/normalize/normalize.py | 104 +++-- dlt/pipeline/__init__.py | 23 +- dlt/pipeline/configuration.py | 8 +- dlt/pipeline/current.py | 6 +- dlt/pipeline/dbt.py | 8 +- dlt/pipeline/exceptions.py | 44 ++- dlt/pipeline/helpers.py | 62 ++- dlt/pipeline/pipeline.py | 226 ++++++++--- dlt/pipeline/platform.py | 21 +- dlt/pipeline/progress.py | 5 +- dlt/pipeline/state_sync.py | 15 +- dlt/pipeline/trace.py | 19 +- dlt/pipeline/track.py | 36 +- dlt/pipeline/warnings.py | 9 +- dlt/reflection/script_inspector.py | 12 +- dlt/reflection/script_visitor.py | 13 +- dlt/sources/credentials.py | 5 +- dlt/sources/helpers/requests/retry.py | 17 +- dlt/sources/helpers/requests/session.py | 4 +- dlt/sources/helpers/rest_client/auth.py | 11 +- dlt/sources/helpers/rest_client/paginators.py | 4 +- docs/examples/_template/_template.py | 4 +- docs/examples/archive/_helpers.py | 7 +- docs/examples/archive/credentials/explicit.py | 8 +- docs/examples/archive/dbt_run_jaffle.py | 14 +- docs/examples/archive/quickstart.py | 7 +- docs/examples/archive/rasa_example.py | 4 +- docs/examples/archive/read_table.py | 9 +- docs/examples/archive/singer_tap_example.py | 8 +- .../archive/singer_tap_jsonl_example.py | 8 +- .../examples/archive/sources/google_sheets.py | 5 +- docs/examples/archive/sources/rasa/rasa.py | 7 +- docs/examples/archive/sources/singer_tap.py | 13 +- docs/examples/archive/sources/sql_query.py | 4 +- docs/examples/chess/chess.py | 9 +- .../chess_production/chess_production.py | 25 +- .../custom_destination_bigquery.py | 4 +- .../incremental_loading.py | 4 +- docs/examples/nested_data/nested_data.py | 4 +- .../pdf_to_weaviate/pdf_to_weaviate.py | 4 +- .../examples/qdrant_zendesk/qdrant_zendesk.py | 8 +- docs/tools/check_embedded_snippets.py | 38 +- docs/tools/fix_grammar_gpt.py | 21 +- docs/tools/lint_setup/template.py | 12 +- .../transformations/dbt/dbt-snippets.py | 8 +- .../snippets/destination-snippets.py | 11 +- docs/website/docs/getting-started-snippets.py | 20 +- docs/website/docs/intro-snippets.py | 8 +- .../performance-snippets.py | 10 +- .../load-data-from-an-api-snippets.py | 12 +- docs/website/docs/utils.py | 3 +- pyproject.toml | 4 +- tests/cases.py | 71 +++- tests/cli/common/test_cli_invoke.py | 74 +++- tests/cli/common/test_telemetry_command.py | 20 +- tests/cli/conftest.py | 7 +- tests/cli/test_config_toml_writer.py | 50 ++- tests/cli/test_deploy_command.py | 9 +- tests/cli/test_init_command.py | 118 ++++-- tests/cli/test_pipeline_command.py | 16 +- tests/common/configuration/test_accessors.py | 59 ++- .../configuration/test_configuration.py | 146 +++++-- tests/common/configuration/test_container.py | 32 +- .../common/configuration/test_credentials.py | 43 +- .../configuration/test_environ_provider.py | 26 +- tests/common/configuration/test_inject.py | 33 +- tests/common/configuration/test_sections.py | 56 ++- tests/common/configuration/test_spec_union.py | 25 +- .../configuration/test_toml_provider.py | 75 +++- tests/common/configuration/utils.py | 4 +- .../common/data_writers/test_data_writers.py | 32 +- .../common/normalizers/custom_normalizers.py | 8 +- .../normalizers/test_import_normalizers.py | 9 +- .../normalizers/test_json_relational.py | 173 +++++--- tests/common/normalizers/test_naming.py | 56 ++- .../normalizers/test_naming_duck_case.py | 4 +- .../normalizers/test_naming_snake_case.py | 50 ++- tests/common/reflection/test_reflect_spec.py | 18 +- tests/common/runners/test_pipes.py | 42 +- tests/common/runners/test_runnable.py | 4 +- tests/common/runners/test_venv.py | 32 +- tests/common/runners/utils.py | 13 +- tests/common/runtime/test_collector.py | 4 +- tests/common/runtime/test_telemetry.py | 4 +- tests/common/schema/test_coercion.py | 71 ++-- tests/common/schema/test_detections.py | 18 +- tests/common/schema/test_filtering.py | 24 +- tests/common/schema/test_inference.py | 32 +- tests/common/schema/test_merges.py | 36 +- tests/common/schema/test_schema.py | 132 +++++-- tests/common/schema/test_schema_contract.py | 97 ++++- tests/common/schema/test_versioning.py | 6 +- tests/common/storages/test_file_storage.py | 15 +- tests/common/storages/test_load_package.py | 37 +- tests/common/storages/test_load_storage.py | 80 +++- tests/common/storages/test_schema_storage.py | 36 +- .../storages/test_transactional_file.py | 16 +- .../common/storages/test_versioned_storage.py | 5 +- tests/common/test_arithmetics.py | 6 +- tests/common/test_destination.py | 30 +- tests/common/test_git.py | 34 +- tests/common/test_json.py | 3 +- tests/common/test_pipeline_state.py | 4 +- tests/common/test_time.py | 27 +- tests/common/test_utils.py | 53 ++- tests/common/test_validation.py | 21 +- tests/common/test_versioned_state.py | 4 +- tests/common/test_wei.py | 8 +- tests/common/utils.py | 8 +- tests/conftest.py | 16 +- tests/destinations/test_custom_destination.py | 62 ++- .../test_destination_name_and_config.py | 40 +- tests/destinations/test_path_utils.py | 8 +- .../data_writers/test_buffered_writer.py | 34 +- .../data_writers/test_data_item_storage.py | 21 +- tests/extract/test_decorators.py | 58 ++- tests/extract/test_extract.py | 38 +- tests/extract/test_extract_pipe.py | 44 ++- tests/extract/test_incremental.py | 369 +++++++++++++----- tests/extract/test_sources.py | 82 +++- tests/extract/test_validation.py | 16 +- tests/extract/utils.py | 8 +- tests/helpers/airflow_tests/conftest.py | 7 +- .../airflow_tests/test_airflow_provider.py | 24 +- .../airflow_tests/test_airflow_wrapper.py | 193 ++++++--- .../test_join_airflow_scheduler.py | 49 ++- tests/helpers/airflow_tests/utils.py | 4 +- .../helpers/dbt_cloud_tests/test_dbt_cloud.py | 4 +- .../helpers/dbt_tests/local/test_dbt_utils.py | 15 +- .../local/test_runner_destinations.py | 34 +- .../dbt_tests/test_runner_dbt_versions.py | 46 ++- .../providers/test_google_secrets_provider.py | 43 +- .../test_streamlit_show_resources.py | 13 +- tests/libs/test_parquet_writer.py | 47 ++- tests/libs/test_pyarrow.py | 14 +- tests/libs/test_pydantic.py | 55 ++- .../athena_iceberg/test_athena_iceberg.py | 23 +- tests/load/bigquery/test_bigquery_client.py | 63 ++- .../bigquery/test_bigquery_table_builder.py | 212 +++++++--- .../test_databricks_configuration.py | 4 +- tests/load/duckdb/test_duckdb_client.py | 51 ++- .../load/duckdb/test_duckdb_table_builder.py | 4 +- .../load/filesystem/test_azure_credentials.py | 19 +- .../load/filesystem/test_filesystem_client.py | 20 +- .../load/filesystem/test_filesystem_common.py | 33 +- tests/load/filesystem/utils.py | 4 +- tests/load/mssql/test_mssql_credentials.py | 8 +- tests/load/mssql/test_mssql_table_builder.py | 9 +- tests/load/pipeline/test_arrow_loading.py | 16 +- tests/load/pipeline/test_athena.py | 36 +- tests/load/pipeline/test_bigquery.py | 19 +- tests/load/pipeline/test_dbt_helper.py | 19 +- tests/load/pipeline/test_drop.py | 140 +++++-- tests/load/pipeline/test_duckdb.py | 20 +- .../load/pipeline/test_filesystem_pipeline.py | 13 +- tests/load/pipeline/test_merge_disposition.py | 264 ++++++++++--- tests/load/pipeline/test_pipelines.py | 142 +++++-- tests/load/pipeline/test_redshift.py | 12 +- .../load/pipeline/test_replace_disposition.py | 85 +++- tests/load/pipeline/test_restore_state.py | 155 ++++++-- tests/load/pipeline/test_stage_loading.py | 62 ++- .../test_write_disposition_changes.py | 54 +-- tests/load/pipeline/utils.py | 4 +- tests/load/postgres/test_postgres_client.py | 19 +- .../postgres/test_postgres_table_builder.py | 6 +- tests/load/qdrant/test_pipeline.py | 18 +- tests/load/qdrant/utils.py | 3 +- tests/load/redshift/test_redshift_client.py | 31 +- .../redshift/test_redshift_table_builder.py | 22 +- .../snowflake/test_snowflake_configuration.py | 24 +- .../snowflake/test_snowflake_table_builder.py | 16 +- .../synapse/test_synapse_table_builder.py | 6 +- .../synapse/test_synapse_table_indexing.py | 8 +- tests/load/synapse/utils.py | 4 +- tests/load/test_dummy_client.py | 182 ++++++--- tests/load/test_insert_job_client.py | 71 +++- tests/load/test_job_client.py | 227 ++++++++--- tests/load/test_sql_client.py | 138 +++++-- tests/load/utils.py | 75 +++- tests/load/weaviate/test_naming.py | 4 +- tests/load/weaviate/test_pipeline.py | 27 +- tests/load/weaviate/test_weaviate_client.py | 48 ++- tests/load/weaviate/utils.py | 9 +- tests/normalize/mock_rasa_json_normalizer.py | 4 +- tests/normalize/test_normalize.py | 144 +++++-- .../cases/github_pipeline/github_extract.py | 5 +- .../cases/github_pipeline/github_pipeline.py | 17 +- tests/pipeline/test_arrow_sources.py | 22 +- tests/pipeline/test_dlt_versions.py | 111 ++++-- tests/pipeline/test_import_export_schema.py | 22 +- tests/pipeline/test_pipeline.py | 155 ++++++-- tests/pipeline/test_pipeline_extra.py | 32 +- .../test_pipeline_file_format_resolver.py | 24 +- tests/pipeline/test_pipeline_state.py | 75 +++- tests/pipeline/test_pipeline_trace.py | 65 ++- tests/pipeline/test_platform_connection.py | 5 +- tests/pipeline/test_resources_evaluation.py | 9 +- tests/pipeline/test_schema_contracts.py | 49 ++- tests/pipeline/test_schema_updates.py | 4 +- tests/pipeline/utils.py | 11 +- tests/reflection/test_script_inspector.py | 20 +- .../helpers/rest_client/test_paginators.py | 8 +- tests/sources/helpers/test_requests.py | 6 +- tests/utils.py | 49 ++- 394 files changed, 10065 insertions(+), 3575 deletions(-) diff --git a/dlt/__init__.py b/dlt/__init__.py index eee105e47e..f0ad25de90 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -23,7 +23,9 @@ from dlt.version import __version__ from dlt.common.configuration.accessors import config, secrets from dlt.common.typing import TSecretValue as _TSecretValue -from dlt.common.configuration.specs import CredentialsConfiguration as _CredentialsConfiguration +from dlt.common.configuration.specs import ( + CredentialsConfiguration as _CredentialsConfiguration, +) from dlt.common.pipeline import source_state as state from dlt.common.schema import Schema diff --git a/dlt/cli/_dlt.py b/dlt/cli/_dlt.py index 2332c0286c..cd72d55b9f 100644 --- a/dlt/cli/_dlt.py +++ b/dlt/cli/_dlt.py @@ -59,7 +59,9 @@ def init_command_wrapper( branch: str, ) -> int: try: - init_command(source_name, destination_type, use_generic_template, repo_location, branch) + init_command( + source_name, destination_type, use_generic_template, repo_location, branch + ) except Exception as ex: on_exception(ex, DLT_INIT_DOCS_URL) return -1 @@ -102,19 +104,22 @@ def deploy_command_wrapper( ) except (CannotRestorePipelineException, PipelineWasNotRun) as ex: fmt.note( - "You must run the pipeline locally successfully at least once in order to deploy it." + "You must run the pipeline locally successfully at least once in order to" + " deploy it." ) on_exception(ex, DLT_DEPLOY_DOCS_URL) return -2 except InvalidGitRepositoryError: click.secho( - "No git repository found for pipeline script %s." % fmt.bold(pipeline_script_path), + "No git repository found for pipeline script %s." + % fmt.bold(pipeline_script_path), err=True, fg="red", ) fmt.note("If you do not have a repository yet, you can do either of:") fmt.note( - "- Run the following command to initialize new repository: %s" % fmt.bold("git init") + "- Run the following command to initialize new repository: %s" + % fmt.bold("git init") ) fmt.note( "- Add your local code to Github as described here: %s" @@ -122,10 +127,14 @@ def deploy_command_wrapper( "https://docs.github.com/en/get-started/importing-your-projects-to-github/importing-source-code-to-github/adding-locally-hosted-code-to-github" ) ) - fmt.note("Please refer to %s for further assistance" % fmt.bold(DLT_DEPLOY_DOCS_URL)) + fmt.note( + "Please refer to %s for further assistance" % fmt.bold(DLT_DEPLOY_DOCS_URL) + ) return -3 except NoSuchPathError as path_ex: - click.secho("The pipeline script does not exist\n%s" % str(path_ex), err=True, fg="red") + click.secho( + "The pipeline script does not exist\n%s" % str(path_ex), err=True, fg="red" + ) return -4 except Exception as ex: on_exception(ex, DLT_DEPLOY_DOCS_URL) @@ -135,10 +144,16 @@ def deploy_command_wrapper( @utils.track_command("pipeline", True, "operation") def pipeline_command_wrapper( - operation: str, pipeline_name: str, pipelines_dir: str, verbosity: int, **command_kwargs: Any + operation: str, + pipeline_name: str, + pipelines_dir: str, + verbosity: int, + **command_kwargs: Any, ) -> int: try: - pipeline_command(operation, pipeline_name, pipelines_dir, verbosity, **command_kwargs) + pipeline_command( + operation, pipeline_name, pipelines_dir, verbosity, **command_kwargs + ) return 0 except CannotRestorePipelineException as ex: click.secho(str(ex), err=True, fg="red") @@ -205,7 +220,11 @@ def __init__( help: str = None, # noqa ) -> None: super(TelemetryAction, self).__init__( - option_strings=option_strings, dest=dest, default=default, nargs=0, help=help + option_strings=option_strings, + dest=dest, + default=default, + nargs=0, + help=help, ) def __call__( @@ -230,7 +249,11 @@ def __init__( help: str = None, # noqa ) -> None: super(NonInteractiveAction, self).__init__( - option_strings=option_strings, dest=dest, default=default, nargs=0, help=help + option_strings=option_strings, + dest=dest, + default=default, + nargs=0, + help=help, ) def __call__( @@ -252,7 +275,11 @@ def __init__( help: str = None, # noqa ) -> None: super(DebugAction, self).__init__( - option_strings=option_strings, dest=dest, default=default, nargs=0, help=help + option_strings=option_strings, + dest=dest, + default=default, + nargs=0, + help=help, ) def __call__( @@ -273,7 +300,9 @@ def main() -> int: formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - "--version", action="version", version="%(prog)s {version}".format(version=__version__) + "--version", + action="version", + version="%(prog)s {version}".format(version=__version__), ) parser.add_argument( "--disable-telemetry", @@ -289,8 +318,8 @@ def main() -> int: "--non-interactive", action=NonInteractiveAction, help=( - "Non interactive mode. Default choices are automatically made for confirmations and" - " prompts." + "Non interactive mode. Default choices are automatically made for" + " confirmations and prompts." ), ) parser.add_argument( @@ -301,8 +330,8 @@ def main() -> int: init_cmd = subparsers.add_parser( "init", help=( - "Creates a pipeline project in the current folder by adding existing verified source or" - " creating a new one from template." + "Creates a pipeline project in the current folder by adding existing" + " verified source or creating a new one from template." ), ) init_cmd.add_argument( @@ -316,9 +345,9 @@ def main() -> int: "source", nargs="?", help=( - "Name of data source for which to create a pipeline. Adds existing verified source or" - " creates a new pipeline template if verified source for your data source is not yet" - " implemented." + "Name of data source for which to create a pipeline. Adds existing verified" + " source or creates a new pipeline template if verified source for your" + " data source is not yet implemented." ), ) init_cmd.add_argument( @@ -327,21 +356,27 @@ def main() -> int: init_cmd.add_argument( "--location", default=DEFAULT_VERIFIED_SOURCES_REPO, - help="Advanced. Uses a specific url or local path to verified sources repository.", + help=( + "Advanced. Uses a specific url or local path to verified sources" + " repository." + ), ) init_cmd.add_argument( "--branch", default=None, - help="Advanced. Uses specific branch of the init repository to fetch the template.", + help=( + "Advanced. Uses specific branch of the init repository to fetch the" + " template." + ), ) init_cmd.add_argument( "--generic", default=False, action="store_true", help=( - "When present uses a generic template with all the dlt loading code present will be" - " used. Otherwise a debug template is used that can be immediately run to get familiar" - " with the dlt sources." + "When present uses a generic template with all the dlt loading code present" + " will be used. Otherwise a debug template is used that can be immediately" + " run to get familiar with the dlt sources." ), ) @@ -359,14 +394,19 @@ def main() -> int: ) deploy_comm.add_argument( "--branch", - help="Advanced. Uses specific branch of the deploy repository to fetch the template.", + help=( + "Advanced. Uses specific branch of the deploy repository to fetch the" + " template." + ), ) deploy_cmd = subparsers.add_parser( "deploy", help="Creates a deployment package for a selected pipeline script" ) deploy_cmd.add_argument( - "pipeline_script_path", metavar="pipeline-script-path", help="Path to a pipeline script" + "pipeline_script_path", + metavar="pipeline-script-path", + help="Path to a pipeline script", ) deploy_sub_parsers = deploy_cmd.add_subparsers(dest="deployment_method") @@ -380,9 +420,9 @@ def main() -> int: "--schedule", required=True, help=( - "A schedule with which to run the pipeline, in cron format. Example: '*/30 * * * *'" - " will run the pipeline every 30 minutes. Remember to enclose the scheduler" - " expression in quotation marks!" + "A schedule with which to run the pipeline, in cron format. Example:" + " '*/30 * * * *' will run the pipeline every 30 minutes. Remember to" + " enclose the scheduler expression in quotation marks!" ), ) deploy_github_cmd.add_argument( @@ -416,32 +456,49 @@ def main() -> int: deploy_cmd = subparsers.add_parser( "deploy", help=( - 'Install additional dependencies with pip install "dlt[cli]" to create deployment' - " packages" + 'Install additional dependencies with pip install "dlt[cli]" to create' + " deployment packages" ), add_help=False, ) deploy_cmd.add_argument("--help", "-h", nargs="?", const=True) deploy_cmd.add_argument( - "pipeline_script_path", metavar="pipeline-script-path", nargs=argparse.REMAINDER + "pipeline_script_path", + metavar="pipeline-script-path", + nargs=argparse.REMAINDER, ) - schema = subparsers.add_parser("schema", help="Shows, converts and upgrades schemas") + schema = subparsers.add_parser( + "schema", help="Shows, converts and upgrades schemas" + ) schema.add_argument( - "file", help="Schema file name, in yaml or json format, will autodetect based on extension" + "file", + help=( + "Schema file name, in yaml or json format, will autodetect based on" + " extension" + ), ) schema.add_argument( - "--format", choices=["json", "yaml"], default="yaml", help="Display schema in this format" + "--format", + choices=["json", "yaml"], + default="yaml", + help="Display schema in this format", ) schema.add_argument( - "--remove-defaults", action="store_true", help="Does not show default hint values" + "--remove-defaults", + action="store_true", + help="Does not show default hint values", ) pipe_cmd = subparsers.add_parser( "pipeline", help="Operations on pipelines that were ran locally" ) pipe_cmd.add_argument( - "--list-pipelines", "-l", default=False, action="store_true", help="List local pipelines" + "--list-pipelines", + "-l", + default=False, + action="store_true", + help="List local pipelines", ) pipe_cmd.add_argument( "--hot-reload", @@ -450,7 +507,9 @@ def main() -> int: help="Reload streamlit app (for core development)", ) pipe_cmd.add_argument("pipeline_name", nargs="?", help="Pipeline name") - pipe_cmd.add_argument("--pipelines-dir", help="Pipelines working directory", default=None) + pipe_cmd.add_argument( + "--pipelines-dir", help="Pipelines working directory", default=None + ) pipe_cmd.add_argument( "--verbose", "-v", @@ -464,10 +523,12 @@ def main() -> int: pipe_cmd_sync_parent = argparse.ArgumentParser(add_help=False) pipe_cmd_sync_parent.add_argument( - "--destination", help="Sync from this destination when local pipeline state is missing." + "--destination", + help="Sync from this destination when local pipeline state is missing.", ) pipe_cmd_sync_parent.add_argument( - "--dataset-name", help="Dataset name to sync from when local pipeline state is missing." + "--dataset-name", + help="Dataset name to sync from when local pipeline state is missing.", ) pipeline_subparsers.add_parser( @@ -475,34 +536,40 @@ def main() -> int: ) pipeline_subparsers.add_parser( "show", - help="Generates and launches Streamlit app with the loading status and dataset explorer", + help=( + "Generates and launches Streamlit app with the loading status and dataset" + " explorer" + ), ) pipeline_subparsers.add_parser( "failed-jobs", help=( - "Displays information on all the failed loads in all completed packages, failed jobs" - " and associated error messages" + "Displays information on all the failed loads in all completed packages," + " failed jobs and associated error messages" ), ) pipeline_subparsers.add_parser( "drop-pending-packages", help=( - "Deletes all extracted and normalized packages including those that are partially" - " loaded." + "Deletes all extracted and normalized packages including those that are" + " partially loaded." ), ) pipeline_subparsers.add_parser( "sync", help=( - "Drops the local state of the pipeline and resets all the schemas and restores it from" - " destination. The destination state, data and schemas are left intact." + "Drops the local state of the pipeline and resets all the schemas and" + " restores it from destination. The destination state, data and schemas are" + " left intact." ), parents=[pipe_cmd_sync_parent], ) pipeline_subparsers.add_parser( "trace", help="Displays last run trace, use -v or -vv for more info" ) - pipe_cmd_schema = pipeline_subparsers.add_parser("schema", help="Displays default schema") + pipe_cmd_schema = pipeline_subparsers.add_parser( + "schema", help="Displays default schema" + ) pipe_cmd_schema.add_argument( "--format", choices=["json", "yaml"], @@ -510,7 +577,9 @@ def main() -> int: help="Display schema in this format", ) pipe_cmd_schema.add_argument( - "--remove-defaults", action="store_true", help="Does not show default hint values" + "--remove-defaults", + action="store_true", + help="Does not show default hint values", ) pipe_cmd_drop = pipeline_subparsers.add_parser( @@ -518,16 +587,16 @@ def main() -> int: help="Selectively drop tables and reset state", parents=[pipe_cmd_sync_parent], epilog=( - f"See {DLT_PIPELINE_COMMAND_DOCS_URL}#selectively-drop-tables-and-reset-state for more" - " info" + f"See {DLT_PIPELINE_COMMAND_DOCS_URL}#selectively-drop-tables-and-reset-state" + " for more info" ), ) pipe_cmd_drop.add_argument( "resources", nargs="*", help=( - "One or more resources to drop. Can be exact resource name(s) or regex pattern(s)." - " Regex patterns must start with re:" + "One or more resources to drop. Can be exact resource name(s) or regex" + " pattern(s). Regex patterns must start with re:" ), ) pipe_cmd_drop.add_argument( @@ -552,13 +621,17 @@ def main() -> int: ) pipe_cmd_package = pipeline_subparsers.add_parser( - "load-package", help="Displays information on load package, use -v or -vv for more info" + "load-package", + help="Displays information on load package, use -v or -vv for more info", ) pipe_cmd_package.add_argument( "load_id", metavar="load-id", nargs="?", - help="Load id of completed or normalized package. Defaults to the most recent package.", + help=( + "Load id of completed or normalized package. Defaults to the most recent" + " package." + ), ) subparsers.add_parser("telemetry", help="Shows telemetry status") @@ -567,17 +640,19 @@ def main() -> int: if Venv.is_virtual_env() and not Venv.is_venv_activated(): fmt.warning( - "You are running dlt installed in the global environment, however you have virtual" - " environment activated. The dlt command will not see dependencies from virtual" - " environment. You should uninstall the dlt from global environment and install it in" - " the current virtual environment instead." + "You are running dlt installed in the global environment, however you have" + " virtual environment activated. The dlt command will not see dependencies" + " from virtual environment. You should uninstall the dlt from global" + " environment and install it in the current virtual environment instead." ) if args.command == "schema": return schema_command_wrapper(args.file, args.format, args.remove_defaults) elif args.command == "pipeline": if args.list_pipelines: - return pipeline_command_wrapper("list", "-", args.pipelines_dir, args.verbosity) + return pipeline_command_wrapper( + "list", "-", args.pipelines_dir, args.verbosity + ) else: command_kwargs = dict(args._get_kwargs()) if not command_kwargs.get("pipeline_name"): @@ -596,7 +671,11 @@ def main() -> int: return -1 else: return init_command_wrapper( - args.source, args.destination, args.generic, args.location, args.branch + args.source, + args.destination, + args.generic, + args.location, + args.branch, ) elif args.command == "deploy": try: @@ -610,12 +689,13 @@ def main() -> int: ) except (NameError, KeyError): fmt.warning( - "Please install additional command line dependencies to use deploy command:" + "Please install additional command line dependencies to use deploy" + " command:" ) fmt.secho('pip install "dlt[cli]"', bold=True) fmt.echo( - "We ask you to install those dependencies separately to keep our core library small" - " and make it work everywhere." + "We ask you to install those dependencies separately to keep our core" + " library small and make it work everywhere." ) return -1 elif args.command == "telemetry": diff --git a/dlt/cli/config_toml_writer.py b/dlt/cli/config_toml_writer.py index 8cf831d725..53bd51df5a 100644 --- a/dlt/cli/config_toml_writer.py +++ b/dlt/cli/config_toml_writer.py @@ -85,7 +85,9 @@ def write_value( toml_table[name] = default_value -def write_spec(toml_table: TOMLTable, config: BaseConfiguration, overwrite_existing: bool) -> None: +def write_spec( + toml_table: TOMLTable, config: BaseConfiguration, overwrite_existing: bool +) -> None: for name, hint in config.get_resolvable_fields().items(): default_value = getattr(config, name, None) # check if field is of particular interest and should be included if it has default diff --git a/dlt/cli/deploy_command.py b/dlt/cli/deploy_command.py index 5a25752a6d..f7bace77d5 100644 --- a/dlt/cli/deploy_command.py +++ b/dlt/cli/deploy_command.py @@ -4,7 +4,11 @@ from enum import Enum from importlib.metadata import version as pkg_version -from dlt.common.configuration.providers import SECRETS_TOML, SECRETS_TOML_KEY, StringTomlProvider +from dlt.common.configuration.providers import ( + SECRETS_TOML, + SECRETS_TOML_KEY, + StringTomlProvider, +) from dlt.common.configuration.paths import make_dlt_settings_path from dlt.common.configuration.utils import serialize_value from dlt.common.git import is_dirty @@ -28,10 +32,10 @@ REQUIREMENTS_GITHUB_ACTION = "requirements_github_action.txt" DLT_DEPLOY_DOCS_URL = "https://dlthub.com/docs/walkthroughs/deploy-a-pipeline" -DLT_AIRFLOW_GCP_DOCS_URL = ( - "https://dlthub.com/docs/walkthroughs/deploy-a-pipeline/deploy-with-airflow-composer" +DLT_AIRFLOW_GCP_DOCS_URL = "https://dlthub.com/docs/walkthroughs/deploy-a-pipeline/deploy-with-airflow-composer" +AIRFLOW_GETTING_STARTED = ( + "https://airflow.apache.org/docs/apache-airflow/stable/start.html" ) -AIRFLOW_GETTING_STARTED = "https://airflow.apache.org/docs/apache-airflow/stable/start.html" AIRFLOW_DAG_TEMPLATE_SCRIPT = "dag_template.py" AIRFLOW_CLOUDBUILD_YAML = "cloudbuild.yaml" COMMAND_REPO_LOCATION = "https://github.com/dlt-hub/dlt-%s-template.git" @@ -69,7 +73,10 @@ def deploy_command( # command no longer needed kwargs.pop("command", None) deployment_class( - pipeline_script_path=pipeline_script_path, location=repo_location, branch=branch, **kwargs + pipeline_script_path=pipeline_script_path, + location=repo_location, + branch=branch, + **kwargs, ).run_deployment() @@ -96,8 +103,9 @@ def _generate_workflow(self, *args: Optional[Any]) -> None: if self.schedule_description is None: # TODO: move that check to _dlt and some intelligent help message on missing arg raise ValueError( - f"Setting 'schedule' for '{self.deployment_method}' is required! Use deploy command" - f" as 'dlt deploy chess.py {self.deployment_method} --schedule \"*/30 * * * *\"'." + f"Setting 'schedule' for '{self.deployment_method}' is required! Use" + " deploy command as 'dlt deploy chess.py" + f' {self.deployment_method} --schedule "*/30 * * * *"\'.' ) workflow = self._create_new_workflow() serialized_workflow = serialize_templated_yaml(workflow) @@ -110,7 +118,9 @@ def _generate_workflow(self, *args: Optional[Any]) -> None: os.path.join(self.deployment_method, "requirements_blacklist.txt") ) as f: requirements_blacklist = f.readlines() - requirements_txt = generate_pip_freeze(requirements_blacklist, REQUIREMENTS_GITHUB_ACTION) + requirements_txt = generate_pip_freeze( + requirements_blacklist, REQUIREMENTS_GITHUB_ACTION + ) requirements_txt_name = REQUIREMENTS_GITHUB_ACTION # if repo_storage.has_file(utils.REQUIREMENTS_TXT): self.artifacts["requirements_txt"] = requirements_txt @@ -121,7 +131,9 @@ def _make_modification(self) -> None: self.repo_storage.create_folder(utils.GITHUB_WORKFLOWS_DIR) self.repo_storage.save( - os.path.join(utils.GITHUB_WORKFLOWS_DIR, self.artifacts["serialized_workflow_name"]), + os.path.join( + utils.GITHUB_WORKFLOWS_DIR, self.artifacts["serialized_workflow_name"] + ), self.artifacts["serialized_workflow"], ) self.repo_storage.save( @@ -135,7 +147,8 @@ def _create_new_workflow(self) -> Any: workflow = yaml.safe_load(f) # customize the workflow workflow["name"] = ( - f"Run {self.state['pipeline_name']} pipeline from {self.pipeline_script_path}" + f"Run {self.state['pipeline_name']} pipeline from" + f" {self.pipeline_script_path}" ) if self.run_on_push is False: del workflow["on"]["push"] @@ -194,9 +207,9 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: ) ) fmt.echo( - "* The dependencies that will be used to run the pipeline are stored in %s. If you" - " change add more dependencies, remember to refresh your deployment by running the same" - " 'deploy' command again." + "* The dependencies that will be used to run the pipeline are stored in %s." + " If you change add more dependencies, remember to refresh your deployment" + " by running the same 'deploy' command again." % fmt.bold(self.artifacts["requirements_txt_name"]) ) fmt.echo() @@ -204,11 +217,12 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: fmt.echo("1. Your pipeline does not seem to need any secrets.") else: fmt.echo( - "You should now add the secrets to github repository secrets, commit and push the" - " pipeline files to github." + "You should now add the secrets to github repository secrets, commit" + " and push the pipeline files to github." ) fmt.echo( - "1. Add the following secret values (typically stored in %s): \n%s\nin %s" + "1. Add the following secret values (typically stored in %s):" + " \n%s\nin %s" % ( fmt.bold(make_dlt_settings_path(SECRETS_TOML)), fmt.bold( @@ -217,33 +231,39 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: for s_v in self.secret_envs ) ), - fmt.bold(github_origin_to_url(self.origin, "/settings/secrets/actions")), + fmt.bold( + github_origin_to_url(self.origin, "/settings/secrets/actions") + ), ) ) fmt.echo() self._echo_secrets() fmt.echo( - "2. Add stage deployment files to commit. Use your Git UI or the following command" + "2. Add stage deployment files to commit. Use your Git UI or the following" + " command" ) new_req_path = self.repo_storage.from_relative_path_to_wd( self.artifacts["requirements_txt_name"] ) new_workflow_path = self.repo_storage.from_relative_path_to_wd( - os.path.join(utils.GITHUB_WORKFLOWS_DIR, self.artifacts["serialized_workflow_name"]) + os.path.join( + utils.GITHUB_WORKFLOWS_DIR, self.artifacts["serialized_workflow_name"] + ) ) fmt.echo(fmt.bold(f"git add {new_req_path} {new_workflow_path}")) fmt.echo() fmt.echo("3. Commit the files above. Use your Git UI or the following command") fmt.echo( fmt.bold( - f"git commit -m 'run {self.state['pipeline_name']} pipeline with github action'" + f"git commit -m 'run {self.state['pipeline_name']} pipeline with github" + " action'" ) ) if is_dirty(self.repo): fmt.warning( - "You have modified files in your repository. Do not forget to push changes to your" - " pipeline script as well!" + "You have modified files in your repository. Do not forget to push" + " changes to your pipeline script as well!" ) fmt.echo() fmt.echo("4. Push changes to github. Use your Git UI or the following command") @@ -253,7 +273,8 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: fmt.echo( fmt.bold( github_origin_to_url( - self.origin, f"/actions/workflows/{self.artifacts['serialized_workflow_name']}" + self.origin, + f"/actions/workflows/{self.artifacts['serialized_workflow_name']}", ) ) ) @@ -273,7 +294,9 @@ def __init__( def _generate_workflow(self, *args: Optional[Any]) -> None: self.deployment_method = DeploymentMethods.airflow_composer.value - req_dep = f"{DLT_PKG_NAME}[{Destination.to_name(self.state['destination_type'])}]" + req_dep = ( + f"{DLT_PKG_NAME}[{Destination.to_name(self.state['destination_type'])}]" + ) req_dep_line = f"{req_dep}>={pkg_version(DLT_PKG_NAME)}" self.artifacts["requirements_txt"] = req_dep_line @@ -304,16 +327,20 @@ def _make_modification(self) -> None: self.repo_storage.create_folder(utils.AIRFLOW_BUILD_FOLDER) # save cloudbuild.yaml only if not exist to allow to run the deploy command for many different pipelines - dest_cloud_build = os.path.join(utils.AIRFLOW_BUILD_FOLDER, AIRFLOW_CLOUDBUILD_YAML) + dest_cloud_build = os.path.join( + utils.AIRFLOW_BUILD_FOLDER, AIRFLOW_CLOUDBUILD_YAML + ) if not self.repo_storage.has_file(dest_cloud_build): self.repo_storage.save(dest_cloud_build, self.artifacts["cloudbuild_file"]) else: fmt.warning( - f"{AIRFLOW_CLOUDBUILD_YAML} already created. Delete the file and run the deploy" - " command again to re-create." + f"{AIRFLOW_CLOUDBUILD_YAML} already created. Delete the file and run" + " the deploy command again to re-create." ) - dest_dag_script = os.path.join(utils.AIRFLOW_DAGS_FOLDER, self.artifacts["dag_script_name"]) + dest_dag_script = os.path.join( + utils.AIRFLOW_DAGS_FOLDER, self.artifacts["dag_script_name"] + ) self.repo_storage.save(dest_dag_script, self.artifacts["dag_file"]) def _echo_instructions(self, *args: Optional[Any]) -> None: @@ -330,7 +357,10 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: ) fmt.echo( "* The %s script was created in %s." - % (fmt.bold(self.artifacts["dag_script_name"]), fmt.bold(utils.AIRFLOW_DAGS_FOLDER)) + % ( + fmt.bold(self.artifacts["dag_script_name"]), + fmt.bold(utils.AIRFLOW_DAGS_FOLDER), + ) ) fmt.echo() @@ -347,11 +377,12 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: fmt.echo() fmt.echo( - "If you are planning run the pipeline with Google Cloud Composer, follow the next" - " instructions:\n" + "If you are planning run the pipeline with Google Cloud Composer, follow" + " the next instructions:\n" ) fmt.echo( - "1. Read this doc and set up the Environment: %s" % (fmt.bold(DLT_AIRFLOW_GCP_DOCS_URL)) + "1. Read this doc and set up the Environment: %s" + % (fmt.bold(DLT_AIRFLOW_GCP_DOCS_URL)) ) fmt.echo( "2. Set _BUCKET_NAME up in %s/%s file. " @@ -365,8 +396,8 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: else: if self.secrets_format == SecretFormats.env.value: fmt.echo( - "3. Add the following secret values (typically stored in %s): \n%s\n%s\nin" - " ENVIRONMENT VARIABLES using Google Composer UI" + "3. Add the following secret values (typically stored in %s):" + " \n%s\n%s\nin ENVIRONMENT VARIABLES using Google Composer UI" % ( fmt.bold(make_dlt_settings_path(SECRETS_TOML)), fmt.bold( @@ -377,7 +408,8 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: ), fmt.bold( "\n".join( - self.env_prov.get_key_name(v.key, *v.sections) for v in self.envs + self.env_prov.get_key_name(v.key, *v.sections) + for v in self.envs ) ), ) @@ -410,7 +442,8 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: ) fmt.echo("5. Commit and push the pipeline files to github:") fmt.echo( - "a. Add stage deployment files to commit. Use your Git UI or the following command" + "a. Add stage deployment files to commit. Use your Git UI or the following" + " command" ) dag_script_path = self.repo_storage.from_relative_path_to_wd( @@ -424,13 +457,14 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: fmt.echo("b. Commit the files above. Use your Git UI or the following command") fmt.echo( fmt.bold( - f"git commit -m 'initiate {self.state['pipeline_name']} pipeline with Airflow'" + f"git commit -m 'initiate {self.state['pipeline_name']} pipeline with" + " Airflow'" ) ) if is_dirty(self.repo): fmt.warning( - "You have modified files in your repository. Do not forget to push changes to your" - " pipeline script as well!" + "You have modified files in your repository. Do not forget to push" + " changes to your pipeline script as well!" ) fmt.echo("c. Push changes to github. Use your Git UI or the following command") fmt.echo(fmt.bold("git push origin")) diff --git a/dlt/cli/deploy_command_helpers.py b/dlt/cli/deploy_command_helpers.py index 5065ba1cfc..87bc417013 100644 --- a/dlt/cli/deploy_command_helpers.py +++ b/dlt/cli/deploy_command_helpers.py @@ -84,7 +84,8 @@ def _prepare_deployment(self) -> None: # load a pipeline script and extract full_refresh and pipelines_dir args self.pipeline_script = self.repo_storage.load(self.repo_pipeline_script_path) fmt.echo( - "Looking up the deployment template scripts in %s...\n" % fmt.bold(self.repo_location) + "Looking up the deployment template scripts in %s...\n" + % fmt.bold(self.repo_location) ) self.template_storage = git.get_fresh_repo_files( self.repo_location, get_dlt_repos_dir(), branch=self.branch @@ -97,14 +98,17 @@ def _get_origin(self) -> str: if "github.com" not in origin: raise CliCommandException( "deploy", - f"Your current repository origin is not set to github but to {origin}.\nYou" - " must change it to be able to run the pipelines with github actions:" + "Your current repository origin is not set to github but to" + f" {origin}.\nYou" + " must change it to be able to run the pipelines with github" + " actions:" " https://docs.github.com/en/get-started/getting-started-with-git/managing-remote-repositories", ) except ValueError: raise CliCommandException( "deploy", - "Your current repository has no origin set. Please set it up to be able to run the" + "Your current repository has no origin set. Please set it up to be able" + " to run the" " pipelines with github actions:" " https://docs.github.com/en/get-started/importing-your-projects-to-github/importing-source-code-to-github/adding-locally-hosted-code-to-github", ) @@ -127,13 +131,17 @@ def run_deployment(self) -> None: elif len(uniq_possible_pipelines) > 1: choices = list(uniq_possible_pipelines.keys()) choices_str = "".join([str(i + 1) for i in range(len(choices))]) - choices_selection = [f"{idx+1}-{name}" for idx, name in enumerate(choices)] + choices_selection = [ + f"{idx+1}-{name}" for idx, name in enumerate(choices) + ] sel = fmt.prompt( "Several pipelines found in script, please select one: " + ", ".join(choices_selection), choices=choices_str, ) - pipeline_name, pipelines_dir = uniq_possible_pipelines[choices[int(sel) - 1]] + pipeline_name, pipelines_dir = uniq_possible_pipelines[ + choices[int(sel) - 1] + ] if pipelines_dir: self.pipelines_dir = os.path.abspath(pipelines_dir) @@ -146,11 +154,13 @@ def run_deployment(self) -> None: if not self.pipeline_name: self.pipeline_name = dlt.config.get("pipeline_name") if not self.pipeline_name: - self.pipeline_name = get_default_pipeline_name(self.pipeline_script_path) + self.pipeline_name = get_default_pipeline_name( + self.pipeline_script_path + ) fmt.warning( - f"Using default pipeline name {self.pipeline_name}. The pipeline name" - " is not passed as argument to dlt.pipeline nor configured via config" - " provides ie. config.toml" + f"Using default pipeline name {self.pipeline_name}. The" + " pipeline name is not passed as argument to dlt.pipeline" + " nor configured via config provides ie. config.toml" ) # fmt.echo("Generating deployment for pipeline %s" % fmt.bold(self.pipeline_name)) @@ -231,34 +241,40 @@ def get_state_and_trace(pipeline: Pipeline) -> Tuple[TPipelineState, PipelineTra trace = pipeline.last_trace if trace is None or len(trace.steps) == 0: raise PipelineWasNotRun( - "Pipeline run trace could not be found. Please run the pipeline at least once locally." + "Pipeline run trace could not be found. Please run the pipeline at least" + " once locally." ) last_step = trace.steps[-1] if last_step.step_exception is not None: raise PipelineWasNotRun( - "The last pipeline run ended with error. Please make sure that pipeline runs correctly" - f" before deployment.\n{last_step.step_exception}" + "The last pipeline run ended with error. Please make sure that pipeline" + f" runs correctly before deployment.\n{last_step.step_exception}" ) if not isinstance(last_step.step_info, LoadInfo): raise PipelineWasNotRun( - "The last pipeline run did not reach the load step. Please run the pipeline locally" - " until it loads data into destination." + "The last pipeline run did not reach the load step. Please run the pipeline" + " locally until it loads data into destination." ) return pipeline.state, trace -def get_visitors(pipeline_script: str, pipeline_script_path: str) -> PipelineScriptVisitor: +def get_visitors( + pipeline_script: str, pipeline_script_path: str +) -> PipelineScriptVisitor: visitor = utils.parse_init_script("deploy", pipeline_script, pipeline_script_path) if n.RUN not in visitor.known_calls: raise CliCommandException( "deploy", - f"The pipeline script {pipeline_script_path} does not seem to run the pipeline.", + f"The pipeline script {pipeline_script_path} does not seem to run the" + " pipeline.", ) return visitor -def parse_pipeline_info(visitor: PipelineScriptVisitor) -> List[Tuple[str, Optional[str]]]: +def parse_pipeline_info( + visitor: PipelineScriptVisitor, +) -> List[Tuple[str, Optional[str]]]: pipelines: List[Tuple[str, Optional[str]]] = [] if n.PIPELINE in visitor.known_calls: for call_args in visitor.known_calls[n.PIPELINE]: @@ -268,17 +284,19 @@ def parse_pipeline_info(visitor: PipelineScriptVisitor) -> List[Tuple[str, Optio f_r_value = evaluate_node_literal(f_r_node) if f_r_value is None: fmt.warning( - "The value of `full_refresh` in call to `dlt.pipeline` cannot be" - f" determined from {unparse(f_r_node).strip()}. We assume that you know" - " what you are doing :)" + "The value of `full_refresh` in call to `dlt.pipeline` cannot" + f" be determined from {unparse(f_r_node).strip()}. We assume" + " that you know what you are doing :)" ) if f_r_value is True: if fmt.confirm( - "The value of 'full_refresh' is set to True. Do you want to abort to set it" - " to False?", + "The value of 'full_refresh' is set to True. Do you want to" + " abort to set it to False?", default=True, ): - raise CliCommandException("deploy", "Please set the full_refresh to False") + raise CliCommandException( + "deploy", "Please set the full_refresh to False" + ) p_d_node = call_args.arguments.get("pipelines_dir") if p_d_node: @@ -286,9 +304,10 @@ def parse_pipeline_info(visitor: PipelineScriptVisitor) -> List[Tuple[str, Optio if pipelines_dir is None: raise CliCommandException( "deploy", - "The value of 'pipelines_dir' argument in call to `dlt_pipeline` cannot be" - f" determined from {unparse(p_d_node).strip()}. Pipeline working dir will" - " be found. Pass it directly with --pipelines-dir option.", + "The value of 'pipelines_dir' argument in call to" + " `dlt_pipeline` cannot be determined from" + f" {unparse(p_d_node).strip()}. Pipeline working dir will be" + " found. Pass it directly with --pipelines-dir option.", ) p_n_node = call_args.arguments.get("pipeline_name") @@ -297,9 +316,10 @@ def parse_pipeline_info(visitor: PipelineScriptVisitor) -> List[Tuple[str, Optio if pipeline_name is None: raise CliCommandException( "deploy", - "The value of 'pipeline_name' argument in call to `dlt_pipeline` cannot be" - f" determined from {unparse(p_d_node).strip()}. Pipeline working dir will" - " be found. Pass it directly with --pipeline-name option.", + "The value of 'pipeline_name' argument in call to" + " `dlt_pipeline` cannot be determined from" + f" {unparse(p_d_node).strip()}. Pipeline working dir will be" + " found. Pass it directly with --pipeline-name option.", ) pipelines.append((pipeline_name, pipelines_dir)) @@ -309,7 +329,9 @@ def parse_pipeline_info(visitor: PipelineScriptVisitor) -> List[Tuple[str, Optio def str_representer(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: # format multiline strings as blocks with the exception of placeholders # that will be expanded as yaml - if len(data.splitlines()) > 1 and "{{ toYaml" not in data: # check for multiline string + if ( + len(data.splitlines()) > 1 and "{{ toYaml" not in data + ): # check for multiline string return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") return dumper.represent_scalar("tag:yaml.org,2002:str", data) @@ -327,7 +349,9 @@ def serialize_templated_yaml(tree: StrAny) -> str: tree, allow_unicode=True, default_flow_style=False, sort_keys=False ) # removes apostrophes around the template - serialized = re.sub(r"'([\s\n]*?\${{.+?}})'", r"\1", serialized, flags=re.DOTALL) + serialized = re.sub( + r"'([\s\n]*?\${{.+?}})'", r"\1", serialized, flags=re.DOTALL + ) # print(serialized) # fix the new lines in templates ending }} serialized = re.sub(r"(\${{.+)\n.+(}})", r"\1 \2", serialized) @@ -336,7 +360,9 @@ def serialize_templated_yaml(tree: StrAny) -> str: yaml.add_representer(str, old_representer) -def generate_pip_freeze(requirements_blacklist: List[str], requirements_file_name: str) -> str: +def generate_pip_freeze( + requirements_blacklist: List[str], requirements_file_name: str +) -> str: pkgs = pipdeptree.get_installed_distributions(local_only=True, user_only=False) # construct graph with all packages @@ -346,8 +372,12 @@ def generate_pip_freeze(requirements_blacklist: List[str], requirements_file_nam nodes = [p for p in tree.keys() if p.key not in branch_keys] # compute excludes to compute includes as set difference - excludes = set(req.strip() for req in requirements_blacklist if not req.strip().startswith("#")) - includes = set(node.project_name for node in nodes if node.project_name not in excludes) + excludes = set( + req.strip() for req in requirements_blacklist if not req.strip().startswith("#") + ) + includes = set( + node.project_name for node in nodes if node.project_name not in excludes + ) # prepare new filtered DAG tree = tree.sort() diff --git a/dlt/cli/echo.py b/dlt/cli/echo.py index bd9cf24f64..bc84699aa9 100644 --- a/dlt/cli/echo.py +++ b/dlt/cli/echo.py @@ -8,7 +8,9 @@ @contextlib.contextmanager -def always_choose(always_choose_default: bool, always_choose_value: Any) -> Iterator[None]: +def always_choose( + always_choose_default: bool, always_choose_value: Any +) -> Iterator[None]: """Temporarily answer all confirmations and prompts with the values specified in arguments""" global ALWAYS_CHOOSE_DEFAULT, ALWAYS_CHOOSE_VALUE _always_choose_default = ALWAYS_CHOOSE_DEFAULT diff --git a/dlt/cli/init_command.py b/dlt/cli/init_command.py index 522b3a6712..6e5bffe68d 100644 --- a/dlt/cli/init_command.py +++ b/dlt/cli/init_command.py @@ -60,7 +60,9 @@ def _select_source_files( remote_deleted: Dict[str, TVerifiedSourceFileEntry], conflict_modified: Sequence[str], conflict_deleted: Sequence[str], -) -> Tuple[str, Dict[str, TVerifiedSourceFileEntry], Dict[str, TVerifiedSourceFileEntry]]: +) -> Tuple[ + str, Dict[str, TVerifiedSourceFileEntry], Dict[str, TVerifiedSourceFileEntry] +]: # some files were changed and cannot be updated (or are created without index) fmt.echo( "Existing files for %s source were changed and cannot be automatically updated" @@ -68,7 +70,8 @@ def _select_source_files( ) if conflict_modified: fmt.echo( - "Following files are MODIFIED locally and CONFLICT with incoming changes: %s" + "Following files are MODIFIED locally and CONFLICT with incoming" + " changes: %s" % fmt.bold(", ".join(conflict_modified)) ) if conflict_deleted: @@ -90,9 +93,12 @@ def _select_source_files( % fmt.bold(", ".join(can_delete_files)) ) prompt = ( - "Should incoming changes be Skipped, Applied (local changes will be lost) or Merged (%s" - " UPDATED | %s DELETED | all local changes remain)?" - % (fmt.bold(",".join(can_update_files)), fmt.bold(",".join(can_delete_files))) + "Should incoming changes be Skipped, Applied (local changes will be lost)" + " or Merged (%s UPDATED | %s DELETED | all local changes remain)?" + % ( + fmt.bold(",".join(can_update_files)), + fmt.bold(",".join(can_delete_files)), + ) ) choices = "sam" else: @@ -107,9 +113,15 @@ def _select_source_files( remote_deleted.clear() elif resolution == "m": # update what we can - fmt.echo("Merging the incoming changes. No files with local changes were modified.") - remote_modified = {n: e for n, e in remote_modified.items() if n in can_update_files} - remote_deleted = {n: e for n, e in remote_deleted.items() if n in can_delete_files} + fmt.echo( + "Merging the incoming changes. No files with local changes were modified." + ) + remote_modified = { + n: e for n, e in remote_modified.items() if n in can_update_files + } + remote_deleted = { + n: e for n, e in remote_deleted.items() if n in can_delete_files + } else: # fully overwrite, leave all files to be copied fmt.echo("Applying all incoming changes to local files.") @@ -129,13 +141,17 @@ def _get_dependency_system(dest_storage: FileStorage) -> str: def _list_verified_sources( repo_location: str, branch: str = None ) -> Dict[str, VerifiedSourceFiles]: - clone_storage = git.get_fresh_repo_files(repo_location, get_dlt_repos_dir(), branch=branch) + clone_storage = git.get_fresh_repo_files( + repo_location, get_dlt_repos_dir(), branch=branch + ) sources_storage = FileStorage(clone_storage.make_full_path(SOURCES_MODULE_NAME)) sources: Dict[str, VerifiedSourceFiles] = {} for source_name in files_ops.get_verified_source_names(sources_storage): try: - sources[source_name] = files_ops.get_verified_source_files(sources_storage, source_name) + sources[source_name] = files_ops.get_verified_source_files( + sources_storage, source_name + ) except Exception as ex: fmt.warning(f"Verified source {source_name} not available: {ex}") @@ -151,27 +167,35 @@ def _welcome_message( ) -> None: fmt.echo() if source_files.is_template: - fmt.echo("Your new pipeline %s is ready to be customized!" % fmt.bold(source_name)) + fmt.echo( + "Your new pipeline %s is ready to be customized!" % fmt.bold(source_name) + ) fmt.echo( "* Review and change how dlt loads your data in %s" % fmt.bold(source_files.dest_pipeline_script) ) else: if is_new_source: - fmt.echo("Verified source %s was added to your project!" % fmt.bold(source_name)) + fmt.echo( + "Verified source %s was added to your project!" % fmt.bold(source_name) + ) fmt.echo( "* See the usage examples and code snippets to copy from %s" % fmt.bold(source_files.dest_pipeline_script) ) else: fmt.echo( - "Verified source %s was updated to the newest version!" % fmt.bold(source_name) + "Verified source %s was updated to the newest version!" + % fmt.bold(source_name) ) if is_new_source: fmt.echo( "* Add credentials for %s and other secrets in %s" - % (fmt.bold(destination_type), fmt.bold(make_dlt_settings_path(SECRETS_TOML))) + % ( + fmt.bold(destination_type), + fmt.bold(make_dlt_settings_path(SECRETS_TOML)), + ) ) if dependency_system: @@ -180,7 +204,8 @@ def _welcome_message( for dep in compiled_requirements: fmt.echo(" " + fmt.bold(dep)) fmt.echo( - " If the dlt dependency is already added, make sure you install the extra for %s to it" + " If the dlt dependency is already added, make sure you install the extra" + " for %s to it" % fmt.bold(destination_type) ) if dependency_system == utils.REQUIREMENTS_TXT: @@ -191,7 +216,9 @@ def _welcome_message( ) elif dependency_system == utils.PYPROJECT_TOML: fmt.echo(" If you are using poetry you may issue the following command:") - fmt.echo(fmt.bold(" poetry add %s -E %s" % (DLT_PKG_NAME, destination_type))) + fmt.echo( + fmt.bold(" poetry add %s -E %s" % (DLT_PKG_NAME, destination_type)) + ) fmt.echo() else: fmt.echo( @@ -213,7 +240,9 @@ def _welcome_message( def list_verified_sources_command(repo_location: str, branch: str = None) -> None: fmt.echo("Looking up for verified sources in %s..." % fmt.bold(repo_location)) - for source_name, source_files in _list_verified_sources(repo_location, branch).items(): + for source_name, source_files in _list_verified_sources( + repo_location, branch + ).items(): reqs = source_files.requirements dlt_req_string = str(reqs.dlt_requirement_base) msg = "%s: %s" % (fmt.bold(source_name), source_files.doc) @@ -234,14 +263,18 @@ def init_command( destination_spec = destination_reference.spec fmt.echo("Looking up the init scripts in %s..." % fmt.bold(repo_location)) - clone_storage = git.get_fresh_repo_files(repo_location, get_dlt_repos_dir(), branch=branch) + clone_storage = git.get_fresh_repo_files( + repo_location, get_dlt_repos_dir(), branch=branch + ) # copy init files from here init_storage = FileStorage(clone_storage.make_full_path(INIT_MODULE_NAME)) # copy dlt source files from here sources_storage = FileStorage(clone_storage.make_full_path(SOURCES_MODULE_NAME)) # load init module and get init files and script init_module = load_script_module(clone_storage.storage_path, INIT_MODULE_NAME) - pipeline_script, template_files = _get_template_files(init_module, use_generic_template) + pipeline_script, template_files = _get_template_files( + init_module, use_generic_template + ) # prepare destination storage dest_storage = FileStorage(os.path.abspath(".")) if not dest_storage.has_folder(get_dlt_settings_dir()): @@ -279,7 +312,11 @@ def init_command( if conflict_modified or conflict_deleted: # select source files that can be copied/updated _, remote_modified, remote_deleted = _select_source_files( - source_name, remote_modified, remote_deleted, conflict_modified, conflict_deleted + source_name, + remote_modified, + remote_deleted, + conflict_modified, + conflict_deleted, ) if not remote_deleted and not remote_modified: fmt.echo("No files to update, exiting") @@ -287,8 +324,8 @@ def init_command( if remote_index["is_dirty"]: fmt.warning( - f"The verified sources repository is dirty. {source_name} source files may not" - " update correctly in the future." + f"The verified sources repository is dirty. {source_name} source files" + " may not update correctly in the future." ) # add template files source_files.files.extend(template_files) @@ -307,7 +344,9 @@ def init_command( "", ) if dest_storage.has_file(dest_pipeline_script): - fmt.warning("Pipeline script %s already exist, exiting" % dest_pipeline_script) + fmt.warning( + "Pipeline script %s already exist, exiting" % dest_pipeline_script + ) return # add .dlt/*.toml files to be copied @@ -327,7 +366,8 @@ def init_command( ) fmt.warning(msg) if not fmt.confirm( - "Would you like to continue anyway? (you can update dlt after this step)", default=True + "Would you like to continue anyway? (you can update dlt after this step)", + default=True, ): fmt.echo( "You can update dlt with: pip3 install -U" @@ -344,15 +384,16 @@ def init_command( if visitor.is_destination_imported: raise CliCommandException( "init", - f"The pipeline script {source_files.pipeline_script} import a destination from" - " dlt.destinations. You should specify destinations by name when calling dlt.pipeline" - " or dlt.run in init scripts.", + f"The pipeline script {source_files.pipeline_script} import a destination" + " from dlt.destinations. You should specify destinations by name when" + " calling dlt.pipeline or dlt.run in init scripts.", ) if n.PIPELINE not in visitor.known_calls: raise CliCommandException( "init", - f"The pipeline script {source_files.pipeline_script} does not seem to initialize" - " pipeline with dlt.pipeline. Please initialize pipeline explicitly in init scripts.", + f"The pipeline script {source_files.pipeline_script} does not seem to" + " initialize pipeline with dlt.pipeline. Please initialize pipeline" + " explicitly in init scripts.", ) # find all arguments in all calls to replace @@ -387,17 +428,18 @@ def init_command( ) # template sources are always in module starting with "pipeline" # for templates, place config and secrets into top level section - required_secrets, required_config, checked_sources = source_detection.detect_source_configs( - _SOURCES, "pipeline", () + required_secrets, required_config, checked_sources = ( + source_detection.detect_source_configs(_SOURCES, "pipeline", ()) ) # template has a strict rules where sources are placed for source_q_name, source_config in checked_sources.items(): if source_q_name not in visitor.known_sources_resources: raise CliCommandException( "init", - f"The pipeline script {source_files.pipeline_script} imports a source/resource" - f" {source_config.f.__name__} from module {source_config.module.__name__}. In" - " init scripts you must declare all sources and resources in single file.", + f"The pipeline script {source_files.pipeline_script} imports a" + f" source/resource {source_config.f.__name__} from module" + f" {source_config.module.__name__}. In init scripts you must" + " declare all sources and resources in single file.", ) # rename sources and resources transformed_nodes.extend( @@ -410,15 +452,17 @@ def init_command( ) # pipeline sources are in module with name starting from {pipeline_name} # for verified pipelines place in the specific source section - required_secrets, required_config, checked_sources = source_detection.detect_source_configs( - _SOURCES, source_name, (known_sections.SOURCES, source_name) + required_secrets, required_config, checked_sources = ( + source_detection.detect_source_configs( + _SOURCES, source_name, (known_sections.SOURCES, source_name) + ) ) if len(checked_sources) == 0: raise CliCommandException( "init", - f"The pipeline script {source_files.pipeline_script} is not creating or importing any" - " sources or resources", + f"The pipeline script {source_files.pipeline_script} is not creating or" + " importing any sources or resources", ) # add destination spec to required secrets @@ -440,8 +484,9 @@ def init_command( if is_new_source: if source_files.is_template: fmt.echo( - "A verified source %s was not found. Using a template to create a new source and" - " pipeline with name %s." % (fmt.bold(source_name), fmt.bold(source_name)) + "A verified source %s was not found. Using a template to create a new" + " source and pipeline with name %s." + % (fmt.bold(source_name), fmt.bold(source_name)) ) else: fmt.echo( @@ -449,12 +494,16 @@ def init_command( % (fmt.bold(source_name), source_files.doc) ) if use_generic_template: - fmt.warning("--generic parameter is meaningless if verified source is found") + fmt.warning( + "--generic parameter is meaningless if verified source is found" + ) if not fmt.confirm("Do you want to proceed?", default=True): raise CliCommandException("init", "Aborted") dependency_system = _get_dependency_system(dest_storage) - _welcome_message(source_name, destination_type, source_files, dependency_system, is_new_source) + _welcome_message( + source_name, destination_type, source_files, dependency_system, is_new_source + ) # copy files at the very end for file_name in source_files.files: @@ -489,7 +538,9 @@ def init_command( # generate tomls with comments secrets_prov = SecretsTomlProvider() # print(secrets_prov._toml) - write_values(secrets_prov._toml, required_secrets.values(), overwrite_existing=False) + write_values( + secrets_prov._toml, required_secrets.values(), overwrite_existing=False + ) config_prov = ConfigTomlProvider() write_values(config_prov._toml, required_config.values(), overwrite_existing=False) # write toml files diff --git a/dlt/cli/pipeline_command.py b/dlt/cli/pipeline_command.py index 0eb73ad7a8..356295600d 100644 --- a/dlt/cli/pipeline_command.py +++ b/dlt/cli/pipeline_command.py @@ -16,7 +16,9 @@ from dlt.cli import echo as fmt -DLT_PIPELINE_COMMAND_DOCS_URL = "https://dlthub.com/docs/reference/command-line-interface" +DLT_PIPELINE_COMMAND_DOCS_URL = ( + "https://dlthub.com/docs/reference/command-line-interface" +) def pipeline_command( @@ -71,7 +73,8 @@ def pipeline_command( # remote state was not found p._wipe_working_folder() fmt.error( - f"Pipeline {pipeline_name} was not found in dataset {dataset_name} in {destination}" + f"Pipeline {pipeline_name} was not found in dataset {dataset_name} in" + f" {destination}" ) return if operation == "sync": @@ -81,7 +84,8 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: extracted_packages = p.list_extracted_load_packages() if extracted_packages: fmt.echo( - "Has %s extracted packages ready to be normalized with following load ids:" + "Has %s extracted packages ready to be normalized with following load" + " ids:" % fmt.bold(str(len(extracted_packages))) ) for load_id in extracted_packages: @@ -98,12 +102,16 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: first_package_info = p.get_load_package_info(norm_packages[0]) if PackageStorage.is_package_partially_loaded(first_package_info): fmt.warning( - "This package is partially loaded. Data in the destination may be modified." + "This package is partially loaded. Data in the destination may be" + " modified." ) fmt.echo() return extracted_packages, norm_packages - fmt.echo("Found pipeline %s in %s" % (fmt.bold(p.pipeline_name), fmt.bold(p.pipelines_dir))) + fmt.echo( + "Found pipeline %s in %s" + % (fmt.bold(p.pipeline_name), fmt.bold(p.pipelines_dir)) + ) if operation == "show": from dlt.common.runtime import signals @@ -145,7 +153,9 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: if verbosity > 0: fmt.echo(json.dumps(sources_state, pretty=True)) else: - fmt.echo("Add -v option to see sources state. Note that it could be large.") + fmt.echo( + "Add -v option to see sources state. Note that it could be large." + ) fmt.echo() fmt.echo("Local state:") @@ -161,7 +171,9 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: fmt.echo("Resources in schema: %s" % fmt.bold(schema_name)) schema = p.schemas[schema_name] data_tables = {t["name"]: t for t in schema.data_tables()} - for resource_name, tables in group_tables_by_resource(data_tables).items(): + for resource_name, tables in group_tables_by_resource( + data_tables + ).items(): res_state_slots = 0 if sources_state: source_state = ( @@ -170,7 +182,9 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: else sources_state.get(schema_name) ) if source_state: - resource_state_ = resource_state(resource_name, source_state) + resource_state_ = resource_state( + resource_name, source_state + ) res_state_slots = len(resource_state_) fmt.echo( "%s with %s table(s) and %s resource state slot(s)" @@ -223,7 +237,10 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: fmt.bold(failed_job.job_file_info.table_name), ) ) - fmt.echo("JOB file type: %s" % fmt.bold(failed_job.job_file_info.file_format)) + fmt.echo( + "JOB file type: %s" + % fmt.bold(failed_job.job_file_info.file_format) + ) fmt.echo("JOB file path: %s" % fmt.bold(failed_job.file_path)) if verbosity > 0: fmt.echo(failed_job.asstr(verbosity)) @@ -242,8 +259,8 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: if operation == "sync": if fmt.confirm( - "About to drop the local state of the pipeline and reset all the schemas. The" - " destination state, data and schemas are left intact. Proceed?", + "About to drop the local state of the pipeline and reset all the schemas." + " The destination state, data and schemas are left intact. Proceed?", default=False, ): fmt.echo("Dropping local state") @@ -267,12 +284,15 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: package_info = p.get_load_package_info(load_id) fmt.echo( - "Package %s found in %s" % (fmt.bold(load_id), fmt.bold(package_info.package_path)) + "Package %s found in %s" + % (fmt.bold(load_id), fmt.bold(package_info.package_path)) ) fmt.echo(package_info.asstr(verbosity)) if len(package_info.schema_update) > 0: if verbosity == 0: - print("Add -v option to see schema update. Note that it could be large.") + print( + "Add -v option to see schema update. Note that it could be large." + ) else: tables = remove_defaults({"tables": package_info.schema_update}) # type: ignore fmt.echo(fmt.bold("Schema update:")) @@ -294,7 +314,9 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: remove_defaults_ = command_kwargs.get("remove_defaults") s = p.default_schema if format_ == "json": - schema_str = json.dumps(s.to_dict(remove_defaults=remove_defaults_), pretty=True) + schema_str = json.dumps( + s.to_dict(remove_defaults=remove_defaults_), pretty=True + ) else: schema_str = s.to_pretty_yaml(remove_defaults=remove_defaults_) fmt.echo(schema_str) @@ -303,8 +325,8 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: drop = DropCommand(p, **command_kwargs) if drop.is_empty: fmt.echo( - "Could not select any resources to drop and no resource/source state to reset. Use" - " the command below to inspect the pipeline:" + "Could not select any resources to drop and no resource/source state to" + " reset. Use the command below to inspect the pipeline:" ) fmt.echo(f"dlt pipeline -v {p.pipeline_name} info") if len(drop.info["warnings"]): @@ -320,7 +342,10 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: fmt.bold(p.destination.destination_name), ) ) - fmt.echo("%s: %s" % (fmt.style("Selected schema", fg="green"), drop.info["schema_name"])) + fmt.echo( + "%s: %s" + % (fmt.style("Selected schema", fg="green"), drop.info["schema_name"]) + ) fmt.echo( "%s: %s" % ( @@ -328,7 +353,9 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: drop.info["resource_names"], ) ) - fmt.echo("%s: %s" % (fmt.style("Table(s) to drop", fg="green"), drop.info["tables"])) + fmt.echo( + "%s: %s" % (fmt.style("Table(s) to drop", fg="green"), drop.info["tables"]) + ) fmt.echo( "%s: %s" % ( diff --git a/dlt/cli/pipeline_files.py b/dlt/cli/pipeline_files.py index 49c0f71b21..8b6458b6b7 100644 --- a/dlt/cli/pipeline_files.py +++ b/dlt/cli/pipeline_files.py @@ -19,7 +19,13 @@ SOURCES_INIT_INFO_ENGINE_VERSION = 1 SOURCES_INIT_INFO_FILE = ".sources" -IGNORE_FILES = ["*.py[cod]", "*$py.class", "__pycache__", "py.typed", "requirements.txt"] +IGNORE_FILES = [ + "*.py[cod]", + "*$py.class", + "__pycache__", + "py.typed", + "requirements.txt", +] IGNORE_SOURCES = [".*", "_*"] @@ -53,13 +59,19 @@ class TVerifiedSourcesFileIndex(TypedDict): def _save_dot_sources(index: TVerifiedSourcesFileIndex) -> None: - with open(make_dlt_settings_path(SOURCES_INIT_INFO_FILE), "w", encoding="utf-8") as f: - yaml.dump(index, f, allow_unicode=True, default_flow_style=False, sort_keys=False) + with open( + make_dlt_settings_path(SOURCES_INIT_INFO_FILE), "w", encoding="utf-8" + ) as f: + yaml.dump( + index, f, allow_unicode=True, default_flow_style=False, sort_keys=False + ) def _load_dot_sources() -> TVerifiedSourcesFileIndex: try: - with open(make_dlt_settings_path(SOURCES_INIT_INFO_FILE), "r", encoding="utf-8") as f: + with open( + make_dlt_settings_path(SOURCES_INIT_INFO_FILE), "r", encoding="utf-8" + ) as f: index: TVerifiedSourcesFileIndex = yaml.safe_load(f) if not index: raise FileNotFoundError(SOURCES_INIT_INFO_FILE) @@ -155,7 +167,10 @@ def get_verified_source_names(sources_storage: FileStorage) -> List[str]: if not any(fnmatch.fnmatch(n, ignore) for ignore in IGNORE_SOURCES) ]: # must contain at least one valid python script - if any(f.endswith(".py") for f in sources_storage.list_folder_files(name, to_root=False)): + if any( + f.endswith(".py") + for f in sources_storage.list_folder_files(name, to_root=False) + ): candidates.append(name) return candidates @@ -165,13 +180,15 @@ def get_verified_source_files( ) -> VerifiedSourceFiles: if not sources_storage.has_folder(source_name): raise VerifiedSourceRepoError( - f"Verified source {source_name} could not be found in the repository", source_name + f"Verified source {source_name} could not be found in the repository", + source_name, ) # find example script example_script = f"{source_name}_pipeline.py" if not sources_storage.has_file(example_script): raise VerifiedSourceRepoError( - f"Pipeline example script {example_script} could not be found in the repository", + f"Pipeline example script {example_script} could not be found in the" + " repository", source_name, ) # get all files recursively @@ -199,12 +216,20 @@ def get_verified_source_files( # read requirements requirements_path = os.path.join(source_name, utils.REQUIREMENTS_TXT) if sources_storage.has_file(requirements_path): - requirements = SourceRequirements.from_string(sources_storage.load(requirements_path)) + requirements = SourceRequirements.from_string( + sources_storage.load(requirements_path) + ) else: requirements = SourceRequirements([]) # find requirements return VerifiedSourceFiles( - False, sources_storage, example_script, example_script, files, requirements, docstring + False, + sources_storage, + example_script, + example_script, + files, + requirements, + docstring, ) @@ -267,7 +292,9 @@ def is_file_modified(file: str, entry: TVerifiedSourceFileEntry) -> bool: for file, entry in remote_modified.items(): if dest_storage.has_file(file): # if local file was changes and it is different from incoming - if is_file_modified(file, entry) and is_file_modified(file, local_index["files"][file]): + if is_file_modified(file, entry) and is_file_modified( + file, local_index["files"][file] + ): conflict_modified.append(file) else: # file was deleted but is modified on remote diff --git a/dlt/cli/source_detection.py b/dlt/cli/source_detection.py index 636615af61..0fb942341d 100644 --- a/dlt/cli/source_detection.py +++ b/dlt/cli/source_detection.py @@ -15,7 +15,9 @@ def find_call_arguments_to_replace( - visitor: PipelineScriptVisitor, replace_nodes: List[Tuple[str, str]], init_script_name: str + visitor: PipelineScriptVisitor, + replace_nodes: List[Tuple[str, str]], + init_script_name: str, ) -> List[Tuple[ast.AST, ast.AST]]: # the input tuple (call argument name, replacement value) # the returned tuple (node, replacement value, node type) @@ -27,14 +29,19 @@ def find_call_arguments_to_replace( for t_arg_name, t_value in replace_nodes: dn_node: ast.AST = args.arguments.get(t_arg_name) if dn_node is not None: - if not isinstance(dn_node, ast.Constant) or not isinstance(dn_node.value, str): + if not isinstance(dn_node, ast.Constant) or not isinstance( + dn_node.value, str + ): raise CliCommandException( "init", - f"The pipeline script {init_script_name} must pass the {t_arg_name} as" - f" string to '{arg_name}' function in line {dn_node.lineno}", + f"The pipeline script {init_script_name} must pass the" + f" {t_arg_name} as string to '{arg_name}' function in line" + f" {dn_node.lineno}", ) else: - transformed_nodes.append((dn_node, ast.Constant(value=t_value, kind=None))) + transformed_nodes.append( + (dn_node, ast.Constant(value=t_value, kind=None)) + ) replaced_args.add(t_arg_name) # there was at least one replacement @@ -43,8 +50,8 @@ def find_call_arguments_to_replace( raise CliCommandException( "init", f"The pipeline script {init_script_name} is not explicitly passing the" - f" '{t_arg_name}' argument to 'pipeline' or 'run' function. In init script the" - " default and configured values are not accepted.", + f" '{t_arg_name}' argument to 'pipeline' or 'run' function. In init" + " script the default and configured values are not accepted.", ) return transformed_nodes @@ -73,7 +80,11 @@ def find_source_calls_to_replace( def detect_source_configs( sources: Dict[str, SourceInfo], module_prefix: str, section: Tuple[str, ...] -) -> Tuple[Dict[str, WritableConfigValue], Dict[str, WritableConfigValue], Dict[str, SourceInfo]]: +) -> Tuple[ + Dict[str, WritableConfigValue], + Dict[str, WritableConfigValue], + Dict[str, SourceInfo], +]: # all detected secrets with sections required_secrets: Dict[str, WritableConfigValue] = {} # all detected configs with sections @@ -85,7 +96,9 @@ def detect_source_configs( # accept only sources declared in the `init` or `pipeline` modules if source_info.module.__name__.startswith(module_prefix): checked_sources[source_name] = source_info - source_config = source_info.SPEC() if source_info.SPEC else BaseConfiguration() + source_config = ( + source_info.SPEC() if source_info.SPEC else BaseConfiguration() + ) spec_fields = source_config.get_resolvable_fields() for field_name, field_type in spec_fields.items(): val_store = None @@ -94,7 +107,8 @@ def detect_source_configs( val_store = required_secrets # all configs that are required and do not have a default value must go to config.toml elif ( - not is_optional_type(field_type) and getattr(source_config, field_name) is None + not is_optional_type(field_type) + and getattr(source_config, field_name) is None ): val_store = required_config diff --git a/dlt/cli/telemetry_command.py b/dlt/cli/telemetry_command.py index bb451ea979..2506ffe6c6 100644 --- a/dlt/cli/telemetry_command.py +++ b/dlt/cli/telemetry_command.py @@ -8,7 +8,9 @@ from dlt.cli import echo as fmt from dlt.cli.utils import get_telemetry_status from dlt.cli.config_toml_writer import WritableConfigValue, write_values -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.runtime.segment import get_anonymous_id DLT_TELEMETRY_DOCS_URL = "https://dlthub.com/docs/reference/telemetry" @@ -25,7 +27,9 @@ def telemetry_status_command() -> None: def change_telemetry_status_command(enabled: bool) -> None: # value to write telemetry_value = [ - WritableConfigValue("dlthub_telemetry", bool, enabled, (RunConfiguration.__section__,)) + WritableConfigValue( + "dlthub_telemetry", bool, enabled, (RunConfiguration.__section__,) + ) ] # write local config config = ConfigTomlProvider(add_global_config=False) diff --git a/dlt/cli/utils.py b/dlt/cli/utils.py index 5ea4471d7e..2c5e100558 100644 --- a/dlt/cli/utils.py +++ b/dlt/cli/utils.py @@ -36,8 +36,8 @@ def parse_init_script( if len(visitor.mod_aliases) == 0: raise CliCommandException( command, - f"The pipeline script {init_script_name} does not import dlt and does not seem to run" - " any pipelines", + f"The pipeline script {init_script_name} does not import dlt and does not" + " seem to run any pipelines", ) return visitor @@ -51,13 +51,16 @@ def ensure_git_command(command: str) -> None: raise raise CliCommandException( command, - "'git' command is not available. Install and setup git with the following the guide %s" + "'git' command is not available. Install and setup git with the following" + " the guide %s" % "https://docs.github.com/en/get-started/quickstart/set-up-git", imp_ex, ) from imp_ex -def track_command(command: str, track_before: bool, *args: str) -> Callable[[TFun], TFun]: +def track_command( + command: str, track_before: bool, *args: str +) -> Callable[[TFun], TFun]: return with_telemetry("command", command, track_before, *args) diff --git a/dlt/common/arithmetics.py b/dlt/common/arithmetics.py index 56d8fcd49b..6192a4bd9b 100644 --- a/dlt/common/arithmetics.py +++ b/dlt/common/arithmetics.py @@ -36,7 +36,9 @@ def default_context(c: Context, precision: int) -> Context: @contextmanager -def numeric_default_context(precision: int = DEFAULT_NUMERIC_PRECISION) -> Iterator[Context]: +def numeric_default_context( + precision: int = DEFAULT_NUMERIC_PRECISION, +) -> Iterator[Context]: with localcontext() as c: yield default_context(c, precision) diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 8de57f7799..35a4fe76ff 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,4 +1,9 @@ -from .specs.base_configuration import configspec, is_valid_hint, is_secret_hint, resolve_type +from .specs.base_configuration import ( + configspec, + is_valid_hint, + is_secret_hint, + resolve_type, +) from .specs import known_sections from .resolve import resolve_configuration, inject_section from .inject import with_config, last_config, get_fun_spec, create_resolved_partial diff --git a/dlt/common/configuration/accessors.py b/dlt/common/configuration/accessors.py index dfadc97fa3..7a7804744e 100644 --- a/dlt/common/configuration/accessors.py +++ b/dlt/common/configuration/accessors.py @@ -6,9 +6,14 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.exceptions import ConfigFieldMissingException, LookupTrace from dlt.common.configuration.providers.provider import ConfigProvider -from dlt.common.configuration.specs import BaseConfiguration, is_base_configuration_inner_hint +from dlt.common.configuration.specs import ( + BaseConfiguration, + is_base_configuration_inner_hint, +) from dlt.common.configuration.utils import deserialize_value, log_traces, auto_cast -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.typing import AnyType, ConfigValue, TSecretValue DLT_SECRETS_VALUE = "secrets.value" @@ -60,7 +65,9 @@ def writable_provider(self) -> ConfigProvider: def _get_providers_from_context(self) -> Sequence[ConfigProvider]: return Container()[ConfigProvidersContext].providers - def _get_value(self, field: str, type_hint: Type[Any] = None) -> Tuple[Any, List[LookupTrace]]: + def _get_value( + self, field: str, type_hint: Type[Any] = None + ) -> Tuple[Any, List[LookupTrace]]: # get default hint type, in case of dlt.secrets it it TSecretValue type_hint = type_hint or self.default_type # split field into sections and a key @@ -126,7 +133,9 @@ def default_type(self) -> AnyType: def writable_provider(self) -> ConfigProvider: """find first writable provider that supports secrets - should be secrets.toml""" return next( - p for p in self._get_providers_from_context() if p.is_writable and p.supports_secrets + p + for p in self._get_providers_from_context() + if p.is_writable and p.supports_secrets ) value: ClassVar[Any] = ConfigValue diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py index 441b0e21bc..3f8f342c10 100644 --- a/dlt/common/configuration/container.py +++ b/dlt/common/configuration/container.py @@ -32,7 +32,9 @@ class Container: _MAIN_THREAD_ID: ClassVar[int] = threading.get_ident() """A main thread id to which get item will fallback for contexts without default""" - thread_contexts: Dict[int, Dict[Type[ContainerInjectableContext], ContainerInjectableContext]] + thread_contexts: Dict[ + int, Dict[Type[ContainerInjectableContext], ContainerInjectableContext] + ] """A thread aware mapping of injection context """ _context_container_locks: Dict[str, threading.Lock] """Locks for container types on threads.""" diff --git a/dlt/common/configuration/exceptions.py b/dlt/common/configuration/exceptions.py index 1d8423057f..4102c9e5f8 100644 --- a/dlt/common/configuration/exceptions.py +++ b/dlt/common/configuration/exceptions.py @@ -35,15 +35,17 @@ class ConfigProviderException(ConfigurationException): class ConfigurationWrongTypeException(ConfigurationException): def __init__(self, _typ: type) -> None: super().__init__( - f"Invalid configuration instance type {_typ}. Configuration instances must derive from" - " BaseConfiguration." + f"Invalid configuration instance type {_typ}. Configuration instances must" + " derive from BaseConfiguration." ) class ConfigFieldMissingException(KeyError, ConfigurationException): """raises when not all required config fields are present""" - def __init__(self, spec_name: str, traces: Mapping[str, Sequence[LookupTrace]]) -> None: + def __init__( + self, spec_name: str, traces: Mapping[str, Sequence[LookupTrace]] + ) -> None: self.traces = traces self.spec_name = spec_name self.fields = list(traces.keys()) @@ -51,11 +53,14 @@ def __init__(self, spec_name: str, traces: Mapping[str, Sequence[LookupTrace]]) def __str__(self) -> str: msg = ( - f"Following fields are missing: {str(self.fields)} in configuration with spec" - f" {self.spec_name}\n" + f"Following fields are missing: {str(self.fields)} in configuration with" + f" spec {self.spec_name}\n" ) for f, field_traces in self.traces.items(): - msg += f'\tfor field "{f}" config providers and keys were tried in following order:\n' + msg += ( + f'\tfor field "{f}" config providers and keys were tried in following' + " order:\n" + ) for tr in field_traces: msg += f"\t\tIn {tr.provider} key {tr.key} was not found.\n" # check if entry point is run with path. this is common problem so warn the user @@ -66,13 +71,14 @@ def __str__(self) -> str: if abs_main_dir != os.getcwd(): # directory was specified msg += ( - "WARNING: dlt looks for .dlt folder in your current working directory and your" - " cwd (%s) is different from directory of your pipeline script (%s).\n" - % (os.getcwd(), abs_main_dir) + "WARNING: dlt looks for .dlt folder in your current working" + " directory and your cwd (%s) is different from directory of your" + " pipeline script (%s).\n" % (os.getcwd(), abs_main_dir) ) msg += ( - "If you keep your secret files in the same folder as your pipeline script but" - " run your script from some other folder, secrets/configs will not be found\n" + "If you keep your secret files in the same folder as your pipeline" + " script but run your script from some other folder," + " secrets/configs will not be found\n" ) msg += ( "Please refer to https://dlthub.com/docs/general-usage/credentials for more" @@ -100,9 +106,10 @@ def __init__(self, spec_name: str, field_names: Sequence[str]) -> None: f">>> {name}: Any" for name in field_names ) msg = ( - f"The config spec {spec_name} has dynamic type resolvers for fields: {field_names} but" - " these fields are not defined in the spec.\nWhen using @resolve_type() decorator, Add" - f" the fields with 'Any' or another common type hint, example:\n\n{example}" + f"The config spec {spec_name} has dynamic type resolvers for fields:" + f" {field_names} but these fields are not defined in the spec.\nWhen using" + " @resolve_type() decorator, Add the fields with 'Any' or another common" + f" type hint, example:\n\n{example}" ) super().__init__(msg) @@ -112,7 +119,8 @@ class FinalConfigFieldException(ConfigurationException): def __init__(self, spec_name: str, field: str) -> None: super().__init__( - f"Field {field} in spec {spec_name} is final but is being changed by a config provider" + f"Field {field} in spec {spec_name} is final but is being changed by a" + " config provider" ) @@ -124,7 +132,8 @@ def __init__(self, field_name: str, field_value: Any, hint: type) -> None: self.field_value = field_value self.hint = hint super().__init__( - "Configured value for field %s cannot be coerced into type %s" % (field_name, str(hint)) + "Configured value for field %s cannot be coerced into type %s" + % (field_name, str(hint)) ) @@ -152,7 +161,8 @@ def __init__(self, field_name: str, spec: Type[Any]) -> None: self.field_name = field_name self.typ_ = spec super().__init__( - f"Field {field_name} on configspec {spec} does not provide required type hint" + f"Field {field_name} on configspec {spec} does not provide required type" + " hint" ) @@ -163,7 +173,8 @@ def __init__(self, field_name: str, spec: Type[Any], typ_: Type[Any]) -> None: self.field_name = field_name self.typ_ = spec super().__init__( - f"Field {field_name} on configspec {spec} has hint with unsupported type {typ_}" + f"Field {field_name} on configspec {spec} has hint with unsupported type" + f" {typ_}" ) @@ -172,8 +183,8 @@ def __init__(self, provider_name: str, key: str) -> None: self.provider_name = provider_name self.key = key super().__init__( - f"Provider {provider_name} cannot hold secret values but key {key} with secret value is" - " present" + f"Provider {provider_name} cannot hold secret values but key {key} with" + " secret value is present" ) @@ -189,34 +200,41 @@ def __init__( self.native_value_type = native_value_type self.embedded_sections = embedded_sections self.inner_exception = inner_exception - inner_msg = f" {self.inner_exception}" if inner_exception is not ValueError else "" + inner_msg = ( + f" {self.inner_exception}" if inner_exception is not ValueError else "" + ) super().__init__( - f"{spec.__name__} cannot parse the configuration value provided. The value is of type" - f" {native_value_type.__name__} and comes from the" + f"{spec.__name__} cannot parse the configuration value provided. The value" + f" is of type {native_value_type.__name__} and comes from the" f" {embedded_sections} section(s).{inner_msg}" ) class ContainerInjectableContextMangled(ContainerException): - def __init__(self, spec: Type[Any], existing_config: Any, expected_config: Any) -> None: + def __init__( + self, spec: Type[Any], existing_config: Any, expected_config: Any + ) -> None: self.spec = spec self.existing_config = existing_config self.expected_config = expected_config super().__init__( - f"When restoring context {spec.__name__}, instance {expected_config} was expected," - f" instead instance {existing_config} was found." + f"When restoring context {spec.__name__}, instance {expected_config} was" + f" expected, instead instance {existing_config} was found." ) class ContextDefaultCannotBeCreated(ContainerException, KeyError): def __init__(self, spec: Type[Any]) -> None: self.spec = spec - super().__init__(f"Container cannot create the default value of context {spec.__name__}.") + super().__init__( + f"Container cannot create the default value of context {spec.__name__}." + ) class DuplicateConfigProviderException(ConfigProviderException): def __init__(self, provider_name: str) -> None: self.provider_name = provider_name super().__init__( - f"Provider with name {provider_name} already present in ConfigProvidersContext" + f"Provider with name {provider_name} already present in" + " ConfigProvidersContext" ) diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index 6699826ec8..21621712c5 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -98,10 +98,13 @@ def decorator(f: TFun) -> TFun: sig: Signature = inspect.signature(f) signature_fields: Dict[str, Any] kwargs_arg = next( - (p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None + (p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), + None, ) if spec is None: - SPEC, signature_fields = spec_from_signature(f, sig, include_defaults, base=base) + SPEC, signature_fields = spec_from_signature( + f, sig, include_defaults, base=base + ) else: SPEC = spec signature_fields = SPEC.get_resolvable_fields() @@ -125,7 +128,9 @@ def decorator(f: TFun) -> TFun: if p.name == "pipeline_name" and auto_pipeline_section: # if argument has name pipeline_name and auto_section is used, use it to generate section context pipeline_name_arg = p - pipeline_name_arg_default = None if p.default == Parameter.empty else p.default + pipeline_name_arg_default = ( + None if p.default == Parameter.empty else p.default + ) def resolve_config(bound_args: inspect.BoundArguments) -> BaseConfiguration: """Resolve arguments using the provided spec""" @@ -161,7 +166,9 @@ def resolve_config(bound_args: inspect.BoundArguments) -> BaseConfiguration: ) # this may be called from many threads so section_context is thread affine - with inject_section(section_context, lock_context=lock_context_on_injection): + with inject_section( + section_context, lock_context=lock_context_on_injection + ): # print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}") return resolve_configuration( config or SPEC(), @@ -170,7 +177,10 @@ def resolve_config(bound_args: inspect.BoundArguments) -> BaseConfiguration: ) def update_bound_args( - bound_args: inspect.BoundArguments, config: BaseConfiguration, args: Any, kwargs: Any + bound_args: inspect.BoundArguments, + config: BaseConfiguration, + args: Any, + kwargs: Any, ) -> None: # overwrite or add resolved params resolved_params = dict(config) @@ -189,7 +199,9 @@ def update_bound_args( bound_args.arguments[kwargs_arg.name][_LAST_DLT_CONFIG] = config bound_args.arguments[kwargs_arg.name][_ORIGINAL_ARGS] = (args, kwargs) - def with_partially_resolved_config(config: Optional[BaseConfiguration] = None) -> Any: + def with_partially_resolved_config( + config: Optional[BaseConfiguration] = None, + ) -> Any: # creates a pre-resolved partial of the decorated function empty_bound_args = sig.bind_partial() if not config: @@ -203,8 +215,8 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: from dlt.common import logger logger.warning( - "Spec argument is provided in kwargs, ignoring it for resolved partial" - " function." + "Spec argument is provided in kwargs, ignoring it for resolved" + " partial function." ) # we can still overwrite the config @@ -236,7 +248,9 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: _FUNC_SPECS[id(_wrap)] = SPEC # add a method to create a pre-resolved partial - setattr(_wrap, "__RESOLVED_PARTIAL_FUNC__", with_partially_resolved_config) # noqa: B010 + setattr( + _wrap, "__RESOLVED_PARTIAL_FUNC__", with_partially_resolved_config + ) # noqa: B010 return _wrap # type: ignore @@ -247,8 +261,8 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: if not callable(func): raise ValueError( - "First parameter to the with_config must be callable ie. by using it as function" - " decorator" + "First parameter to the with_config must be callable ie. by using it as" + " function decorator" ) # We're called as @with_config without parens. @@ -264,7 +278,9 @@ def get_orig_args(**kwargs: Any) -> Tuple[Tuple[Any], DictStrAny]: return kwargs[_ORIGINAL_ARGS] # type: ignore -def create_resolved_partial(f: AnyFun, config: Optional[BaseConfiguration] = None) -> AnyFun: +def create_resolved_partial( + f: AnyFun, config: Optional[BaseConfiguration] = None +) -> AnyFun: """Create a pre-resolved partial of the with_config decorated function""" if partial_func := getattr(f, "__RESOLVED_PARTIAL_FUNC__", None): return cast(AnyFun, partial_func(config)) diff --git a/dlt/common/configuration/providers/airflow.py b/dlt/common/configuration/providers/airflow.py index 99edf258d2..ad58c217eb 100644 --- a/dlt/common/configuration/providers/airflow.py +++ b/dlt/common/configuration/providers/airflow.py @@ -5,7 +5,9 @@ class AirflowSecretsTomlProvider(VaultTomlProvider): - def __init__(self, only_secrets: bool = False, only_toml_fragments: bool = False) -> None: + def __init__( + self, only_secrets: bool = False, only_toml_fragments: bool = False + ) -> None: super().__init__(only_secrets, only_toml_fragments) @property @@ -17,7 +19,9 @@ def _look_vault(self, full_key: str, hint: type) -> str: from airflow.models import Variable - with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( + io.StringIO() + ): return Variable.get(full_key, default_var=None) # type: ignore @property diff --git a/dlt/common/configuration/providers/google_secrets.py b/dlt/common/configuration/providers/google_secrets.py index 98cbbc4553..c2f123c867 100644 --- a/dlt/common/configuration/providers/google_secrets.py +++ b/dlt/common/configuration/providers/google_secrets.py @@ -56,7 +56,9 @@ def get_key_name(key: str, *sections: str) -> str: 4. Underscores. """ key = normalize_key(key) - normalized_sections = [normalize_key(section) for section in sections if section] + normalized_sections = [ + normalize_key(section) for section in sections if section + ] key_name = get_key_name(normalize_key(key), "-", *normalized_sections) return key_name @@ -76,10 +78,20 @@ def _look_vault(self, full_key: str, hint: type) -> str: ) from dlt.common import logger - resource_name = f"projects/{self.credentials.project_id}/secrets/{full_key}/versions/latest" - client = build("secretmanager", "v1", credentials=self.credentials.to_native_credentials()) + resource_name = ( + f"projects/{self.credentials.project_id}/secrets/{full_key}/versions/latest" + ) + client = build( + "secretmanager", "v1", credentials=self.credentials.to_native_credentials() + ) try: - response = client.projects().secrets().versions().access(name=resource_name).execute() + response = ( + client.projects() + .secrets() + .versions() + .access(name=resource_name) + .execute() + ) secret_value = response["payload"]["data"] decoded_value = base64.b64decode(secret_value).decode("utf-8") return decoded_value @@ -91,14 +103,15 @@ def _look_vault(self, full_key: str, hint: type) -> str: elif error.resp.status == 403: logger.warning( f"{self.credentials.client_email} does not have" - " roles/secretmanager.secretAccessor role. It also does not have read" - f" permission to {full_key} or the key is not found in Google Secrets:" - f" {error_doc['message']}[{error_doc['status']}]" + " roles/secretmanager.secretAccessor role. It also does not have" + f" read permission to {full_key} or the key is not found in Google" + f" Secrets: {error_doc['message']}[{error_doc['status']}]" ) return None elif error.resp.status == 400: logger.warning( - f"Unable to read {full_key} : {error_doc['message']}[{error_doc['status']}]" + f"Unable to read {full_key} :" + f" {error_doc['message']}[{error_doc['status']}]" ) return None raise diff --git a/dlt/common/configuration/providers/provider.py b/dlt/common/configuration/providers/provider.py index 405a42bcf0..64e21613a4 100644 --- a/dlt/common/configuration/providers/provider.py +++ b/dlt/common/configuration/providers/provider.py @@ -11,7 +11,9 @@ def get_value( ) -> Tuple[Optional[Any], str]: pass - def set_value(self, key: str, value: Any, pipeline_name: str, *sections: str) -> None: + def set_value( + self, key: str, value: Any, pipeline_name: str, *sections: str + ) -> None: raise NotImplementedError() @property diff --git a/dlt/common/configuration/providers/toml.py b/dlt/common/configuration/providers/toml.py index 7c856e8c27..824ec3d2e5 100644 --- a/dlt/common/configuration/providers/toml.py +++ b/dlt/common/configuration/providers/toml.py @@ -48,7 +48,9 @@ def get_value( except KeyError: return None, full_key - def set_value(self, key: str, value: Any, pipeline_name: str, *sections: str) -> None: + def set_value( + self, key: str, value: Any, pipeline_name: str, *sections: str + ) -> None: if pipeline_name: sections = (pipeline_name,) + sections @@ -165,7 +167,9 @@ def get_value( # generate auxiliary paths to get from vault for known_section in [known_sections.SOURCES, known_sections.DESTINATION]: - def _look_at_idx(idx: int, full_path: Tuple[str, ...], pipeline_name: str) -> None: + def _look_at_idx( + idx: int, full_path: Tuple[str, ...], pipeline_name: str + ) -> None: lookup_key = full_path[idx] lookup_sections = full_path[:idx] lookup_fk = self.get_key_name(lookup_key, *lookup_sections) @@ -185,7 +189,9 @@ def _lookup_paths(pipeline_name_: str, known_section_: str) -> None: _look_at_idx(idx + 1, full_path, pipeline_name_) # first query the shortest paths so the longer paths can override it - _lookup_paths(None, known_section) # check sources and sources. + _lookup_paths( + None, known_section + ) # check sources and sources. if pipeline_name: _lookup_paths( pipeline_name, known_section @@ -213,7 +219,12 @@ def _look_vault(self, full_key: str, hint: type) -> str: pass def _update_from_vault( - self, full_key: str, key: str, hint: type, pipeline_name: str, sections: Tuple[str, ...] + self, + full_key: str, + key: str, + hint: type, + pipeline_name: str, + sections: Tuple[str, ...], ) -> None: if full_key in self._vault_lookups: return @@ -259,11 +270,15 @@ def _read_toml_file( try: project_toml = self._read_toml(self._toml_path) if add_global_config: - global_toml = self._read_toml(os.path.join(self.global_config_path(), file_name)) + global_toml = self._read_toml( + os.path.join(self.global_config_path(), file_name) + ) project_toml = update_dict_nested(global_toml, project_toml) return project_toml except Exception as ex: - raise TomlProviderReadException(self.name, file_name, self._toml_path, str(ex)) + raise TomlProviderReadException( + self.name, file_name, self._toml_path, str(ex) + ) @staticmethod def global_config_path() -> str: @@ -287,8 +302,12 @@ def _read_toml(toml_path: str) -> tomlkit.TOMLDocument: class ConfigTomlProvider(TomlFileProvider): - def __init__(self, project_dir: str = None, add_global_config: bool = False) -> None: - super().__init__(CONFIG_TOML, project_dir=project_dir, add_global_config=add_global_config) + def __init__( + self, project_dir: str = None, add_global_config: bool = False + ) -> None: + super().__init__( + CONFIG_TOML, project_dir=project_dir, add_global_config=add_global_config + ) @property def name(self) -> str: @@ -304,8 +323,12 @@ def is_writable(self) -> bool: class SecretsTomlProvider(TomlFileProvider): - def __init__(self, project_dir: str = None, add_global_config: bool = False) -> None: - super().__init__(SECRETS_TOML, project_dir=project_dir, add_global_config=add_global_config) + def __init__( + self, project_dir: str = None, add_global_config: bool = False + ) -> None: + super().__init__( + SECRETS_TOML, project_dir=project_dir, add_global_config=add_global_config + ) @property def name(self) -> str: diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index ebfa7b6b89..6e88786298 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -1,6 +1,16 @@ import itertools from collections.abc import Mapping as C_Mapping -from typing import Any, Dict, ContextManager, List, Optional, Sequence, Tuple, Type, TypeVar +from typing import ( + Any, + Dict, + ContextManager, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, +) from dlt.common.configuration.providers.provider import ConfigProvider from dlt.common.typing import ( @@ -24,7 +34,9 @@ ) from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.configuration.specs.exceptions import NativeValueError -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.configuration.container import Container from dlt.common.configuration.utils import log_traces, deserialize_value from dlt.common.configuration.exceptions import ( @@ -52,7 +64,9 @@ def resolve_configuration( # try to get the native representation of the top level configuration using the config section as a key # allows, for example, to store connection string or service.json in their native form in single env variable or under single vault key if config.__section__ and explicit_value is None: - initial_hint = TSecretValue if isinstance(config, CredentialsConfiguration) else AnyType + initial_hint = ( + TSecretValue if isinstance(config, CredentialsConfiguration) else AnyType + ) explicit_value, traces = _resolve_single_value( config.__section__, initial_hint, AnyType, None, sections, () ) @@ -60,7 +74,9 @@ def resolve_configuration( # mappings cannot be used as explicit values, we want to enumerate mappings and request the fields' values one by one explicit_value = None else: - log_traces(None, config.__section__, type(config), explicit_value, None, traces) + log_traces( + None, config.__section__, type(config), explicit_value, None, traces + ) return _resolve_configuration(config, sections, (), explicit_value, accept_partial) @@ -92,7 +108,9 @@ def initialize_credentials(hint: Any, initial_value: Any) -> CredentialsConfigur def inject_section( - section_context: ConfigSectionContext, merge_existing: bool = True, lock_context: bool = False + section_context: ConfigSectionContext, + merge_existing: bool = True, + lock_context: bool = False, ) -> ContextManager[ConfigSectionContext]: """Context manager that sets section specified in `section_context` to be used during configuration resolution. Optionally merges the context already in the container with the one provided @@ -121,13 +139,16 @@ def _maybe_parse_native_value( ) -> Any: # use initial value to resolve the whole configuration. if explicit value is a mapping it will be applied field by field later if explicit_value and ( - not isinstance(explicit_value, C_Mapping) or isinstance(explicit_value, BaseConfiguration) + not isinstance(explicit_value, C_Mapping) + or isinstance(explicit_value, BaseConfiguration) ): try: config.parse_native_representation(explicit_value) except ValueError as v_err: # provide generic exception - raise InvalidNativeValue(type(config), type(explicit_value), embedded_sections, v_err) + raise InvalidNativeValue( + type(config), type(explicit_value), embedded_sections, v_err + ) except NotImplementedError: pass # explicit value was consumed @@ -149,11 +170,17 @@ def _resolve_configuration( config.__exception__ = None try: try: - explicit_value = _maybe_parse_native_value(config, explicit_value, embedded_sections) + explicit_value = _maybe_parse_native_value( + config, explicit_value, embedded_sections + ) # if native representation didn't fully resolve the config, we try to resolve field by field if not config.is_resolved(): _resolve_config_fields( - config, explicit_value, explicit_sections, embedded_sections, accept_partial + config, + explicit_value, + explicit_sections, + embedded_sections, + accept_partial, ) # full configuration was resolved config.resolve() @@ -212,7 +239,9 @@ def _resolve_config_fields( ): current_value, traces = explicit_value, [] else: - specs_in_union = get_all_types_of_class_in_union(hint, BaseConfiguration) + specs_in_union = get_all_types_of_class_in_union( + hint, BaseConfiguration + ) if not current_value: if len(specs_in_union) > 1: for idx, alt_spec in enumerate(specs_in_union): @@ -269,7 +298,9 @@ def _resolve_config_fields( unmatched_hint_resolvers.append(field_name) if unmatched_hint_resolvers: - raise UnmatchedConfigHintResolversException(type(config).__name__, unmatched_hint_resolvers) + raise UnmatchedConfigHintResolversException( + type(config).__name__, unmatched_hint_resolvers + ) if unresolved_fields: raise ConfigFieldMissingException(type(config).__name__, unresolved_fields) @@ -318,14 +349,18 @@ def _resolve_config_field( # print(f"{embedded_config} IS RESOLVED with VALUE {value}") # injected context will be resolved if value is not None: - _maybe_parse_native_value(embedded_config, value, embedded_sections + (key,)) + _maybe_parse_native_value( + embedded_config, value, embedded_sections + (key,) + ) value = embedded_config else: # only config with sections may look for initial values if embedded_config.__section__ and value is None: # config section becomes the key if the key does not start with, otherwise it keeps its original value - initial_key, initial_embedded = _apply_embedded_sections_to_config_sections( - embedded_config.__section__, embedded_sections + (key,) + initial_key, initial_embedded = ( + _apply_embedded_sections_to_config_sections( + embedded_config.__section__, embedded_sections + (key,) + ) ) # it must be a secret value is config is credentials initial_hint = ( @@ -334,7 +369,12 @@ def _resolve_config_field( else AnyType ) value, initial_traces = _resolve_single_value( - initial_key, initial_hint, AnyType, None, explicit_sections, initial_embedded + initial_key, + initial_hint, + AnyType, + None, + explicit_sections, + initial_embedded, ) if isinstance(value, C_Mapping): # mappings are not passed as initials @@ -507,4 +547,6 @@ def _apply_embedded_sections_to_config_sections( embedded_sections = embedded_sections[:-1] # remove all embedded ns starting with _ - return config_section, tuple(ns for ns in embedded_sections if not ns.startswith("_")) + return config_section, tuple( + ns for ns in embedded_sections if not ns.startswith("_") + ) diff --git a/dlt/common/configuration/specs/api_credentials.py b/dlt/common/configuration/specs/api_credentials.py index 918cd4ee45..4383671d00 100644 --- a/dlt/common/configuration/specs/api_credentials.py +++ b/dlt/common/configuration/specs/api_credentials.py @@ -1,7 +1,10 @@ from typing import ClassVar, List, Union, Optional from dlt.common.typing import TSecretValue -from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec +from dlt.common.configuration.specs.base_configuration import ( + CredentialsConfiguration, + configspec, +) @configspec @@ -17,7 +20,9 @@ class OAuth2Credentials(CredentialsConfiguration): # add refresh_token when generating config samples __config_gen_annotations__: ClassVar[List[str]] = ["refresh_token"] - def auth(self, scopes: Union[str, List[str]] = None, redirect_url: str = None) -> None: + def auth( + self, scopes: Union[str, List[str]] = None, redirect_url: str = None + ) -> None: """Authorizes the client using the available credentials Uses the `refresh_token` grant if refresh token is available. Note that `scopes` and `redirect_url` are ignored in this flow. diff --git a/dlt/common/configuration/specs/aws_credentials.py b/dlt/common/configuration/specs/aws_credentials.py index ee49e79e40..1696e0ba4f 100644 --- a/dlt/common/configuration/specs/aws_credentials.py +++ b/dlt/common/configuration/specs/aws_credentials.py @@ -56,7 +56,11 @@ def on_partial(self) -> None: def to_session_credentials(self) -> Dict[str, str]: """Return configured or new aws session token""" - if self.aws_session_token and self.aws_access_key_id and self.aws_secret_access_key: + if ( + self.aws_session_token + and self.aws_access_key_id + and self.aws_secret_access_key + ): return dict( aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, @@ -84,9 +88,15 @@ def _to_botocore_session(self) -> Any: if self.profile_name is not None: session.set_config_variable("profile", self.profile_name) - if self.aws_access_key_id or self.aws_secret_access_key or self.aws_session_token: + if ( + self.aws_access_key_id + or self.aws_secret_access_key + or self.aws_session_token + ): session.set_credentials( - self.aws_access_key_id, self.aws_secret_access_key, self.aws_session_token + self.aws_access_key_id, + self.aws_secret_access_key, + self.aws_session_token, ) if self.region_name is not None: session.set_config_variable("region", self.region_name) diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 06fb97fcdd..9700ac5015 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -53,16 +53,22 @@ def is_base_configuration_inner_hint(inner_hint: Type[Any]) -> bool: def is_context_inner_hint(inner_hint: Type[Any]) -> bool: - return inspect.isclass(inner_hint) and issubclass(inner_hint, ContainerInjectableContext) + return inspect.isclass(inner_hint) and issubclass( + inner_hint, ContainerInjectableContext + ) def is_credentials_inner_hint(inner_hint: Type[Any]) -> bool: - return inspect.isclass(inner_hint) and issubclass(inner_hint, CredentialsConfiguration) + return inspect.isclass(inner_hint) and issubclass( + inner_hint, CredentialsConfiguration + ) def get_config_if_union_hint(hint: Type[Any]) -> Type[Any]: if is_union_type(hint): - return next((t for t in get_args(hint) if is_base_configuration_inner_hint(t)), None) + return next( + (t for t in get_args(hint) if is_base_configuration_inner_hint(t)), None + ) return None @@ -117,7 +123,9 @@ def configspec( ) -> Callable[[Type[TAnyClass]], Type[TAnyClass]]: ... -@dataclass_transform(eq_default=False, field_specifiers=(dataclasses.Field, dataclasses.field)) +@dataclass_transform( + eq_default=False, field_specifiers=(dataclasses.Field, dataclasses.field) +) def configspec( cls: Optional[Type[Any]] = None, init: bool = True ) -> Union[Type[TAnyClass], Callable[[Type[TAnyClass]], Type[TAnyClass]]]: @@ -153,8 +161,8 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: for ann in cls.__annotations__: if not hasattr(cls, ann) and not ann.startswith(("__", "_abc_")): warnings.warn( - f"Missing default value for field {ann} on {cls.__name__}. None assumed. All" - " fields in configspec must have default." + f"Missing default value for field {ann} on {cls.__name__}. None" + " assumed. All fields in configspec must have default." ) setattr(cls, ann, None) # get all attributes without corresponding annotations @@ -187,19 +195,28 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: def default_factory(att_value=att_value): # type: ignore[no-untyped-def] return att_value.copy() - setattr(cls, att_name, dataclasses.field(default_factory=default_factory)) + setattr( + cls, + att_name, + dataclasses.field(default_factory=default_factory), + ) # We don't want to overwrite user's __init__ method # Create dataclass init only when not defined in the class # NOTE: any class without synthesized __init__ breaks the creation chain has_default_init = super(cls, cls).__init__ == cls.__init__ # type: ignore[misc] - base_params = getattr(cls, "__dataclass_params__", None) # cls.__init__ is object.__init__ - synth_init = init and ((not base_params or base_params.init) and has_default_init) + base_params = getattr( + cls, "__dataclass_params__", None + ) # cls.__init__ is object.__init__ + synth_init = init and ( + (not base_params or base_params.init) and has_default_init + ) if synth_init != init and has_default_init: warnings.warn( - f"__init__ method will not be generated on {cls.__name__} because bas class didn't" - " synthesize __init__. Please correct `init` flag in confispec decorator. You are" - " probably receiving incorrect __init__ signature for type checking" + f"__init__ method will not be generated on {cls.__name__} because bas" + " class didn't synthesize __init__. Please correct `init` flag in" + " confispec decorator. You are probably receiving incorrect __init__" + " signature for type checking" ) # do not generate repr as it may contain secret values return dataclasses.dataclass(cls, init=synth_init, eq=False, repr=False) # type: ignore @@ -213,7 +230,9 @@ def default_factory(att_value=att_value): # type: ignore[no-untyped-def] @configspec class BaseConfiguration(MutableMapping[str, Any]): - __is_resolved__: bool = dataclasses.field(default=False, init=False, repr=False, compare=False) + __is_resolved__: bool = dataclasses.field( + default=False, init=False, repr=False, compare=False + ) """True when all config fields were resolved and have a specified value type""" __exception__: Exception = dataclasses.field( default=None, init=False, repr=False, compare=False @@ -225,7 +244,9 @@ class BaseConfiguration(MutableMapping[str, Any]): """Additional annotations for config generator, currently holds a list of fields of interest that have defaults""" __dataclass_fields__: ClassVar[Dict[str, TDtcField]] """Typing for dataclass fields""" - __hint_resolvers__: ClassVar[Dict[str, Callable[["BaseConfiguration"], Type[Any]]]] = {} + __hint_resolvers__: ClassVar[ + Dict[str, Callable[["BaseConfiguration"], Type[Any]]] + ] = {} def parse_native_representation(self, native_value: Any) -> None: """Initialize the configuration fields by parsing the `native_value` which should be a native representation of the configuration @@ -314,7 +335,10 @@ def __iter__(self) -> Iterator[str]: """Iterator or valid key names""" return map( lambda field: field.name, - filter(lambda val: self.__is_valid_field(val), self.__dataclass_fields__.values()), + filter( + lambda val: self.__is_valid_field(val), + self.__dataclass_fields__.values(), + ), ) def __len__(self) -> int: @@ -434,7 +458,9 @@ def add_extras(self) -> None: THintResolver = Callable[[TSpec], Type[Any]] -def resolve_type(field_name: str) -> Callable[[THintResolver[TSpec]], THintResolver[TSpec]]: +def resolve_type( + field_name: str, +) -> Callable[[THintResolver[TSpec]], THintResolver[TSpec]]: def decorator(func: THintResolver[TSpec]) -> THintResolver[TSpec]: func.__hint_for_field__ = field_name # type: ignore[attr-defined] diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py index 642634fb0a..021eb38e2f 100644 --- a/dlt/common/configuration/specs/config_providers_context.py +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -105,7 +105,9 @@ def _extra_providers() -> List[ConfigProvider]: extra_providers.extend(_airflow_providers()) if providers_config.enable_google_secrets: extra_providers.append( - _google_secrets_provider(only_toml_fragments=providers_config.only_toml_fragments) + _google_secrets_provider( + only_toml_fragments=providers_config.only_toml_fragments + ) ) return extra_providers @@ -116,7 +118,8 @@ def _google_secrets_provider( from dlt.common.configuration.resolve import resolve_configuration c = resolve_configuration( - GcpServiceAccountCredentials(), sections=(known_sections.PROVIDERS, "google_secrets") + GcpServiceAccountCredentials(), + sections=(known_sections.PROVIDERS, "google_secrets"), ) return GoogleSecretsProvider( c, only_secrets=only_secrets, only_toml_fragments=only_toml_fragments @@ -140,10 +143,14 @@ def _airflow_providers() -> List[ConfigProvider]: try: # hide stdio. airflow typically dumps tons of warnings and deprecations to stdout and stderr - with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( + io.StringIO() + ): # try to get dlt secrets variable. many broken Airflow installations break here. in that case do not create from airflow.models import Variable, TaskInstance # noqa - from dlt.common.configuration.providers.airflow import AirflowSecretsTomlProvider + from dlt.common.configuration.providers.airflow import ( + AirflowSecretsTomlProvider, + ) # probe if Airflow variable containing all secrets is present from dlt.common.configuration.providers.toml import SECRETS_TOML_KEY @@ -164,8 +171,8 @@ def _airflow_providers() -> List[ConfigProvider]: f"Airflow variable '{SECRETS_TOML_KEY}' was not found. " + "This Airflow variable is a recommended place to hold the content of" " secrets.toml." - + "If you do not use Airflow variables to hold dlt configuration or use variables" - " with other names you can ignore this warning." + + "If you do not use Airflow variables to hold dlt configuration or use" + " variables with other names you can ignore this warning." ) ti.log.warning(message) diff --git a/dlt/common/configuration/specs/config_section_context.py b/dlt/common/configuration/specs/config_section_context.py index 1e6cd56155..ab3cae22cd 100644 --- a/dlt/common/configuration/specs/config_section_context.py +++ b/dlt/common/configuration/specs/config_section_context.py @@ -1,7 +1,10 @@ from typing import Callable, List, Optional, Tuple, TYPE_CHECKING from dlt.common.configuration.specs import known_sections -from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext, configspec +from dlt.common.configuration.specs.base_configuration import ( + ContainerInjectableContext, + configspec, +) @configspec @@ -20,28 +23,44 @@ def merge(self, existing: "ConfigSectionContext") -> None: def source_name(self) -> str: """Gets name of a source from `sections`""" - if self.sections and len(self.sections) == 3 and self.sections[0] == known_sections.SOURCES: + if ( + self.sections + and len(self.sections) == 3 + and self.sections[0] == known_sections.SOURCES + ): return self.sections[-1] raise ValueError(self.sections) def source_section(self) -> str: """Gets section of a source from `sections`""" - if self.sections and len(self.sections) == 3 and self.sections[0] == known_sections.SOURCES: + if ( + self.sections + and len(self.sections) == 3 + and self.sections[0] == known_sections.SOURCES + ): return self.sections[1] raise ValueError(self.sections) @staticmethod - def prefer_incoming(incoming: "ConfigSectionContext", existing: "ConfigSectionContext") -> None: + def prefer_incoming( + incoming: "ConfigSectionContext", existing: "ConfigSectionContext" + ) -> None: incoming.pipeline_name = incoming.pipeline_name or existing.pipeline_name incoming.sections = incoming.sections or existing.sections - incoming.source_state_key = incoming.source_state_key or existing.source_state_key + incoming.source_state_key = ( + incoming.source_state_key or existing.source_state_key + ) @staticmethod - def prefer_existing(incoming: "ConfigSectionContext", existing: "ConfigSectionContext") -> None: + def prefer_existing( + incoming: "ConfigSectionContext", existing: "ConfigSectionContext" + ) -> None: """Prefer existing section context when merging this context before injecting""" incoming.pipeline_name = existing.pipeline_name or incoming.pipeline_name incoming.sections = existing.sections or incoming.sections - incoming.source_state_key = existing.source_state_key or incoming.source_state_key + incoming.source_state_key = ( + existing.source_state_key or incoming.source_state_key + ) @staticmethod def resource_merge_style( @@ -60,10 +79,14 @@ def resource_merge_style( existing.sections[1] or incoming.sections[1], incoming.sections[2], ) - incoming.source_state_key = existing.source_state_key or incoming.source_state_key + incoming.source_state_key = ( + existing.source_state_key or incoming.source_state_key + ) else: incoming.sections = incoming.sections or existing.sections - incoming.source_state_key = incoming.source_state_key or existing.source_state_key + incoming.source_state_key = ( + incoming.source_state_key or existing.source_state_key + ) def __str__(self) -> str: return ( diff --git a/dlt/common/configuration/specs/connection_string_credentials.py b/dlt/common/configuration/specs/connection_string_credentials.py index 2691c5d886..ab688cc172 100644 --- a/dlt/common/configuration/specs/connection_string_credentials.py +++ b/dlt/common/configuration/specs/connection_string_credentials.py @@ -4,12 +4,17 @@ from dlt.common.libs.sql_alchemy import URL, make_url from dlt.common.configuration.specs.exceptions import InvalidConnectionString from dlt.common.typing import TSecretValue -from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec +from dlt.common.configuration.specs.base_configuration import ( + CredentialsConfiguration, + configspec, +) @configspec class ConnectionStringCredentials(CredentialsConfiguration): - drivername: str = dataclasses.field(default=None, init=False, repr=False, compare=False) + drivername: str = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) database: str = None password: Optional[TSecretValue] = None username: str = None diff --git a/dlt/common/configuration/specs/exceptions.py b/dlt/common/configuration/specs/exceptions.py index 7a0b283630..7cb02d0055 100644 --- a/dlt/common/configuration/specs/exceptions.py +++ b/dlt/common/configuration/specs/exceptions.py @@ -10,8 +10,8 @@ class OAuth2ScopesRequired(SpecException): def __init__(self, spec: type) -> None: self.spec = spec super().__init__( - "Scopes are required to retrieve refresh_token. Use 'openid' scope for a token without" - " any permissions to resources." + "Scopes are required to retrieve refresh_token. Use 'openid' scope for a" + " token without any permissions to resources." ) @@ -26,8 +26,9 @@ class InvalidConnectionString(NativeValueError): def __init__(self, spec: Type[Any], native_value: str, driver: str): driver = driver or "driver" msg = ( - f"The expected representation for {spec.__name__} is a standard database connection" - f" string with the following format: {driver}://username:password@host:port/database." + f"The expected representation for {spec.__name__} is a standard database" + " connection string with the following format:" + f" {driver}://username:password@host:port/database." ) super().__init__(spec, native_value, msg) @@ -35,9 +36,9 @@ def __init__(self, spec: Type[Any], native_value: str, driver: str): class InvalidGoogleNativeCredentialsType(NativeValueError): def __init__(self, spec: Type[Any], native_value: Any): msg = ( - f"Credentials {spec.__name__} accept a string with serialized credentials json file or" - " an instance of Credentials object from google.* namespace. The value passed is of" - f" type {type(native_value)}" + f"Credentials {spec.__name__} accept a string with serialized credentials" + " json file or an instance of Credentials object from google.* namespace." + f" The value passed is of type {type(native_value)}" ) super().__init__(spec, native_value, msg) @@ -45,9 +46,9 @@ def __init__(self, spec: Type[Any], native_value: Any): class InvalidGoogleServicesJson(NativeValueError): def __init__(self, spec: Type[Any], native_value: Any): msg = ( - f"The expected representation for {spec.__name__} is a string with serialized service" - " account credentials, where at least 'project_id', 'private_key' and 'client_email`" - " keys are present" + f"The expected representation for {spec.__name__} is a string with" + " serialized service account credentials, where at least 'project_id'," + " 'private_key' and 'client_email` keys are present" ) super().__init__(spec, native_value, msg) @@ -55,8 +56,9 @@ def __init__(self, spec: Type[Any], native_value: Any): class InvalidGoogleOauth2Json(NativeValueError): def __init__(self, spec: Type[Any], native_value: Any): msg = ( - f"The expected representation for {spec.__name__} is a string with serialized oauth2" - " user info and may be wrapped in 'install'/'web' node - depending of oauth2 app type." + f"The expected representation for {spec.__name__} is a string with" + " serialized oauth2 user info and may be wrapped in 'install'/'web' node -" + " depending of oauth2 app type." ) super().__init__(spec, native_value, msg) @@ -64,7 +66,7 @@ def __init__(self, spec: Type[Any], native_value: Any): class InvalidBoto3Session(NativeValueError): def __init__(self, spec: Type[Any], native_value: Any): msg = ( - f"The expected representation for {spec.__name__} is and instance of boto3.Session" - " containing credentials" + f"The expected representation for {spec.__name__} is and instance of" + " boto3.Session containing credentials" ) super().__init__(spec, native_value, msg) diff --git a/dlt/common/configuration/specs/gcp_credentials.py b/dlt/common/configuration/specs/gcp_credentials.py index 4d81a493a3..31b073c252 100644 --- a/dlt/common/configuration/specs/gcp_credentials.py +++ b/dlt/common/configuration/specs/gcp_credentials.py @@ -24,10 +24,16 @@ @configspec class GcpCredentials(CredentialsConfiguration): token_uri: Final[str] = dataclasses.field( - default="https://oauth2.googleapis.com/token", init=False, repr=False, compare=False + default="https://oauth2.googleapis.com/token", + init=False, + repr=False, + compare=False, ) auth_uri: Final[str] = dataclasses.field( - default="https://accounts.google.com/o/oauth2/auth", init=False, repr=False, compare=False + default="https://accounts.google.com/o/oauth2/auth", + init=False, + repr=False, + compare=False, ) project_id: str = None @@ -64,7 +70,8 @@ def to_gcs_credentials(self) -> Dict[str, Any]: "project": self.project_id, "token": ( None - if isinstance(self, CredentialsWithDefault) and self.has_default_credentials() + if isinstance(self, CredentialsWithDefault) + and self.has_default_credentials() else dict(self) ), } @@ -82,14 +89,18 @@ def parse_native_representation(self, native_value: Any) -> None: """Accepts ServiceAccountCredentials as native value. In other case reverts to serialized services.json""" service_dict: DictStrAny = None try: - from google.oauth2.service_account import Credentials as ServiceAccountCredentials + from google.oauth2.service_account import ( + Credentials as ServiceAccountCredentials, + ) if isinstance(native_value, ServiceAccountCredentials): # extract credentials service_dict = { "project_id": native_value.project_id, "client_email": native_value.service_account_email, - "private_key": native_value, # keep native credentials in private key + "private_key": ( + native_value + ), # keep native credentials in private key } self.__is_resolved__ = True except ImportError: @@ -113,7 +124,9 @@ def on_resolved(self) -> None: def to_native_credentials(self) -> Any: """Returns google.oauth2.service_account.Credentials""" - from google.oauth2.service_account import Credentials as ServiceAccountCredentials + from google.oauth2.service_account import ( + Credentials as ServiceAccountCredentials, + ) if isinstance(self.private_key, ServiceAccountCredentials): # private key holds the native instance if it was passed to parse_native_representation @@ -170,7 +183,9 @@ def parse_native_representation(self, native_value: Any) -> None: def to_native_representation(self) -> str: return json.dumps(self._info_dict()) - def auth(self, scopes: Union[str, List[str]] = None, redirect_url: str = None) -> None: + def auth( + self, scopes: Union[str, List[str]] = None, redirect_url: str = None + ) -> None: if not self.refresh_token: self.add_scopes(scopes) if not self.scopes: @@ -199,7 +214,9 @@ def _get_access_token(self) -> TSecretValue: try: from requests_oauthlib import OAuth2Session except ModuleNotFoundError: - raise MissingDependencyException("GcpOAuthCredentials", ["requests_oauthlib"]) + raise MissingDependencyException( + "GcpOAuthCredentials", ["requests_oauthlib"] + ) google = OAuth2Session(client_id=self.client_id, scope=self.scopes) extra = {"client_id": self.client_id, "client_secret": self.client_secret} @@ -208,12 +225,18 @@ def _get_access_token(self) -> TSecretValue: )["access_token"] return TSecretValue(token) - def _get_refresh_token(self, redirect_url: str) -> Tuple[TSecretValue, TSecretValue]: + def _get_refresh_token( + self, redirect_url: str + ) -> Tuple[TSecretValue, TSecretValue]: try: from google_auth_oauthlib.flow import InstalledAppFlow except ModuleNotFoundError: - raise MissingDependencyException("GcpOAuthCredentials", ["google-auth-oauthlib"]) - flow = InstalledAppFlow.from_client_config(self._installed_dict(redirect_url), self.scopes) + raise MissingDependencyException( + "GcpOAuthCredentials", ["google-auth-oauthlib"] + ) + flow = InstalledAppFlow.from_client_config( + self._installed_dict(redirect_url), self.scopes + ) credentials = flow.run_local_server(port=0) return TSecretValue(credentials.refresh_token), TSecretValue(credentials.token) @@ -222,7 +245,9 @@ def to_native_credentials(self) -> Any: try: from google.oauth2.credentials import Credentials as GoogleOAuth2Credentials except ModuleNotFoundError: - raise MissingDependencyException("GcpOAuthCredentials", ["google-auth-oauthlib"]) + raise MissingDependencyException( + "GcpOAuthCredentials", ["google-auth-oauthlib"] + ) credentials = GoogleOAuth2Credentials.from_authorized_user_info(info=dict(self)) return credentials @@ -312,7 +337,9 @@ def parse_native_representation(self, native_value: Any) -> None: GcpDefaultCredentials.parse_native_representation(self, native_value) except NativeValueError: pass - GcpServiceAccountCredentialsWithoutDefaults.parse_native_representation(self, native_value) + GcpServiceAccountCredentialsWithoutDefaults.parse_native_representation( + self, native_value + ) @configspec @@ -322,4 +349,6 @@ def parse_native_representation(self, native_value: Any) -> None: GcpDefaultCredentials.parse_native_representation(self, native_value) except NativeValueError: pass - GcpOAuthCredentialsWithoutDefaults.parse_native_representation(self, native_value) + GcpOAuthCredentialsWithoutDefaults.parse_native_representation( + self, native_value + ) diff --git a/dlt/common/configuration/specs/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py index b57b4abbdd..c2d51737cf 100644 --- a/dlt/common/configuration/specs/run_configuration.py +++ b/dlt/common/configuration/specs/run_configuration.py @@ -4,8 +4,15 @@ from typing import Any, ClassVar, Optional, IO from dlt.common.typing import TSecretStrValue -from dlt.common.utils import encoding_for_mode, main_module_file_path, reveal_pseudo_secret -from dlt.common.configuration.specs.base_configuration import BaseConfiguration, configspec +from dlt.common.utils import ( + encoding_for_mode, + main_module_file_path, + reveal_pseudo_secret, +) +from dlt.common.configuration.specs.base_configuration import ( + BaseConfiguration, + configspec, +) from dlt.common.configuration.exceptions import ConfigFileNotFoundException diff --git a/dlt/common/configuration/utils.py b/dlt/common/configuration/utils.py index 5a7330447b..e07a1031a1 100644 --- a/dlt/common/configuration/utils.py +++ b/dlt/common/configuration/utils.py @@ -9,7 +9,10 @@ from dlt.common.typing import AnyType, TAny from dlt.common.data_types import coerce_value, py_type_to_sc_type from dlt.common.configuration.providers import EnvironProvider -from dlt.common.configuration.exceptions import ConfigValueCannotBeCoercedException, LookupTrace +from dlt.common.configuration.exceptions import ( + ConfigValueCannotBeCoercedException, + LookupTrace, +) from dlt.common.configuration.specs.base_configuration import ( BaseConfiguration, is_base_configuration_inner_hint, @@ -149,7 +152,9 @@ def get_resolved_traces() -> Dict[str, ResolvedValueTrace]: return _RESOLVED_TRACES -def add_config_to_env(config: BaseConfiguration, sections: Tuple[str, ...] = ()) -> None: +def add_config_to_env( + config: BaseConfiguration, sections: Tuple[str, ...] = () +) -> None: """Writes values in configuration back into environment using the naming convention of EnvironProvider. Will descend recursively if embedded BaseConfiguration instances are found""" if config.__section__: sections += (config.__section__,) @@ -157,7 +162,9 @@ def add_config_to_env(config: BaseConfiguration, sections: Tuple[str, ...] = ()) def add_config_dict_to_env( - dict_: Mapping[str, Any], sections: Tuple[str, ...] = (), overwrite_keys: bool = False + dict_: Mapping[str, Any], + sections: Tuple[str, ...] = (), + overwrite_keys: bool = False, ) -> None: """Writes values in dict_ back into environment using the naming convention of EnvironProvider. Applies `sections` if specified. Does not overwrite existing keys by default""" for k, v in dict_.items(): diff --git a/dlt/common/data_writers/__init__.py b/dlt/common/data_writers/__init__.py index 04c5d04328..9c02e4c8bd 100644 --- a/dlt/common/data_writers/__init__.py +++ b/dlt/common/data_writers/__init__.py @@ -1,4 +1,8 @@ -from dlt.common.data_writers.writers import DataWriter, DataWriterMetrics, TLoaderFileFormat +from dlt.common.data_writers.writers import ( + DataWriter, + DataWriterMetrics, + TLoaderFileFormat, +) from dlt.common.data_writers.buffered import BufferedDataWriter, new_file_id from dlt.common.data_writers.escape import ( escape_redshift_literal, diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index b10b1d14b9..4bba867042 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -48,7 +48,9 @@ def __init__( _caps: DestinationCapabilitiesContext = None ): self.file_format = file_format - self._file_format_spec = DataWriter.data_format_from_file_format(self.file_format) + self._file_format_spec = DataWriter.data_format_from_file_format( + self.file_format + ) if self._file_format_spec.requires_destination_capabilities and not _caps: raise DestinationCapabilitiesRequired(file_format) self._caps = _caps @@ -56,7 +58,9 @@ def __init__( self.file_name_template = file_name_template self.closed_files: List[DataWriterMetrics] = [] # all fully processed files # buffered items must be less than max items in file - self.buffer_max_items = min(buffer_max_items, file_max_items or buffer_max_items) + self.buffer_max_items = min( + buffer_max_items, file_max_items or buffer_max_items + ) self.file_max_bytes = file_max_bytes self.file_max_items = file_max_items # the open function is either gzip.open or open @@ -125,7 +129,9 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> int if self.file_max_bytes and self._file.tell() >= self.file_max_bytes: self._rotate_file() # rotate on max items - elif self.file_max_items and self._writer.items_count >= self.file_max_items: + elif ( + self.file_max_items and self._writer.items_count >= self.file_max_items + ): self._rotate_file() return new_rows_count @@ -138,7 +144,9 @@ def write_empty_file(self, columns: TTableSchemaColumns) -> DataWriterMetrics: self._last_modified = time.time() return self._rotate_file(allow_empty_file=True) - def import_file(self, file_path: str, metrics: DataWriterMetrics) -> DataWriterMetrics: + def import_file( + self, file_path: str, metrics: DataWriterMetrics + ) -> DataWriterMetrics: """Import a file from `file_path` into items storage under a new file name. Does not check the imported file format. Uses counts from `metrics` as a base. Logically closes the imported file @@ -177,13 +185,17 @@ def closed(self) -> bool: def __enter__(self) -> "BufferedDataWriter[TWriter]": return self - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: Any) -> None: + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: Any + ) -> None: self.close() def _rotate_file(self, allow_empty_file: bool = False) -> DataWriterMetrics: metrics = self._flush_and_close_file(allow_empty_file) self._file_name = ( - self.file_name_template % new_file_id() + "." + self._file_format_spec.file_extension + self.file_name_template % new_file_id() + + "." + + self._file_format_spec.file_extension ) self._created = time.time() return metrics @@ -206,7 +218,9 @@ def _flush_items(self, allow_empty_file: bool = False) -> None: self._buffered_items.clear() self._buffered_items_count = 0 - def _flush_and_close_file(self, allow_empty_file: bool = False) -> DataWriterMetrics: + def _flush_and_close_file( + self, allow_empty_file: bool = False + ) -> DataWriterMetrics: # if any buffered items exist, flush them self._flush_items(allow_empty_file) # if writer exists then close it diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index 5460657253..898baf0a70 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -24,7 +24,9 @@ def _escape_extended( ) -> str: escape_dict = escape_dict or SQL_ESCAPE_DICT escape_re = escape_re or SQL_ESCAPE_RE - return "{}{}{}".format(prefix, escape_re.sub(lambda x: escape_dict[x.group(0)], v), "'") + return "{}{}{}".format( + prefix, escape_re.sub(lambda x: escape_dict[x.group(0)], v), "'" + ) def escape_redshift_literal(v: Any) -> Any: @@ -95,7 +97,10 @@ def escape_mssql_literal(v: Any) -> Any: return f"'{v.isoformat()}'" if isinstance(v, (list, dict)): return _escape_extended( - json.dumps(v), prefix="N'", escape_dict=MS_SQL_ESCAPE_DICT, escape_re=MS_SQL_ESCAPE_RE + json.dumps(v), + prefix="N'", + escape_dict=MS_SQL_ESCAPE_DICT, + escape_re=MS_SQL_ESCAPE_RE, ) if isinstance(v, bytes): from dlt.destinations.impl.mssql.mssql import VARBINARY_MAX_N @@ -144,7 +149,9 @@ def escape_databricks_literal(v: Any) -> Any: if isinstance(v, (datetime, date, time)): return f"'{v.isoformat()}'" if isinstance(v, (list, dict)): - return _escape_extended(json.dumps(v), prefix="'", escape_dict=DATABRICKS_ESCAPE_DICT) + return _escape_extended( + json.dumps(v), prefix="'", escape_dict=DATABRICKS_ESCAPE_DICT + ) if isinstance(v, bytes): return f"X'{v.hex()}'" if v is None: diff --git a/dlt/common/data_writers/exceptions.py b/dlt/common/data_writers/exceptions.py index d3a073cf4e..60346badaa 100644 --- a/dlt/common/data_writers/exceptions.py +++ b/dlt/common/data_writers/exceptions.py @@ -10,8 +10,8 @@ class InvalidFileNameTemplateException(DataWriterException, ValueError): def __init__(self, file_name_template: str): self.file_name_template = file_name_template super().__init__( - f"Wrong file name template {file_name_template}. File name template must contain" - " exactly one %s formatter" + f"Wrong file name template {file_name_template}. File name template must" + " contain exactly one %s formatter" ) @@ -25,5 +25,6 @@ class DestinationCapabilitiesRequired(DataWriterException, ValueError): def __init__(self, file_format: TLoaderFileFormat): self.file_format = file_format super().__init__( - f"Writer for {file_format} requires destination capabilities which were not provided." + f"Writer for {file_format} requires destination capabilities which were not" + " provided." ) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 2aadb010e0..e14d5e6d4b 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -75,7 +75,9 @@ def write_data(self, rows: Sequence[Any]) -> None: def write_footer(self) -> None: pass - def write_all(self, columns_schema: TTableSchemaColumns, rows: Sequence[Any]) -> None: + def write_all( + self, columns_schema: TTableSchemaColumns, rows: Sequence[Any] + ) -> None: self.write_header(columns_schema) self.write_data(rows) self.write_footer() @@ -87,7 +89,10 @@ def data_format(cls) -> TFileFormatSpec: @classmethod def from_file_format( - cls, file_format: TLoaderFileFormat, f: IO[Any], caps: DestinationCapabilitiesContext = None + cls, + file_format: TLoaderFileFormat, + f: IO[Any], + caps: DestinationCapabilitiesContext = None, ) -> "DataWriter": return cls.class_factory(file_format)(f, caps) @@ -98,7 +103,9 @@ def from_destination_capabilities( return cls.class_factory(caps.preferred_loader_file_format)(f, caps) @classmethod - def data_format_from_file_format(cls, file_format: TLoaderFileFormat) -> TFileFormatSpec: + def data_format_from_file_format( + cls, file_format: TLoaderFileFormat + ) -> TFileFormatSpec: return cls.class_factory(file_format).data_format() @staticmethod @@ -284,7 +291,9 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: [ pyarrow.field( name, - get_py_arrow_datatype(schema_item, self._caps, self.timestamp_timezone), + get_py_arrow_datatype( + schema_item, self._caps, self.timestamp_timezone + ), nullable=schema_item.get("nullable", True), ) for name, schema_item in columns_schema.items() @@ -351,7 +360,9 @@ def write_data(self, rows: Sequence[Any]) -> None: def write_footer(self) -> None: if not self.writer: - raise NotImplementedError("Arrow Writer does not support writing empty files") + raise NotImplementedError( + "Arrow Writer does not support writing empty files" + ) return super().write_footer() @classmethod diff --git a/dlt/common/destination/__init__.py b/dlt/common/destination/__init__.py index 00f129c69c..20341c5795 100644 --- a/dlt/common/destination/__init__.py +++ b/dlt/common/destination/__init__.py @@ -3,7 +3,11 @@ TLoaderFileFormat, ALL_SUPPORTED_FILE_FORMATS, ) -from dlt.common.destination.reference import TDestinationReferenceArg, Destination, TDestination +from dlt.common.destination.reference import ( + TDestinationReferenceArg, + Destination, + TDestination, +) __all__ = [ "DestinationCapabilitiesContext", diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 7a64f32ea3..d2d0afa152 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -1,4 +1,14 @@ -from typing import Any, Callable, ClassVar, List, Literal, Optional, Tuple, Set, get_args +from typing import ( + Any, + Callable, + ClassVar, + List, + Literal, + Optional, + Tuple, + Set, + get_args, +) from dlt.common.configuration.utils import serialize_value from dlt.common.configuration import configspec @@ -55,7 +65,9 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): insert_values_writer_type: str = "default" supports_multiple_statements: bool = True supports_clone_table: bool = False - max_table_nesting: Optional[int] = None # destination can overwrite max table nesting + max_table_nesting: Optional[int] = ( + None # destination can overwrite max table nesting + ) """Destination supports CREATE TABLE ... CLONE ... statements""" # do not allow to create default value, destination caps must be always explicitly inserted into container diff --git a/dlt/common/destination/exceptions.py b/dlt/common/destination/exceptions.py index 1b5423ff02..05f4a58f17 100644 --- a/dlt/common/destination/exceptions.py +++ b/dlt/common/destination/exceptions.py @@ -11,9 +11,15 @@ class UnknownDestinationModule(DestinationException): def __init__(self, destination_module: str) -> None: self.destination_module = destination_module if "." in destination_module: - msg = f"Destination module {destination_module} could not be found and imported" + msg = ( + f"Destination module {destination_module} could not be found and" + " imported" + ) else: - msg = f"Destination {destination_module} is not one of the standard dlt destinations" + msg = ( + f"Destination {destination_module} is not one of the standard dlt" + " destinations" + ) super().__init__(msg) @@ -39,13 +45,17 @@ class DestinationTransientException(DestinationException, TransientException): class DestinationLoadingViaStagingNotSupported(DestinationTerminalException): def __init__(self, destination: str) -> None: self.destination = destination - super().__init__(f"Destination {destination} does not support loading via staging.") + super().__init__( + f"Destination {destination} does not support loading via staging." + ) class DestinationLoadingWithoutStagingNotSupported(DestinationTerminalException): def __init__(self, destination: str) -> None: self.destination = destination - super().__init__(f"Destination {destination} does not support loading without staging.") + super().__init__( + f"Destination {destination} does not support loading without staging." + ) class DestinationNoStagingMode(DestinationTerminalException): @@ -56,7 +66,11 @@ def __init__(self, destination: str) -> None: class DestinationIncompatibleLoaderFileFormatException(DestinationTerminalException): def __init__( - self, destination: str, staging: str, file_format: str, supported_formats: Iterable[str] + self, + destination: str, + staging: str, + file_format: str, + supported_formats: Iterable[str], ) -> None: self.destination = destination self.staging = staging @@ -66,20 +80,20 @@ def __init__( if self.staging: if not supported_formats: msg = ( - f"Staging {staging} cannot be used with destination {destination} because they" - " have no file formats in common." + f"Staging {staging} cannot be used with destination" + f" {destination} because they have no file formats in common." ) else: msg = ( - f"Unsupported file format {file_format} for destination {destination} in" - f" combination with staging destination {staging}. Supported formats:" - f" {supported_formats_str}" + f"Unsupported file format {file_format} for destination" + f" {destination} in combination with staging destination {staging}." + f" Supported formats: {supported_formats_str}" ) else: msg = ( - f"Unsupported file format {file_format} destination {destination}. Supported" - f" formats: {supported_formats_str}. Check the staging option in the dlt.pipeline" - " for additional formats." + f"Unsupported file format {file_format} destination {destination}." + f" Supported formats: {supported_formats_str}. Check the staging option" + " in the dlt.pipeline for additional formats." ) super().__init__(msg) @@ -103,7 +117,9 @@ def __init__( class DestinationHasFailedJobs(DestinationTerminalException): - def __init__(self, destination_name: str, load_id: str, failed_jobs: List[Any]) -> None: + def __init__( + self, destination_name: str, load_id: str, failed_jobs: List[Any] + ) -> None: self.destination_name = destination_name self.load_id = load_id self.failed_jobs = failed_jobs @@ -113,14 +129,17 @@ def __init__(self, destination_name: str, load_id: str, failed_jobs: List[Any]) class DestinationSchemaTampered(DestinationTerminalException): - def __init__(self, schema_name: str, version_hash: str, stored_version_hash: str) -> None: + def __init__( + self, schema_name: str, version_hash: str, stored_version_hash: str + ) -> None: self.version_hash = version_hash self.stored_version_hash = stored_version_hash super().__init__( - f"Schema {schema_name} content was changed - by a loader or by destination code - from" - " the moment it was retrieved by load package. Such schema cannot reliably be updated" - f" nor saved. Current version hash: {version_hash} != stored version hash" - f" {stored_version_hash}. If you are using destination client directly, without storing" - " schema in load package, you should first save it into schema storage. You can also" - " use schema._bump_version() in test code to remove modified flag." + f"Schema {schema_name} content was changed - by a loader or by destination" + " code - from the moment it was retrieved by load package. Such schema" + " cannot reliably be updated nor saved. Current version hash:" + f" {version_hash} != stored version hash {stored_version_hash}. If you are" + " using destination client directly, without storing schema in load" + " package, you should first save it into schema storage. You can also use" + " schema._bump_version() in test code to remove modified flag." ) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index ddcc5d1146..43d2d3a305 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -49,10 +49,16 @@ from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName -TLoaderReplaceStrategy = Literal["truncate-and-insert", "insert-from-staging", "staging-optimized"] -TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration") +TLoaderReplaceStrategy = Literal[ + "truncate-and-insert", "insert-from-staging", "staging-optimized" +] +TDestinationConfig = TypeVar( + "TDestinationConfig", bound="DestinationClientConfiguration" +) TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase") -TDestinationDwhClient = TypeVar("TDestinationDwhClient", bound="DestinationClientDwhConfiguration") +TDestinationDwhClient = TypeVar( + "TDestinationDwhClient", bound="DestinationClientDwhConfiguration" +) class StorageSchemaInfo(NamedTuple): @@ -131,7 +137,10 @@ def normalize_dataset_name(self, schema: Schema) -> str: raise ValueError("schema_name is None or empty") # if default schema is None then suffix is not added - if self.default_schema_name is not None and schema.name != self.default_schema_name: + if ( + self.default_schema_name is not None + and schema.name != self.default_schema_name + ): # also normalize schema name. schema name is Python identifier and here convention may be different return schema.naming.normalize_table_identifier( (self.dataset_name or "") + "_" + schema.name @@ -295,7 +304,9 @@ def update_stored_schema( return expected_update @abstractmethod - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: """Creates and starts a load job for a particular `table` with content in `file_path`""" pass @@ -324,7 +335,10 @@ def __enter__(self) -> "JobClientBase": @abstractmethod def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, ) -> None: pass @@ -347,7 +361,8 @@ def _verify_schema(self) -> None: if has_column_with_prop(table, "hard_delete"): if len(get_columns_names_with_prop(table, "hard_delete")) > 1: raise SchemaException( - f'Found multiple "hard_delete" column hints for table "{table_name}" in' + 'Found multiple "hard_delete" column hints for table' + f' "{table_name}" in' f' schema "{self.schema.name}" while only one is allowed:' f' {", ".join(get_columns_names_with_prop(table, "hard_delete"))}.' ) @@ -363,7 +378,8 @@ def _verify_schema(self) -> None: if has_column_with_prop(table, "dedup_sort"): if len(get_columns_names_with_prop(table, "dedup_sort")) > 1: raise SchemaException( - f'Found multiple "dedup_sort" column hints for table "{table_name}" in' + 'Found multiple "dedup_sort" column hints for table' + f' "{table_name}" in' f' schema "{self.schema.name}" while only one is allowed:' f' {", ".join(get_columns_names_with_prop(table, "dedup_sort"))}.' ) @@ -376,9 +392,9 @@ def _verify_schema(self) -> None: ' The "dedup_sort" column hint is only applied when using' ' the "merge" write disposition.' ) - if table.get("write_disposition") == "merge" and not has_column_with_prop( - table, "primary_key" - ): + if table.get( + "write_disposition" + ) == "merge" and not has_column_with_prop(table, "primary_key"): logger.warning( f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ f'in table "{table_name}" with write disposition' @@ -398,9 +414,10 @@ def _verify_schema(self) -> None: if not is_complete_column(column): logger.warning( f"A column {column_name} in table {table_name} in schema" - f" {self.schema.name} is incomplete. It was not bound to the data during" - " normalizations stage and its data type is unknown. Did you add this" - " column manually in code ie. as a merge key?" + f" {self.schema.name} is incomplete. It was not bound to the" + " data during normalizations stage and its data type is" + " unknown. Did you add this column manually in code ie. as a" + " merge key?" ) def prepare_load_table( @@ -411,7 +428,9 @@ def prepare_load_table( table = deepcopy(self.schema.tables[table_name]) # add write disposition if not specified - in child tables if "write_disposition" not in table: - table["write_disposition"] = get_write_disposition(self.schema.tables, table_name) + table["write_disposition"] = get_write_disposition( + self.schema.tables, table_name + ) if "table_format" not in table: table["table_format"] = get_table_format(self.schema.tables, table_name) return table @@ -456,7 +475,9 @@ def should_load_data_to_staging_dataset_on_staging_destination( ) -> bool: return False - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination( + self, table: TTableSchema + ) -> bool: # the default is to truncate the tables on the staging destination... return True @@ -495,7 +516,9 @@ def capabilities(self) -> DestinationCapabilitiesContext: @property def destination_name(self) -> str: """The destination name will either be explicitly set while creating the destination or will be taken from the type""" - return self.config_params.get("destination_name") or self.to_name(self.destination_type) + return self.config_params.get("destination_name") or self.to_name( + self.destination_type + ) @property def destination_type(self) -> str: @@ -566,8 +589,8 @@ def from_reference( if isinstance(ref, Destination): if credentials or destination_name or environment: logger.warning( - "Cannot override credentials, destination_name or environment when passing a" - " Destination instance, these values will be ignored." + "Cannot override credentials, destination_name or environment when" + " passing a Destination instance, these values will be ignored." ) return ref if not isinstance(ref, str): @@ -579,9 +602,9 @@ def from_reference( raise UnknownDestinationModule(ref) from e try: - factory: Type[Destination[DestinationClientConfiguration, JobClientBase]] = getattr( - dest_module, attr_name - ) + factory: Type[ + Destination[DestinationClientConfiguration, JobClientBase] + ] = getattr(dest_module, attr_name) except AttributeError as e: raise UnknownDestinationModule(ref) from e if credentials: diff --git a/dlt/common/exceptions.py b/dlt/common/exceptions.py index fe526c53dc..df46451d04 100644 --- a/dlt/common/exceptions.py +++ b/dlt/common/exceptions.py @@ -34,7 +34,9 @@ def attrs(self) -> Dict[str, Any]: return { k: v for k, v in vars(self).items() - if not k.startswith("_") and not callable(v) and not hasattr(self.__class__, k) + if not k.startswith("_") + and not callable(v) + and not hasattr(self.__class__, k) } @@ -42,13 +44,15 @@ class UnsupportedProcessStartMethodException(DltException): def __init__(self, method: str) -> None: self.method = method super().__init__( - f"Process pool supports only fork start method, {method} not supported. Switch the pool" - " type to threading" + f"Process pool supports only fork start method, {method} not supported." + " Switch the pool type to threading" ) class CannotInstallDependencies(DltException): - def __init__(self, dependencies: Sequence[str], interpreter: str, output: AnyStr) -> None: + def __init__( + self, dependencies: Sequence[str], interpreter: str, output: AnyStr + ) -> None: self.dependencies = dependencies self.interpreter = interpreter if isinstance(output, bytes): @@ -56,8 +60,8 @@ def __init__(self, dependencies: Sequence[str], interpreter: str, output: AnyStr else: str_output = output super().__init__( - f"Cannot install dependencies {', '.join(dependencies)} with {interpreter} and" - f" pip:\n{str_output}\n" + f"Cannot install dependencies {', '.join(dependencies)} with" + f" {interpreter} and pip:\n{str_output}\n" ) @@ -94,7 +98,9 @@ def __init__(self, signal_code: int) -> None: class DictValidationException(DltException): - def __init__(self, msg: str, path: str, field: str = None, value: Any = None) -> None: + def __init__( + self, msg: str, path: str, field: str = None, value: Any = None + ) -> None: self.path = path self.field = field self.value = value @@ -104,13 +110,18 @@ def __init__(self, msg: str, path: str, field: str = None, value: Any = None) -> class ArgumentsOverloadException(DltException): def __init__(self, msg: str, func_name: str, *args: str) -> None: self.func_name = func_name - msg = f"Arguments combination not allowed when calling function {func_name}: {msg}" + msg = ( + f"Arguments combination not allowed when calling function {func_name}:" + f" {msg}" + ) msg = "\n".join((msg, *args)) super().__init__(msg) class MissingDependencyException(DltException): - def __init__(self, caller: str, dependencies: Sequence[str], appendix: str = "") -> None: + def __init__( + self, caller: str, dependencies: Sequence[str], appendix: str = "" + ) -> None: self.caller = caller self.dependencies = dependencies super().__init__(self._get_msg(appendix)) @@ -144,17 +155,17 @@ class PipelineStateNotAvailable(PipelineException): def __init__(self, source_state_key: Optional[str] = None) -> None: if source_state_key: msg = ( - f"The source {source_state_key} requested the access to pipeline state but no" - " pipeline is active right now." + f"The source {source_state_key} requested the access to pipeline state" + " but no pipeline is active right now." ) else: msg = ( - "The resource you called requested the access to pipeline state but no pipeline is" - " active right now." + "The resource you called requested the access to pipeline state but no" + " pipeline is active right now." ) msg += ( - " Call dlt.pipeline(...) before you call the @dlt.source or @dlt.resource decorated" - " function." + " Call dlt.pipeline(...) before you call the @dlt.source or @dlt.resource" + " decorated function." ) self.source_state_key = source_state_key super().__init__(None, msg) @@ -164,16 +175,17 @@ class ResourceNameNotAvailable(PipelineException): def __init__(self) -> None: super().__init__( None, - "A resource state was requested but no active extract pipe context was found. Resource" - " state may be only requested from @dlt.resource decorated function or with explicit" - " resource name.", + "A resource state was requested but no active extract pipe context was" + " found. Resource state may be only requested from @dlt.resource decorated" + " function or with explicit resource name.", ) class SourceSectionNotAvailable(PipelineException): def __init__(self) -> None: msg = ( - "Access to state was requested without source section active. State should be requested" - " from within the @dlt.source and @dlt.resource decorated function." + "Access to state was requested without source section active. State should" + " be requested from within the @dlt.source and @dlt.resource decorated" + " function." ) super().__init__(None, msg) diff --git a/dlt/common/git.py b/dlt/common/git.py index c4f83a7398..572d2bc8ab 100644 --- a/dlt/common/git.py +++ b/dlt/common/git.py @@ -42,7 +42,11 @@ def is_clean_and_synced(repo: Repo) -> bool: status_lines = status.splitlines() first_line = status_lines[0] # we expect first status line is not ## main...origin/main [ahead 1] - return len(status_lines) == 1 and first_line.startswith("##") and not first_line.endswith("]") + return ( + len(status_lines) == 1 + and first_line.startswith("##") + and not first_line.endswith("]") + ) def is_dirty(repo: Repo) -> bool: @@ -84,7 +88,9 @@ def clone_repo( ) -> Repo: from git import Repo - repo = Repo.clone_from(repository_url, clone_path, env=dict(GIT_SSH_COMMAND=with_git_command)) + repo = Repo.clone_from( + repository_url, clone_path, env=dict(GIT_SSH_COMMAND=with_git_command) + ) if branch: repo.git.checkout(branch) return repo @@ -133,7 +139,9 @@ def get_fresh_repo_files( repo_name = url.name repo_path = os.path.join(working_dir, repo_name) try: - ensure_remote_head(repo_path, branch=branch, with_git_command=with_git_command) + ensure_remote_head( + repo_path, branch=branch, with_git_command=with_git_command + ) except GitError: force_clone_repo( repo_location, diff --git a/dlt/common/json/__init__.py b/dlt/common/json/__init__.py index 371c74e54a..b6a56f41ba 100644 --- a/dlt/common/json/__init__.py +++ b/dlt/common/json/__init__.py @@ -30,17 +30,23 @@ def dump( def typed_dump(self, obj: Any, fp: IO[bytes], pretty: bool = False) -> None: ... - def typed_dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ... + def typed_dumps( + self, obj: Any, sort_keys: bool = False, pretty: bool = False + ) -> str: ... def typed_loads(self, s: str) -> Any: ... - def typed_dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ... + def typed_dumpb( + self, obj: Any, sort_keys: bool = False, pretty: bool = False + ) -> bytes: ... def typed_loadb(self, s: Union[bytes, bytearray, memoryview]) -> Any: ... def dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ... - def dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ... + def dumpb( + self, obj: Any, sort_keys: bool = False, pretty: bool = False + ) -> bytes: ... def load(self, fp: Union[IO[bytes], IO[str]]) -> Any: ... diff --git a/dlt/common/json/_orjson.py b/dlt/common/json/_orjson.py index d2d960e6ce..dc82a34238 100644 --- a/dlt/common/json/_orjson.py +++ b/dlt/common/json/_orjson.py @@ -8,7 +8,11 @@ def _dumps( - obj: Any, sort_keys: bool, pretty: bool, default: AnyFun = custom_encode, options: int = 0 + obj: Any, + sort_keys: bool, + pretty: bool, + default: AnyFun = custom_encode, + options: int = 0, ) -> bytes: options = options | orjson.OPT_UTC_Z | orjson.OPT_NON_STR_KEYS if pretty: @@ -18,7 +22,9 @@ def _dumps( return orjson.dumps(obj, default=default, option=options) -def dump(obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False) -> None: +def dump( + obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False +) -> None: fp.write(_dumps(obj, sort_keys, pretty)) @@ -27,7 +33,9 @@ def typed_dump(obj: Any, fp: IO[bytes], pretty: bool = False) -> None: def typed_dumpb(obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: - return _dumps(obj, sort_keys, pretty, custom_pua_encode, orjson.OPT_PASSTHROUGH_DATETIME) + return _dumps( + obj, sort_keys, pretty, custom_pua_encode, orjson.OPT_PASSTHROUGH_DATETIME + ) def typed_dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: diff --git a/dlt/common/json/_simplejson.py b/dlt/common/json/_simplejson.py index 10ee17e2f6..5b00a93ce7 100644 --- a/dlt/common/json/_simplejson.py +++ b/dlt/common/json/_simplejson.py @@ -15,7 +15,9 @@ _impl_name = "simplejson" -def dump(obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False) -> None: +def dump( + obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False +) -> None: if pretty: indent = 2 else: diff --git a/dlt/common/jsonpath.py b/dlt/common/jsonpath.py index 7808d1c69c..065375c028 100644 --- a/dlt/common/jsonpath.py +++ b/dlt/common/jsonpath.py @@ -45,4 +45,6 @@ def resolve_paths(paths: TAnyJsonPath, data: DictStrAny) -> List[str]: """ paths = compile_paths(paths) p: JSONPath - return list(chain.from_iterable((str(r.full_path) for r in p.find(data)) for p in paths)) + return list( + chain.from_iterable((str(r.full_path) for r in p.find(data)) for p in paths) + ) diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index c1fbfbff85..ed8076933b 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -1,6 +1,16 @@ from datetime import datetime, date # noqa: I251 from pendulum.tz import UTC -from typing import Any, Tuple, Optional, Union, Callable, Iterable, Iterator, Sequence, Tuple +from typing import ( + Any, + Tuple, + Optional, + Union, + Callable, + Iterable, + Iterator, + Sequence, + Tuple, +) from dlt import version from dlt.common import pendulum @@ -20,7 +30,8 @@ raise MissingDependencyException( "dlt pyarrow helpers", [f"{version.DLT_PKG_NAME}[parquet]"], - "Install pyarrow to be allow to load arrow tables, panda frames and to use parquet files.", + "Install pyarrow to be allow to load arrow tables, panda frames and to use" + " parquet files.", ) @@ -38,7 +49,9 @@ def get_py_arrow_datatype( elif column_type == "bool": return pyarrow.bool_() elif column_type == "timestamp": - return get_py_arrow_timestamp(column.get("precision") or caps.timestamp_precision, tz) + return get_py_arrow_timestamp( + column.get("precision") or caps.timestamp_precision, tz + ) elif column_type == "bigint": return get_pyarrow_int(column.get("precision")) elif column_type == "binary": @@ -169,7 +182,9 @@ def remove_columns(item: TAnyArrowItem, columns: Sequence[str]) -> TAnyArrowItem return item.drop(columns) elif isinstance(item, pyarrow.RecordBatch): # NOTE: select is available in pyarrow 12 an up - return item.select([n for n in item.schema.names if n not in columns]) # reverse selection + return item.select( + [n for n in item.schema.names if n not in columns] + ) # reverse selection else: raise ValueError(item) @@ -187,7 +202,9 @@ def append_column(item: TAnyArrowItem, name: str, data: Any) -> TAnyArrowItem: raise ValueError(item) -def rename_columns(item: TAnyArrowItem, new_column_names: Sequence[str]) -> TAnyArrowItem: +def rename_columns( + item: TAnyArrowItem, new_column_names: Sequence[str] +) -> TAnyArrowItem: """Rename arrow columns on Table or RecordBatch, returns same data but with renamed schema""" if list(item.schema.names) == list(new_column_names): @@ -198,9 +215,12 @@ def rename_columns(item: TAnyArrowItem, new_column_names: Sequence[str]) -> TAny return item.rename_columns(new_column_names) elif isinstance(item, pyarrow.RecordBatch): new_fields = [ - field.with_name(new_name) for new_name, field in zip(new_column_names, item.schema) + field.with_name(new_name) + for new_name, field in zip(new_column_names, item.schema) ] - return pyarrow.RecordBatch.from_arrays(item.columns, schema=pyarrow.schema(new_fields)) + return pyarrow.RecordBatch.from_arrays( + item.columns, schema=pyarrow.schema(new_fields) + ) else: raise TypeError(f"Unsupported data item type {type(item)}") @@ -268,7 +288,9 @@ def normalize_py_arrow_schema( return item.__class__.from_arrays(new_columns, schema=pyarrow.schema(new_fields)) -def get_normalized_arrow_fields_mapping(item: TAnyArrowItem, naming: NamingConvention) -> StrStr: +def get_normalized_arrow_fields_mapping( + item: TAnyArrowItem, naming: NamingConvention +) -> StrStr: """Normalizes schema field names and returns mapping from original to normalized name. Raises on name clashes""" norm_f = naming.normalize_identifier name_mapping = {n.name: norm_f(n.name) for n in item.schema} @@ -330,13 +352,17 @@ def from_arrow_scalar(arrow_value: pyarrow.Scalar) -> Any: # datetimes as dates and keeping the exact time inside. probably a bug # but can be corrected this way if isinstance(row_value, date) and not isinstance(row_value, datetime): - row_value = pendulum.from_timestamp(arrow_value.cast(pyarrow.int64()).as_py() / 1000) + row_value = pendulum.from_timestamp( + arrow_value.cast(pyarrow.int64()).as_py() / 1000 + ) elif isinstance(row_value, datetime): row_value = pendulum.instance(row_value).in_tz("UTC") return row_value -TNewColumns = Sequence[Tuple[int, pyarrow.Field, Callable[[pyarrow.Table], Iterable[Any]]]] +TNewColumns = Sequence[ + Tuple[int, pyarrow.Field, Callable[[pyarrow.Table], Iterable[Any]]] +] """Sequence of tuples: (field index, field, generating function)""" diff --git a/dlt/common/libs/pydantic.py b/dlt/common/libs/pydantic.py index c4bf994cb9..3d8fb0a714 100644 --- a/dlt/common/libs/pydantic.py +++ b/dlt/common/libs/pydantic.py @@ -18,7 +18,9 @@ from dlt.common.exceptions import MissingDependencyException from dlt.common.schema import DataValidationError from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention +from dlt.common.normalizers.naming.snake_case import ( + NamingConvention as SnakeCaseNamingConvention, +) from dlt.common.typing import ( TDataItem, TDataItems, @@ -204,8 +206,8 @@ def apply_schema_contract_to_model( model = create_model(model.__name__ + "Any", **{n: (Any, None) for n in model.__fields__}) # type: ignore[call-overload, attr-defined] elif data_mode == "discard_value": raise NotImplementedError( - "data_mode is discard_value. Cannot discard defined fields with validation errors using" - " Pydantic models." + "data_mode is discard_value. Cannot discard defined fields with validation" + " errors using Pydantic models." ) extra = column_mode_to_extra(column_mode) @@ -349,7 +351,8 @@ def validate_items( deleted.add(err_idx) else: raise NotImplementedError( - f"{column_mode} column mode not implemented for Pydantic validation" + f"{column_mode} column mode not implemented for Pydantic" + " validation" ) else: if data_mode == "freeze": @@ -368,7 +371,8 @@ def validate_items( deleted.add(err_idx) else: raise NotImplementedError( - f"{column_mode} column mode not implemented for Pydantic validation" + f"{column_mode} column mode not implemented for Pydantic" + " validation" ) # validate again with error items removed diff --git a/dlt/common/libs/sql_alchemy.py b/dlt/common/libs/sql_alchemy.py index 2f3b51ec0d..d124ddf2ab 100644 --- a/dlt/common/libs/sql_alchemy.py +++ b/dlt/common/libs/sql_alchemy.py @@ -165,7 +165,8 @@ def _assert_value( return tuple(_assert_value(elem) for elem in val) else: raise TypeError( - "Query dictionary values must be strings or sequences of strings" + "Query dictionary values must be strings or sequences of" + " strings" ) def _assert_str(v: str) -> str: @@ -250,7 +251,9 @@ def update_query_pairs( new_keys[key] = to_list(new_keys[key]) cast("List[str]", new_keys[key]).append(cast(str, value)) else: - new_keys[key] = to_list(value) if isinstance(value, (list, tuple)) else value + new_keys[key] = ( + to_list(value) if isinstance(value, (list, tuple)) else value + ) new_query: Mapping[str, Union[str, Sequence[str]]] if append: @@ -258,18 +261,26 @@ def update_query_pairs( for k in new_keys: if k in existing_query: - new_query[k] = tuple(to_list(existing_query[k]) + to_list(new_keys[k])) + new_query[k] = tuple( + to_list(existing_query[k]) + to_list(new_keys[k]) + ) else: new_query[k] = new_keys[k] new_query.update( - {k: existing_query[k] for k in set(existing_query).difference(new_keys)} + { + k: existing_query[k] + for k in set(existing_query).difference(new_keys) + } ) else: new_query = ImmutableDict( { **self.query, - **{k: tuple(v) if isinstance(v, list) else v for k, v in new_keys.items()}, + **{ + k: tuple(v) if isinstance(v, list) else v + for k, v in new_keys.items() + }, } ) return self.set(query=new_query) @@ -287,7 +298,9 @@ def render_as_string(self, hide_password: bool = True) -> str: if self.username is not None: s += quote(self.username, safe=" +") if self.password is not None: - s += ":" + ("***" if hide_password else quote(str(self.password), safe=" +")) + s += ":" + ( + "***" if hide_password else quote(str(self.password), safe=" +") + ) s += "@" if self.host is not None: if ":" in self.host: diff --git a/dlt/common/logger.py b/dlt/common/logger.py index 02412248c3..2ac910f327 100644 --- a/dlt/common/logger.py +++ b/dlt/common/logger.py @@ -93,7 +93,9 @@ class _CustomJsonFormatter(json_logging.JSONLogFormatter): version: Mapping[str, str] = None def _format_log_object(self, record: LogRecord) -> Any: - json_log_object = super(_CustomJsonFormatter, self)._format_log_object(record) + json_log_object = super(_CustomJsonFormatter, self)._format_log_object( + record + ) if self.version: json_log_object.update({"version": self.version}) return json_log_object diff --git a/dlt/common/normalizers/configuration.py b/dlt/common/normalizers/configuration.py index 54b725db1f..fd0f1c466b 100644 --- a/dlt/common/normalizers/configuration.py +++ b/dlt/common/normalizers/configuration.py @@ -14,7 +14,9 @@ class NormalizersConfiguration(BaseConfiguration): naming: Optional[str] = None json_normalizer: Optional[DictStrAny] = None - destination_capabilities: Optional[DestinationCapabilitiesContext] = None # injectable + destination_capabilities: Optional[DestinationCapabilitiesContext] = ( + None # injectable + ) def on_resolved(self) -> None: # get naming from capabilities if not present diff --git a/dlt/common/normalizers/json/__init__.py b/dlt/common/normalizers/json/__init__.py index a13bab15f4..683c2399af 100644 --- a/dlt/common/normalizers/json/__init__.py +++ b/dlt/common/normalizers/json/__init__.py @@ -1,5 +1,14 @@ import abc -from typing import Any, Generic, Type, Generator, Tuple, Protocol, TYPE_CHECKING, TypeVar +from typing import ( + Any, + Generic, + Type, + Generator, + Tuple, + Protocol, + TYPE_CHECKING, + TypeVar, +) from dlt.common.typing import DictStrAny, TDataItem, StrAny @@ -38,7 +47,9 @@ def extend_table(self, table_name: str) -> None: @classmethod @abc.abstractmethod - def update_normalizer_config(cls, schema: Schema, config: TNormalizerConfig) -> None: + def update_normalizer_config( + cls, schema: Schema, config: TNormalizerConfig + ) -> None: pass @classmethod diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index e33bf2ab35..455d0b76f0 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -95,7 +95,9 @@ def _flatten( out_rec_list: Dict[Tuple[str, ...], Sequence[Any]] = {} schema_naming = self.schema.naming - def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) -> None: + def norm_row_dicts( + dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = () + ) -> None: for k, v in dict_row.items(): if k.strip(): norm_k = schema_naming.normalize_identifier(k) @@ -105,7 +107,9 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) - # if norm_k != k: # print(f"{k} -> {norm_k}") child_name = ( - norm_k if path == () else schema_naming.shorten_fragments(*path, norm_k) + norm_k + if path == () + else schema_naming.shorten_fragments(*path, norm_k) ) # for lists and dicts we must check if type is possibly complex if isinstance(v, (dict, list)): @@ -116,7 +120,9 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) - norm_row_dicts(v, __r_lvl + 1, path + (norm_k,)) else: # pass the list to out_rec_list - out_rec_list[path + (schema_naming.normalize_table_identifier(k),)] = v + out_rec_list[ + path + (schema_naming.normalize_table_identifier(k),) + ] = v continue else: # pass the complex value to out_rec_row @@ -131,10 +137,14 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) - def _get_child_row_hash(parent_row_id: str, child_table: str, list_idx: int) -> str: # create deterministic unique id of the child row taking into account that all lists are ordered # and all child tables must be lists - return digest128(f"{parent_row_id}_{child_table}_{list_idx}", DLT_ID_LENGTH_BYTES) + return digest128( + f"{parent_row_id}_{child_table}_{list_idx}", DLT_ID_LENGTH_BYTES + ) @staticmethod - def _link_row(row: TDataItemRowChild, parent_row_id: str, list_idx: int) -> TDataItemRowChild: + def _link_row( + row: TDataItemRowChild, parent_row_id: str, list_idx: int + ) -> TDataItemRowChild: assert parent_row_id row["_dlt_parent_id"] = parent_row_id row["_dlt_list_idx"] = list_idx @@ -154,13 +164,19 @@ def _add_row_id( primary_key = self.schema.filter_row_with_hint(table, "primary_key", row) if not primary_key: # child table row deterministic hash - row_id = DataItemNormalizer._get_child_row_hash(parent_row_id, table, pos) + row_id = DataItemNormalizer._get_child_row_hash( + parent_row_id, table, pos + ) # link to parent table - DataItemNormalizer._link_row(cast(TDataItemRowChild, row), parent_row_id, pos) + DataItemNormalizer._link_row( + cast(TDataItemRowChild, row), parent_row_id, pos + ) row["_dlt_id"] = row_id return row_id - def _get_propagated_values(self, table: str, row: TDataItemRow, _r_lvl: int) -> StrAny: + def _get_propagated_values( + self, table: str, row: TDataItemRow, _r_lvl: int + ) -> StrAny: extend: DictStrAny = {} config = self.propagation_config @@ -200,11 +216,19 @@ def _normalize_list( elif isinstance(v, list): # to normalize lists of lists, we must create a tracking intermediary table by creating a mock row yield from self._normalize_row( - {"list": v}, extend, ident_path, parent_path, parent_row_id, idx, _r_lvl + 1 + {"list": v}, + extend, + ident_path, + parent_path, + parent_row_id, + idx, + _r_lvl + 1, ) else: # list of simple types - child_row_hash = DataItemNormalizer._get_child_row_hash(parent_row_id, table, idx) + child_row_hash = DataItemNormalizer._get_child_row_hash( + parent_row_id, table, idx + ) wrap_v = wrap_in_dict(v) wrap_v["_dlt_id"] = child_row_hash e = DataItemNormalizer._link_row(wrap_v, parent_row_id, idx) @@ -237,20 +261,29 @@ def _normalize_row( extend.update(self._get_propagated_values(table, flattened_row, _r_lvl)) # yield parent table first - should_descend = yield (table, schema.naming.shorten_fragments(*parent_path)), flattened_row + should_descend = ( + yield (table, schema.naming.shorten_fragments(*parent_path)), + flattened_row, + ) if should_descend is False: return # normalize and yield lists for list_path, list_content in lists.items(): yield from self._normalize_list( - list_content, extend, list_path, parent_path + ident_path, row_id, _r_lvl + 1 + list_content, + extend, + list_path, + parent_path + ident_path, + row_id, + _r_lvl + 1, ) def extend_schema(self) -> None: # validate config config = cast( - RelationalNormalizerConfig, self.schema._normalizers_config["json"].get("config") or {} + RelationalNormalizerConfig, + self.schema._normalizers_config["json"].get("config") or {}, ) DataItemNormalizer._validate_normalizer_config(self.schema, config) @@ -283,7 +316,11 @@ def extend_table(self, table_name: str) -> None: if not table.get("parent") and table.get("write_disposition") == "merge": DataItemNormalizer.update_normalizer_config( self.schema, - {"propagation": {"tables": {table_name: {"_dlt_id": TColumnName("_dlt_root_id")}}}}, + { + "propagation": { + "tables": {table_name: {"_dlt_id": TColumnName("_dlt_root_id")}} + } + }, ) def normalize_data_item( @@ -310,7 +347,9 @@ def ensure_this_normalizer(cls, norm_config: TJSONNormalizer) -> None: raise InvalidJsonNormalizer(__name__, present_normalizer) @classmethod - def update_normalizer_config(cls, schema: Schema, config: RelationalNormalizerConfig) -> None: + def update_normalizer_config( + cls, schema: Schema, config: RelationalNormalizerConfig + ) -> None: cls._validate_normalizer_config(schema, config) norm_config = schema._normalizers_config["json"] cls.ensure_this_normalizer(norm_config) @@ -326,7 +365,9 @@ def get_normalizer_config(cls, schema: Schema) -> RelationalNormalizerConfig: return cast(RelationalNormalizerConfig, norm_config.get("config", {})) @staticmethod - def _validate_normalizer_config(schema: Schema, config: RelationalNormalizerConfig) -> None: + def _validate_normalizer_config( + schema: Schema, config: RelationalNormalizerConfig + ) -> None: validate_dict( RelationalNormalizerConfig, config, diff --git a/dlt/common/normalizers/naming/direct.py b/dlt/common/normalizers/naming/direct.py index 0998650852..f11c0bb31b 100644 --- a/dlt/common/normalizers/naming/direct.py +++ b/dlt/common/normalizers/naming/direct.py @@ -1,6 +1,8 @@ from typing import Any, Sequence -from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention +from dlt.common.normalizers.naming.naming import ( + NamingConvention as BaseNamingConvention, +) class NamingConvention(BaseNamingConvention): diff --git a/dlt/common/normalizers/naming/duck_case.py b/dlt/common/normalizers/naming/duck_case.py index 063482a799..199a5fc233 100644 --- a/dlt/common/normalizers/naming/duck_case.py +++ b/dlt/common/normalizers/naming/duck_case.py @@ -1,7 +1,9 @@ import re from functools import lru_cache -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention +from dlt.common.normalizers.naming.snake_case import ( + NamingConvention as SnakeCaseNamingConvention, +) class NamingConvention(SnakeCaseNamingConvention): @@ -17,5 +19,7 @@ def _normalize_identifier(identifier: str, max_length: int) -> str: # shorten identifier return NamingConvention.shorten_identifier( - NamingConvention._RE_UNDERSCORES.sub("_", normalized_ident), identifier, max_length + NamingConvention._RE_UNDERSCORES.sub("_", normalized_ident), + identifier, + max_length, ) diff --git a/dlt/common/normalizers/naming/exceptions.py b/dlt/common/normalizers/naming/exceptions.py index 572fc7e0d0..1bf8bdc28c 100644 --- a/dlt/common/normalizers/naming/exceptions.py +++ b/dlt/common/normalizers/naming/exceptions.py @@ -11,7 +11,10 @@ def __init__(self, naming_module: str) -> None: if "." in naming_module: msg = f"Naming module {naming_module} could not be found and imported" else: - msg = f"Naming module {naming_module} is not one of the standard dlt naming convention" + msg = ( + f"Naming module {naming_module} is not one of the standard dlt naming" + " convention" + ) super().__init__(msg) @@ -19,7 +22,7 @@ class InvalidNamingModule(NormalizersException): def __init__(self, naming_module: str) -> None: self.naming_module = naming_module msg = ( - f"Naming module {naming_module} does not implement required SupportsNamingConvention" - " protocol" + f"Naming module {naming_module} does not implement required" + " SupportsNamingConvention protocol" ) super().__init__(msg) diff --git a/dlt/common/normalizers/naming/naming.py b/dlt/common/normalizers/naming/naming.py index fccb147981..c2e5fd2eca 100644 --- a/dlt/common/normalizers/naming/naming.py +++ b/dlt/common/normalizers/naming/naming.py @@ -39,9 +39,13 @@ def break_path(self, path: str) -> Sequence[str]: def normalize_path(self, path: str) -> str: """Breaks path into identifiers, normalizes components, reconstitutes and shortens the path""" - normalized_idents = [self.normalize_identifier(ident) for ident in self.break_path(path)] + normalized_idents = [ + self.normalize_identifier(ident) for ident in self.break_path(path) + ] # shorten the whole path - return self.shorten_identifier(self.make_path(*normalized_idents), path, self.max_length) + return self.shorten_identifier( + self.make_path(*normalized_idents), path, self.max_length + ) def normalize_tables_path(self, path: str) -> str: """Breaks path of table identifiers, normalizes components, reconstitutes and shortens the path""" @@ -49,7 +53,9 @@ def normalize_tables_path(self, path: str) -> str: self.normalize_table_identifier(ident) for ident in self.break_path(path) ] # shorten the whole path - return self.shorten_identifier(self.make_path(*normalized_idents), path, self.max_length) + return self.shorten_identifier( + self.make_path(*normalized_idents), path, self.max_length + ) def shorten_fragments(self, *normalized_idents: str) -> str: """Reconstitutes and shortens the path of normalized identifiers""" @@ -70,7 +76,9 @@ def shorten_identifier( if max_length and len(normalized_ident) > max_length: # use original identifier to compute tag tag = NamingConvention._compute_tag(identifier, collision_prob) - normalized_ident = NamingConvention._trim_and_tag(normalized_ident, tag, max_length) + normalized_ident = NamingConvention._trim_and_tag( + normalized_ident, tag, max_length + ) return normalized_ident @@ -80,7 +88,9 @@ def _compute_tag(identifier: str, collision_prob: float) -> str: # take into account that we are case insensitive in base64 so we need ~1.5x more bits (2+1) tl_bytes = int(((2 + 1) * math.log2(1 / (collision_prob)) // 8) + 1) tag = ( - base64.b64encode(hashlib.shake_128(identifier.encode("utf-8")).digest(tl_bytes)) + base64.b64encode( + hashlib.shake_128(identifier.encode("utf-8")).digest(tl_bytes) + ) .rstrip(b"=") .translate(NamingConvention._TR_TABLE) .lower() diff --git a/dlt/common/normalizers/naming/snake_case.py b/dlt/common/normalizers/naming/snake_case.py index b3c65e9b8d..6e3af18011 100644 --- a/dlt/common/normalizers/naming/snake_case.py +++ b/dlt/common/normalizers/naming/snake_case.py @@ -2,7 +2,9 @@ from typing import Any, List, Sequence from functools import lru_cache -from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention +from dlt.common.normalizers.naming.naming import ( + NamingConvention as BaseNamingConvention, +) class NamingConvention(BaseNamingConvention): @@ -36,7 +38,9 @@ def _normalize_identifier(identifier: str, max_length: int) -> str: """Normalizes the identifier according to naming convention represented by this function""" # all characters that are not letters digits or a few special chars are replaced with underscore normalized_ident = identifier.translate(NamingConvention._TR_REDUCE_ALPHABET) - normalized_ident = NamingConvention._RE_NON_ALPHANUMERIC.sub("_", normalized_ident) + normalized_ident = NamingConvention._RE_NON_ALPHANUMERIC.sub( + "_", normalized_ident + ) # shorten identifier return NamingConvention.shorten_identifier( diff --git a/dlt/common/normalizers/utils.py b/dlt/common/normalizers/utils.py index 645bad2bea..a944c41d33 100644 --- a/dlt/common/normalizers/utils.py +++ b/dlt/common/normalizers/utils.py @@ -7,7 +7,10 @@ from dlt.common.normalizers.configuration import NormalizersConfiguration from dlt.common.normalizers.json import SupportsDataItemNormalizer, DataItemNormalizer from dlt.common.normalizers.naming import NamingConvention, SupportsNamingConvention -from dlt.common.normalizers.naming.exceptions import UnknownNamingModule, InvalidNamingModule +from dlt.common.normalizers.naming.exceptions import ( + UnknownNamingModule, + InvalidNamingModule, +) from dlt.common.normalizers.typing import TJSONNormalizer, TNormalizersConfig from dlt.common.utils import uniq_id_base64, many_uniq_ids_base64 @@ -49,7 +52,8 @@ def import_normalizers( else: # from known location naming_module = cast( - SupportsNamingConvention, import_module(f"dlt.common.normalizers.naming.{names}") + SupportsNamingConvention, + import_module(f"dlt.common.normalizers.naming.{names}"), ) except ImportError: raise UnknownNamingModule(names) @@ -63,7 +67,9 @@ def import_normalizers( ) else: max_length = None - json_module = cast(SupportsDataItemNormalizer, import_module(item_normalizer["module"])) + json_module = cast( + SupportsDataItemNormalizer, import_module(item_normalizer["module"]) + ) return ( normalizers_config, diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 7c117d4612..4c926ac7b2 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -36,7 +36,12 @@ from dlt.common.destination.exceptions import DestinationHasFailedJobs from dlt.common.exceptions import PipelineStateNotAvailable, SourceSectionNotAvailable from dlt.common.schema import Schema -from dlt.common.schema.typing import TColumnNames, TColumnSchema, TWriteDisposition, TSchemaContract +from dlt.common.schema.typing import ( + TColumnNames, + TColumnSchema, + TWriteDisposition, + TSchemaContract, +) from dlt.common.source import get_current_pipe_name from dlt.common.storages.load_storage import LoadPackageInfo from dlt.common.time import ensure_pendulum_datetime, precise_time @@ -104,7 +109,9 @@ def asdict(self) -> DictStrAny: # to be mixed with NamedTuple step_info: DictStrAny = self._asdict() # type: ignore step_info["pipeline"] = {"pipeline_name": self.pipeline.pipeline_name} - step_info["load_packages"] = [package.asdict() for package in self.load_packages] + step_info["load_packages"] = [ + package.asdict() for package in self.load_packages + ] if self.metrics: step_info["started_at"] = self.started_at step_info["finished_at"] = self.finished_at @@ -120,23 +127,36 @@ def __str__(self) -> str: return self.asstr(verbosity=0) @staticmethod - def _load_packages_asstr(load_packages: List[LoadPackageInfo], verbosity: int) -> str: + def _load_packages_asstr( + load_packages: List[LoadPackageInfo], verbosity: int + ) -> str: msg: str = "" for load_package in load_packages: cstr = ( load_package.state.upper() if load_package.completed_at - else f"{load_package.state.upper()} and NOT YET LOADED to the destination" + else ( + f"{load_package.state.upper()} and NOT YET LOADED to the" + " destination" + ) ) # now enumerate all complete loads if we have any failed packages # complete but failed job will not raise any exceptions failed_jobs = load_package.jobs["failed_jobs"] - jobs_str = "no failed jobs" if not failed_jobs else f"{len(failed_jobs)} FAILED job(s)!" - msg += f"\nLoad package {load_package.load_id} is {cstr} and contains {jobs_str}" + jobs_str = ( + "no failed jobs" + if not failed_jobs + else f"{len(failed_jobs)} FAILED job(s)!" + ) + msg += ( + f"\nLoad package {load_package.load_id} is {cstr} and contains" + f" {jobs_str}" + ) if verbosity > 0: for failed_job in failed_jobs: msg += ( - f"\n\t[{failed_job.job_file_info.job_id()}]: {failed_job.failed_message}\n" + f"\n\t[{failed_job.job_file_info.job_id()}]:" + f" {failed_job.failed_message}\n" ) if verbosity > 1: msg += "\nPackage details:\n" @@ -145,7 +165,9 @@ def _load_packages_asstr(load_packages: List[LoadPackageInfo], verbosity: int) - @staticmethod def job_metrics_asdict( - job_metrics: Dict[str, DataWriterMetrics], key_name: str = "job_id", extend: StrAny = None + job_metrics: Dict[str, DataWriterMetrics], + key_name: str = "job_id", + extend: StrAny = None, ) -> List[DictStrAny]: jobs = [] for job_id, metrics in job_metrics.items(): @@ -228,7 +250,9 @@ def asdict(self) -> DictStrAny: ) load_metrics["resource_metrics"].extend( self.job_metrics_asdict( - metrics["resource_metrics"], key_name="resource_name", extend=extend + metrics["resource_metrics"], + key_name="resource_name", + extend=extend, ) ) load_metrics["dag"].extend( @@ -279,9 +303,12 @@ def row_counts(self) -> RowCounts: return {} counts: RowCounts = {} for metrics in self.metrics.values(): - assert len(metrics) == 1, "Cannot deal with more than 1 normalize metric per load_id" + assert ( + len(metrics) == 1 + ), "Cannot deal with more than 1 normalize metric per load_id" merge_row_counts( - counts, {t: m.items_count for t, m in metrics[0]["table_metrics"].items()} + counts, + {t: m.items_count for t, m in metrics[0]["table_metrics"].items()}, ) return counts @@ -419,8 +446,8 @@ def _step_info_start_load_id(self, load_id: str) -> None: def _step_info_complete_load_id(self, load_id: str, metrics: TStepMetrics) -> None: assert self._current_load_id == load_id, ( - f"Current load id mismatch {self._current_load_id} != {load_id} when completing step" - " info" + f"Current load id mismatch {self._current_load_id} != {load_id} when" + " completing step info" ) metrics["started_at"] = ensure_pendulum_datetime(self._current_load_started) metrics["finished_at"] = ensure_pendulum_datetime(precise_time()) @@ -568,8 +595,8 @@ def pipeline(self) -> SupportsPipeline: if not self._pipeline: # delayed pipeline creation assert self._deferred_pipeline is not None, ( - "Deferred pipeline creation function not provided to PipelineContext. Are you" - " calling dlt.pipeline() from another thread?" + "Deferred pipeline creation function not provided to PipelineContext." + " Are you calling dlt.pipeline() from another thread?" ) self.activate(self._deferred_pipeline()) return self._pipeline @@ -590,7 +617,9 @@ def deactivate(self) -> None: self._pipeline._set_context(False) self._pipeline = None - def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline] = None) -> None: + def __init__( + self, deferred_pipeline: Callable[..., SupportsPipeline] = None + ) -> None: """Initialize the context with a function returning the Pipeline object to allow creation on first use""" self._deferred_pipeline = deferred_pipeline @@ -745,7 +774,9 @@ def resource_state( return state_.setdefault("resources", {}).setdefault(resource_name, {}) # type: ignore -def reset_resource_state(resource_name: str, source_state_: Optional[DictStrAny] = None, /) -> None: +def reset_resource_state( + resource_name: str, source_state_: Optional[DictStrAny] = None, / +) -> None: """Resets the resource state with name `resource_name` by removing it from `source_state` Args: diff --git a/dlt/common/reflection/spec.py b/dlt/common/reflection/spec.py index 5c39199f63..84006eacf8 100644 --- a/dlt/common/reflection/spec.py +++ b/dlt/common/reflection/spec.py @@ -22,7 +22,10 @@ def _get_spec_name_from_f(f: AnyFun) -> str: def _first_up(s: str) -> str: return s[0].upper() + s[1:] - return "".join(map(_first_up, _SLEEPING_CAT_SPLIT.findall(func_name))) + "Configuration" + return ( + "".join(map(_first_up, _SLEEPING_CAT_SPLIT.findall(func_name))) + + "Configuration" + ) def spec_from_signature( @@ -81,7 +84,10 @@ def dlt_config_literal_to_type(arg_name: str) -> AnyType: for p in sig.parameters.values(): # skip *args and **kwargs, skip typical method params - if p.kind not in (Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL) and p.name not in [ + if p.kind not in ( + Parameter.VAR_KEYWORD, + Parameter.VAR_POSITIONAL, + ) and p.name not in [ "self", "cls", ]: diff --git a/dlt/common/reflection/utils.py b/dlt/common/reflection/utils.py index 9bd3cb6775..a38c2c4c6f 100644 --- a/dlt/common/reflection/utils.py +++ b/dlt/common/reflection/utils.py @@ -64,7 +64,9 @@ def set_ast_parents(tree: ast.AST) -> None: child.parent = node if node is not tree else None # type: ignore -def creates_func_def_name_node(func_def: ast.FunctionDef, source_lines: Sequence[str]) -> ast.Name: +def creates_func_def_name_node( + func_def: ast.FunctionDef, source_lines: Sequence[str] +) -> ast.Name: """Recreate function name as a ast.Name with known source code location""" func_name = ast.Name(func_def.name) func_name.lineno = func_name.end_lineno = func_def.lineno @@ -84,7 +86,9 @@ def rewrite_python_script( last_line = -1 last_offset = -1 # sort transformed nodes by line and offset - for node, t_value in sorted(transformed_nodes, key=lambda n: (n[0].lineno, n[0].col_offset)): + for node, t_value in sorted( + transformed_nodes, key=lambda n: (n[0].lineno, n[0].col_offset) + ): # do we have a line changed if last_line != node.lineno - 1: # add remainder from the previous line @@ -96,7 +100,9 @@ def rewrite_python_script( script_lines.append(source_script_lines[node.lineno - 1][: node.col_offset]) elif last_offset >= 0: # no line change, add the characters from the end of previous node to the current - script_lines.append(source_script_lines[last_line][last_offset : node.col_offset]) + script_lines.append( + source_script_lines[last_line][last_offset : node.col_offset] + ) # replace node value script_lines.append(astunparse.unparse(t_value).strip()) diff --git a/dlt/common/runners/pool_runner.py b/dlt/common/runners/pool_runner.py index 491c74cd18..996dc1040a 100644 --- a/dlt/common/runners/pool_runner.py +++ b/dlt/common/runners/pool_runner.py @@ -24,7 +24,9 @@ class NullExecutor(Executor): Provides a uniform interface for `None` pool type """ - def submit(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Future[T]: + def submit( + self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs + ) -> Future[T]: """Run the job and return a Future""" fut: Future[T] = Future() try: @@ -53,7 +55,8 @@ def create_pool(config: PoolRunnerConfiguration) -> Executor: ) elif config.pool_type == "thread": return ThreadPoolExecutor( - max_workers=config.workers, thread_name_prefix=Container.thread_pool_prefix() + max_workers=config.workers, + thread_name_prefix=Container.thread_pool_prefix(), ) # no pool - single threaded return NullExecutor() @@ -66,12 +69,16 @@ def run_pool( # validate the run function if not isinstance(run_f, Runnable) and not callable(run_f): raise ValueError( - run_f, "Pool runner entry point must be a function f(pool: TPool) or Runnable" + run_f, + "Pool runner entry point must be a function f(pool: TPool) or Runnable", ) # start pool pool = create_pool(config) - logger.info(f"Created {config.pool_type} pool with {config.workers or 'default no.'} workers") + logger.info( + f"Created {config.pool_type} pool with" + f" {config.workers or 'default no.'} workers" + ) runs_count = 1 def _run_func() -> bool: diff --git a/dlt/common/runners/stdout.py b/dlt/common/runners/stdout.py index 8ddfb45ee4..b084144a87 100644 --- a/dlt/common/runners/stdout.py +++ b/dlt/common/runners/stdout.py @@ -56,7 +56,9 @@ def _r_stderr() -> None: # we fail iterator if exit code is not 0 if exit_code != 0: - raise CalledProcessError(exit_code, command, output=line, stderr="".join(stderr)) + raise CalledProcessError( + exit_code, command, output=line, stderr="".join(stderr) + ) def iter_stdout_with_result( @@ -80,7 +82,9 @@ def iter_stdout_with_result( # try to find last object in stderr if cpe.stderr: # if exception was decoded from stderr - exception = decode_last_obj(cpe.stderr.split("\n"), ignore_pickle_errors=False) + exception = decode_last_obj( + cpe.stderr.split("\n"), ignore_pickle_errors=False + ) if isinstance(exception, Exception): raise exception from cpe else: diff --git a/dlt/common/runners/venv.py b/dlt/common/runners/venv.py index 9a92b30326..db1f19eceb 100644 --- a/dlt/common/runners/venv.py +++ b/dlt/common/runners/venv.py @@ -71,7 +71,10 @@ def __enter__(self) -> "Venv": return self def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: types.TracebackType, ) -> None: self.delete_environment() diff --git a/dlt/common/runtime/collector.py b/dlt/common/runtime/collector.py index e478d713b2..b58b612844 100644 --- a/dlt/common/runtime/collector.py +++ b/dlt/common/runtime/collector.py @@ -21,7 +21,11 @@ if TYPE_CHECKING: from tqdm import tqdm import enlighten - from enlighten import Counter as EnlCounter, StatusBar as EnlStatusBar, Manager as EnlManager + from enlighten import ( + Counter as EnlCounter, + StatusBar as EnlStatusBar, + Manager as EnlManager, + ) from alive_progress import alive_bar else: tqdm = EnlCounter = EnlStatusBar = EnlManager = Any @@ -37,7 +41,12 @@ class Collector(ABC): @abstractmethod def update( - self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None + self, + name: str, + inc: int = 1, + total: int = None, + message: str = None, + label: str = None, ) -> None: """Creates or updates a counter @@ -72,7 +81,9 @@ def __enter__(self: TCollector) -> TCollector: self._start(self.step) return self - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: Any) -> None: + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: Any + ) -> None: self._stop() @@ -80,7 +91,12 @@ class NullCollector(Collector): """A default counter that does not count anything.""" def update( - self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None + self, + name: str, + inc: int = 1, + total: int = None, + message: str = None, + label: str = None, ) -> None: pass @@ -98,7 +114,12 @@ def __init__(self) -> None: self.counters: DefaultDict[str, int] = None def update( - self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None + self, + name: str, + inc: int = 1, + total: int = None, + message: str = None, + label: str = None, ) -> None: assert not label, "labels not supported in dict collector" self.counters[name] += inc @@ -149,16 +170,21 @@ def __init__( except ImportError: self._log( logging.WARNING, - "psutil dependency is not installed and mem stats will not be available. add" - " psutil to your environment or pass dump_system_stats argument as False to" - " disable warning.", + "psutil dependency is not installed and mem stats will not be" + " available. add psutil to your environment or pass" + " dump_system_stats argument as False to disable warning.", ) dump_system_stats = False self.dump_system_stats = dump_system_stats self.last_log_time: float = None def update( - self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = None + self, + name: str, + inc: int = 1, + total: int = None, + message: str = None, + label: str = None, ) -> None: counter_key = f"{name}_{label}" if label else name @@ -178,7 +204,10 @@ def update( def maybe_log(self) -> None: current_time = time.time() - if self.last_log_time is None or current_time - self.last_log_time >= self.log_period: + if ( + self.last_log_time is None + or current_time - self.last_log_time >= self.log_period + ): self.dump_counters() self.last_log_time = current_time @@ -198,11 +227,13 @@ def dump_counters(self) -> None: percentage = f"({count / info.total * 100:.1f}%)" if info.total else "" elapsed_time_str = f"{elapsed_time:.2f}s" items_per_second_str = f"{items_per_second:.2f}/s" - message = f"[{self.messages[name]}]" if self.messages[name] is not None else "" + message = ( + f"[{self.messages[name]}]" if self.messages[name] is not None else "" + ) counter_line = ( - f"{info.description}: {progress} {percentage} | Time: {elapsed_time_str} | Rate:" - f" {items_per_second_str} {message}" + f"{info.description}: {progress} {percentage} | Time:" + f" {elapsed_time_str} | Rate: {items_per_second_str} {message}" ) log_lines.append(counter_line.strip()) @@ -263,7 +294,12 @@ def __init__(self, single_bar: bool = False, **tqdm_kwargs: Any) -> None: self.tqdm_kwargs = tqdm_kwargs or {} def update( - self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = "" + self, + name: str, + inc: int = 1, + total: int = None, + message: str = None, + label: str = "", ) -> None: key = f"{name}_{label}" bar = self._bars.get(key) @@ -315,7 +351,12 @@ def __init__(self, single_bar: bool = True, **alive_kwargs: Any) -> None: self.alive_kwargs = alive_kwargs or {} def update( - self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = "" + self, + name: str, + inc: int = 1, + total: int = None, + message: str = None, + label: str = "", ) -> None: key = f"{name}_{label}" bar = self._bars.get(key) @@ -369,13 +410,19 @@ def __init__(self, single_bar: bool = False, **enlighten_kwargs: Any) -> None: raise MissingDependencyException( "EnlightenCollector", ["enlighten"], - "We need enlighten to display progress bars with a space for log messages.", + "We need enlighten to display progress bars with a space for log" + " messages.", ) self.single_bar = single_bar self.enlighten_kwargs = enlighten_kwargs def update( - self, name: str, inc: int = 1, total: int = None, message: str = None, label: str = "" + self, + name: str, + inc: int = 1, + total: int = None, + message: str = None, + label: str = "", ) -> None: key = f"{name}_{label}" bar = self._bars.get(key) diff --git a/dlt/common/runtime/exec_info.py b/dlt/common/runtime/exec_info.py index 3aa19c83ab..e28936d48f 100644 --- a/dlt/common/runtime/exec_info.py +++ b/dlt/common/runtime/exec_info.py @@ -79,7 +79,9 @@ def is_colab() -> bool: def airflow_info() -> StrAny: try: - with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( + io.StringIO() + ): from airflow.operators.python import get_current_context get_current_context() @@ -90,7 +92,9 @@ def airflow_info() -> StrAny: def is_airflow_installed() -> bool: try: - with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( + io.StringIO() + ): import airflow return True except Exception: @@ -99,7 +103,9 @@ def is_airflow_installed() -> bool: def is_running_in_airflow_task() -> bool: try: - with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( + io.StringIO() + ): from airflow.operators.python import get_current_context context = get_current_context() @@ -124,7 +130,9 @@ def kube_pod_info() -> StrStr: def github_info() -> StrStr: """Extracts github info""" - info = filter_env_vars(["GITHUB_USER", "GITHUB_REPOSITORY", "GITHUB_REPOSITORY_OWNER"]) + info = filter_env_vars( + ["GITHUB_USER", "GITHUB_REPOSITORY", "GITHUB_REPOSITORY_OWNER"] + ) # set GITHUB_REPOSITORY_OWNER as github user if not present. GITHUB_REPOSITORY_OWNER is available in github action context if "github_user" not in info and "github_repository_owner" in info: info["github_user"] = info["github_repository_owner"] # type: ignore diff --git a/dlt/common/runtime/init.py b/dlt/common/runtime/init.py index 5354dee4ff..022ca15514 100644 --- a/dlt/common/runtime/init.py +++ b/dlt/common/runtime/init.py @@ -11,7 +11,11 @@ def init_logging(config: RunConfiguration) -> None: version = dlt_version_info(config.pipeline_name) logger.LOGGER = logger._init_logging( - logger.DLT_LOGGER_NAME, config.log_level, config.log_format, config.pipeline_name, version + logger.DLT_LOGGER_NAME, + config.log_level, + config.log_format, + config.pipeline_name, + version, ) diff --git a/dlt/common/runtime/json_logging.py b/dlt/common/runtime/json_logging.py index 042236a093..92e0b45859 100644 --- a/dlt/common/runtime/json_logging.py +++ b/dlt/common/runtime/json_logging.py @@ -74,7 +74,8 @@ def init(custom_formatter: Type[logging.Formatter] = None) -> None: if custom_formatter: if not issubclass(custom_formatter, logging.Formatter): raise ValueError( - "custom_formatter is not subclass of logging.Formatter", custom_formatter + "custom_formatter is not subclass of logging.Formatter", + custom_formatter, ) _default_formatter = custom_formatter if custom_formatter else JSONLogFormatter @@ -195,7 +196,10 @@ def update_formatter_for_loggers( def epoch_nano_second(datetime_: datetime) -> int: - return int((datetime_ - _epoch).total_seconds()) * 1000000000 + datetime_.microsecond * 1000 + return ( + int((datetime_ - _epoch).total_seconds()) * 1000000000 + + datetime_.microsecond * 1000 + ) def iso_time_format(datetime_: datetime) -> str: diff --git a/dlt/common/runtime/prometheus.py b/dlt/common/runtime/prometheus.py index 07c960efe7..c715fb8f3f 100644 --- a/dlt/common/runtime/prometheus.py +++ b/dlt/common/runtime/prometheus.py @@ -24,7 +24,9 @@ def get_metrics_from_prometheus(gauges: Iterable[MetricWrapperBase]) -> StrAny: if g._is_parent(): # for gauges containing many label values, enumerate all metrics.update( - get_metrics_from_prometheus([g.labels(*label) for label in g._metrics.keys()]) + get_metrics_from_prometheus( + [g.labels(*label) for label in g._metrics.keys()] + ) ) continue # for gauges with labels: add the label to the name and enumerate samples diff --git a/dlt/common/runtime/segment.py b/dlt/common/runtime/segment.py index 70b81fb4f4..9443160a9f 100644 --- a/dlt/common/runtime/segment.py +++ b/dlt/common/runtime/segment.py @@ -51,7 +51,9 @@ def disable_segment() -> None: _at_exit_cleanup() -def track(event_category: TEventCategory, event_name: str, properties: DictStrAny) -> None: +def track( + event_category: TEventCategory, event_name: str, properties: DictStrAny +) -> None: """Tracks a telemetry event. The segment event name will be created as "{event_category}_{event_name} @@ -67,7 +69,9 @@ def track(event_category: TEventCategory, event_name: str, properties: DictStrAn properties.update({"event_category": event_category, "event_name": event_name}) try: - _send_event(f"{event_category}_{event_name}", properties, _default_context_fields()) + _send_event( + f"{event_category}_{event_name}", properties, _default_context_fields() + ) except Exception as e: logger.debug(f"Skipping telemetry reporting: {e}") raise @@ -119,7 +123,9 @@ def get_anonymous_id() -> str: return anonymous_id -def _segment_request_payload(event_name: str, properties: StrAny, context: StrAny) -> DictStrAny: +def _segment_request_payload( + event_name: str, properties: StrAny, context: StrAny +) -> DictStrAny: """Compose a valid payload for the segment API. Args: @@ -183,7 +189,10 @@ def _future_send() -> None: # import time # start_ts = time.time() resp = _SESSION.post( - _SEGMENT_ENDPOINT, headers=headers, json=payload, timeout=_SEGMENT_REQUEST_TIMEOUT + _SEGMENT_ENDPOINT, + headers=headers, + json=payload, + timeout=_SEGMENT_REQUEST_TIMEOUT, ) # print(f"SENDING TO Segment done {resp.status_code} {time.time() - start_ts} {base64.b64decode(_WRITE_KEY)}") # handle different failure cases @@ -195,6 +204,8 @@ def _future_send() -> None: else: data = resp.json() if not data.get("success"): - logger.debug(f"Segment telemetry request returned a failure. Response: {data}") + logger.debug( + f"Segment telemetry request returned a failure. Response: {data}" + ) _THREAD_POOL.thread_pool.submit(_future_send) diff --git a/dlt/common/runtime/sentry.py b/dlt/common/runtime/sentry.py index 7ea45affc0..c55a30543e 100644 --- a/dlt/common/runtime/sentry.py +++ b/dlt/common/runtime/sentry.py @@ -11,7 +11,8 @@ raise MissingDependencyException( "sentry telemetry", ["sentry-sdk"], - "Please install sentry-sdk if you have `sentry_dsn` set in your RuntimeConfiguration", + "Please install sentry-sdk if you have `sentry_dsn` set in your" + " RuntimeConfiguration", ) from dlt.common.typing import DictStrAny, Any, StrAny @@ -56,7 +57,9 @@ def disable_sentry() -> None: sentry_sdk.init() -def before_send(event: DictStrAny, _unused_hint: Optional[StrAny] = None) -> Optional[DictStrAny]: +def before_send( + event: DictStrAny, _unused_hint: Optional[StrAny] = None +) -> Optional[DictStrAny]: """Called by sentry before sending event. Does nothing, patch this function in the module for custom behavior""" return event diff --git a/dlt/common/runtime/slack.py b/dlt/common/runtime/slack.py index b1e090098d..0608fbd8fe 100644 --- a/dlt/common/runtime/slack.py +++ b/dlt/common/runtime/slack.py @@ -1,7 +1,9 @@ import requests -def send_slack_message(incoming_hook: str, message: str, is_markdown: bool = True) -> None: +def send_slack_message( + incoming_hook: str, message: str, is_markdown: bool = True +) -> None: from dlt.common import json, logger """Sends a `message` to Slack `incoming_hook`, by default formatted as markdown.""" diff --git a/dlt/common/runtime/telemetry.py b/dlt/common/runtime/telemetry.py index e03bc04d79..5b467686f3 100644 --- a/dlt/common/runtime/telemetry.py +++ b/dlt/common/runtime/telemetry.py @@ -6,7 +6,12 @@ from dlt.common.configuration.specs import RunConfiguration from dlt.common.typing import TFun from dlt.common.configuration import resolve_configuration -from dlt.common.runtime.segment import TEventCategory, init_segment, disable_segment, track +from dlt.common.runtime.segment import ( + TEventCategory, + init_segment, + disable_segment, + track, +) _TELEMETRY_STARTED = False @@ -62,7 +67,9 @@ def decorator(f: TFun) -> TFun: def _wrap(*f_args: Any, **f_kwargs: Any) -> Any: # look for additional arguments bound_args = sig.bind(*f_args, **f_kwargs) - props = {p: bound_args.arguments[p] for p in args if p in bound_args.arguments} + props = { + p: bound_args.arguments[p] for p in args if p in bound_args.arguments + } start_ts = time.time() def _track(success: bool) -> None: diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 96341ab8b4..5af56723d6 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -19,10 +19,10 @@ class InvalidSchemaName(ValueError, SchemaException): def __init__(self, name: str) -> None: self.name = name super().__init__( - f"{name} is an invalid schema/source name. The source or schema name must be a valid" - " Python identifier ie. a snake case function name and have maximum" - f" {self.MAXIMUM_SCHEMA_NAME_LENGTH} characters. Ideally should contain only small" - " letters, numbers and underscores." + f"{name} is an invalid schema/source name. The source or schema name must" + " be a valid Python identifier ie. a snake case function name and have" + f" maximum {self.MAXIMUM_SCHEMA_NAME_LENGTH} characters. Ideally should" + " contain only small letters, numbers and underscores." ) @@ -30,8 +30,9 @@ class InvalidDatasetName(ValueError, SchemaException): def __init__(self, destination_name: str) -> None: self.destination_name = destination_name super().__init__( - f"Destination {destination_name} does not accept empty datasets. Please pass the" - " dataset name to the destination configuration ie. via dlt pipeline." + f"Destination {destination_name} does not accept empty datasets. Please" + " pass the dataset name to the destination configuration ie. via dlt" + " pipeline." ) @@ -50,8 +51,8 @@ def __init__( self.to_type = to_type self.coerced_value = coerced_value super().__init__( - f"Cannot coerce type in table {table_name} column {column_name} existing type" - f" {from_type} coerced type {to_type} value: {coerced_value}" + f"Cannot coerce type in table {table_name} column {column_name} existing" + f" type {from_type} coerced type {to_type} value: {coerced_value}" ) @@ -62,13 +63,15 @@ def __init__(self, table_name: str, prop_name: str, val1: str, val2: str): self.val1 = val1 self.val2 = val2 super().__init__( - f"Cannot merge partial tables for {table_name} due to property {prop_name}: {val1} !=" - f" {val2}" + f"Cannot merge partial tables for {table_name} due to property {prop_name}:" + f" {val1} != {val2}" ) class ParentTableNotFoundException(SchemaException): - def __init__(self, table_name: str, parent_table_name: str, explanation: str = "") -> None: + def __init__( + self, table_name: str, parent_table_name: str, explanation: str = "" + ) -> None: self.table_name = table_name self.parent_table_name = parent_table_name super().__init__( @@ -80,7 +83,8 @@ def __init__(self, table_name: str, parent_table_name: str, explanation: str = " class CannotCoerceNullException(SchemaException): def __init__(self, table_name: str, column_name: str) -> None: super().__init__( - f"Cannot coerce NULL in table {table_name} column {column_name} which is not nullable" + f"Cannot coerce NULL in table {table_name} column {column_name} which is" + " not nullable" ) @@ -97,10 +101,10 @@ def __init__( self.from_engine = from_engine self.to_engine = to_engine super().__init__( - f"No engine upgrade path in schema {schema_name} from {init_engine} to {to_engine}," - f" stopped at {from_engine}. You possibly tried to run an older dlt" - " version against a destination you have previously loaded data to with a newer dlt" - " version." + f"No engine upgrade path in schema {schema_name} from {init_engine} to" + f" {to_engine}, stopped at {from_engine}. You possibly tried to run an" + " older dlt version against a destination you have previously loaded data" + " to with a newer dlt version." ) diff --git a/dlt/common/schema/migrations.py b/dlt/common/schema/migrations.py index 9b206d61a6..8fa5af0b48 100644 --- a/dlt/common/schema/migrations.py +++ b/dlt/common/schema/migrations.py @@ -17,7 +17,9 @@ from dlt.common.schema.utils import new_table, version_table, load_table -def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> TStoredSchema: +def migrate_schema( + schema_dict: DictStrAny, from_engine: int, to_engine: int +) -> TStoredSchema: if from_engine == to_engine: return cast(TStoredSchema, schema_dict) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index c738f1753e..3db4216c90 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -1,5 +1,16 @@ from copy import copy, deepcopy -from typing import ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple, Any, cast, Literal +from typing import ( + ClassVar, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Any, + cast, + Literal, +) from dlt.common import json from dlt.common.schema.migrations import migrate_schema @@ -12,7 +23,11 @@ VARIANT_FIELD_FORMAT, TDataItem, ) -from dlt.common.normalizers import TNormalizersConfig, explicit_normalizers, import_normalizers +from dlt.common.normalizers import ( + TNormalizersConfig, + explicit_normalizers, + import_normalizers, +) from dlt.common.normalizers.naming import NamingConvention from dlt.common.normalizers.json import DataItemNormalizer, TNormalizedRowIterator from dlt.common.schema import utils @@ -77,7 +92,9 @@ class Schema: _dlt_tables_prefix: str _stored_version: int # version at load time _stored_version_hash: str # version hash at load time - _stored_previous_hashes: Optional[List[str]] # list of ancestor hashes of the schema + _stored_previous_hashes: Optional[ + List[str] + ] # list of ancestor hashes of the schema _imported_version_hash: str # version hash of recently imported schema _schema_description: str # optional schema description _schema_tables: TSchemaTables @@ -142,7 +159,9 @@ def replace_schema_content( self._reset_schema(schema.name, schema._normalizers_config) self._from_stored_schema(stored_schema) - def to_dict(self, remove_defaults: bool = False, bump_version: bool = True) -> TStoredSchema: + def to_dict( + self, remove_defaults: bool = False, bump_version: bool = True + ) -> TStoredSchema: stored_schema: TStoredSchema = { "version": self._stored_version, "version_hash": self._stored_version_hash, @@ -263,7 +282,8 @@ def apply_schema_contract( data_item: TDataItem = None, raise_on_freeze: bool = True, ) -> Tuple[ - TPartialTableSchema, List[Tuple[TSchemaContractEntities, str, TSchemaEvolutionMode]] + TPartialTableSchema, + List[Tuple[TSchemaContractEntities, str, TSchemaEvolutionMode]], ]: """ Checks if `schema_contract` allows for the `partial_table` to update the schema. It applies the contract dropping @@ -316,7 +336,10 @@ def apply_schema_contract( # filter tables with name below return None, [("tables", table_name, schema_contract["tables"])] - column_mode, data_mode = schema_contract["columns"], schema_contract["data_type"] + column_mode, data_mode = ( + schema_contract["columns"], + schema_contract["data_type"], + ) # allow to add new columns when table is new or if columns are allowed to evolve once if is_new_table or existing_table.get("x-normalizer", {}).get("evolve-columns-once", False): # type: ignore[attr-defined] column_mode = "evolve" @@ -340,8 +363,8 @@ def apply_schema_contract( existing_table, schema_contract, data_item, - f"Trying to add column {column_name} to table {table_name} but columns are" - " frozen.", + f"Trying to add column {column_name} to table {table_name} but" + " columns are frozen.", ) # filter column with name below filters.append(("columns", column_name, column_mode)) @@ -376,9 +399,12 @@ def expand_schema_contract_settings( ) -> TSchemaContractDict: """Expand partial or shorthand settings into full settings dictionary using `default` for unset entities""" if isinstance(settings, str): - settings = TSchemaContractDict(tables=settings, columns=settings, data_type=settings) + settings = TSchemaContractDict( + tables=settings, columns=settings, data_type=settings + ) return cast( - TSchemaContractDict, {**(default or DEFAULT_SCHEMA_CONTRACT_MODE), **(settings or {})} + TSchemaContractDict, + {**(default or DEFAULT_SCHEMA_CONTRACT_MODE), **(settings or {})}, ) def resolve_contract_settings_for_table( @@ -413,9 +439,9 @@ def update_table(self, partial_table: TPartialTableSchema) -> TPartialTableSchem raise ParentTableNotFoundException( table_name, parent_table_name, - " This may be due to misconfigured excludes filter that fully deletes content" - f" of the {parent_table_name}. Add includes that will preserve the parent" - " table.", + " This may be due to misconfigured excludes filter that fully" + f" deletes content of the {parent_table_name}. Add includes that" + " will preserve the parent table.", ) table = self._schema_tables.get(table_name) if table is None: @@ -439,7 +465,9 @@ def update_schema(self, schema: "Schema") -> None: self._settings = deepcopy(schema.settings) self._compile_settings() - def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: StrAny) -> StrAny: + def filter_row_with_hint( + self, table_name: str, hint_type: TColumnHint, row: StrAny + ) -> StrAny: rv_row: DictStrAny = {} column_prop: TColumnProp = utils.hint_to_column_prop(hint_type) try: @@ -457,7 +485,9 @@ def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: Str # dicts are ordered and we will return the rows with hints in the same order as they appear in the columns return rv_row - def merge_hints(self, new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]]) -> None: + def merge_hints( + self, new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]] + ) -> None: # validate regexes validate_dict( TSchemaSettings, @@ -561,7 +591,9 @@ def data_table_names(self) -> List[str]: def dlt_tables(self) -> List[TTableSchema]: """Gets dlt tables""" return [ - t for t in self._schema_tables.values() if t["name"].startswith(self._dlt_tables_prefix) + t + for t in self._schema_tables.values() + if t["name"].startswith(self._dlt_tables_prefix) ] def dlt_table_names(self) -> List[str]: @@ -569,7 +601,10 @@ def dlt_table_names(self) -> List[str]: return [t["name"] for t in self.dlt_tables()] def get_preferred_type(self, col_name: str) -> Optional[TDataType]: - return next((m[1] for m in self._compiled_preferred_types if m[0].search(col_name)), None) + return next( + (m[1] for m in self._compiled_preferred_types if m[0].search(col_name)), + None, + ) def is_new_table(self, table_name: str) -> bool: """Returns true if this table does not exist OR is incomplete (has only incomplete columns) and therefore new""" @@ -649,7 +684,9 @@ def to_pretty_yaml(self, remove_defaults: bool = True) -> str: d = self.to_dict(remove_defaults=remove_defaults) return utils.to_pretty_yaml(d) - def clone(self, with_name: str = None, update_normalizers: bool = False) -> "Schema": + def clone( + self, with_name: str = None, update_normalizers: bool = False + ) -> "Schema": """Make a deep copy of the schema, optionally changing the name, and updating normalizers and identifiers in the schema if `update_normalizers` is True Note that changing of name will set the schema as new @@ -749,7 +786,11 @@ def _coerce_non_null_value( if is_variant: # this is final call: we cannot generate any more auto-variants raise CannotCoerceColumnException( - table_name, col_name, py_type, table_columns[col_name]["data_type"], v + table_name, + col_name, + py_type, + table_columns[col_name]["data_type"], + v, ) # otherwise we must create variant extension to the table # pass final=True so no more auto-variants can be created recursively @@ -771,7 +812,11 @@ def _coerce_non_null_value( col_name, VARIANT_FIELD_FORMAT % coerced_v[0] ) return self._coerce_non_null_value( - table_columns, table_name, variant_col_name, coerced_v[1], is_variant=True + table_columns, + table_name, + variant_col_name, + coerced_v[1], + is_variant=True, ) if not existing_column: @@ -787,7 +832,9 @@ def _coerce_non_null_value( return col_name, new_column, coerced_v - def _infer_column_type(self, v: Any, col_name: str, skip_preferred: bool = False) -> TDataType: + def _infer_column_type( + self, v: Any, col_name: str, skip_preferred: bool = False + ) -> TDataType: tv = type(v) # try to autodetect data type mapped_type = utils.autodetect_sc_type(self._type_detections, tv, v) @@ -814,8 +861,8 @@ def _bump_version(self) -> Tuple[int, str]: Returns: Tuple[int, str]: Current (``stored_version``, ``stored_version_hash``) tuple """ - self._stored_version, self._stored_version_hash, _, _ = utils.bump_version_if_modified( - self.to_dict(bump_version=False) + self._stored_version, self._stored_version_hash, _, _ = ( + utils.bump_version_if_modified(self.to_dict(bump_version=False)) ) return self._stored_version, self._stored_version_hash @@ -846,8 +893,8 @@ def _add_standard_hints(self) -> None: def _configure_normalizers(self, normalizers: TNormalizersConfig) -> None: # import desired modules - self._normalizers_config, naming_module, item_normalizer_class = import_normalizers( - normalizers + self._normalizers_config, naming_module, item_normalizer_class = ( + import_normalizers(normalizers) ) # print(f"{self.name}: {type(self.naming)} {type(naming_module)}") if self.naming and type(self.naming) is not type(naming_module): @@ -859,8 +906,12 @@ def _configure_normalizers(self, normalizers: TNormalizersConfig) -> None: # name normalization functions self.naming = naming_module - self._dlt_tables_prefix = self.naming.normalize_table_identifier(DLT_NAME_PREFIX) - self.version_table_name = self.naming.normalize_table_identifier(VERSION_TABLE_NAME) + self._dlt_tables_prefix = self.naming.normalize_table_identifier( + DLT_NAME_PREFIX + ) + self.version_table_name = self.naming.normalize_table_identifier( + VERSION_TABLE_NAME + ) self.loads_table_name = self.naming.normalize_table_identifier(LOADS_TABLE_NAME) self.state_table_name = self.naming.normalize_table_identifier(STATE_TABLE_NAME) # data item normalization function @@ -903,9 +954,13 @@ def _reset_schema(self, name: str, normalizers: TNormalizersConfig = None) -> No def _from_stored_schema(self, stored_schema: TStoredSchema) -> None: self._schema_tables = stored_schema.get("tables") or {} if self.version_table_name not in self._schema_tables: - raise SchemaCorruptedException(f"Schema must contain table {self.version_table_name}") + raise SchemaCorruptedException( + f"Schema must contain table {self.version_table_name}" + ) if self.loads_table_name not in self._schema_tables: - raise SchemaCorruptedException(f"Schema must contain table {self.loads_table_name}") + raise SchemaCorruptedException( + f"Schema must contain table {self.loads_table_name}" + ) self._stored_version = stored_schema["version"] self._stored_version_hash = stored_schema["version_hash"] self._imported_version_hash = stored_schema.get("imported_version_hash") @@ -924,20 +979,28 @@ def _compile_settings(self) -> None: # if self._settings: for pattern, dt in self._settings.get("preferred_types", {}).items(): # add tuples to be searched in coercions - self._compiled_preferred_types.append((utils.compile_simple_regex(pattern), dt)) + self._compiled_preferred_types.append( + (utils.compile_simple_regex(pattern), dt) + ) for hint_name, hint_list in self._settings.get("default_hints", {}).items(): # compile hints which are column matching regexes - self._compiled_hints[hint_name] = list(map(utils.compile_simple_regex, hint_list)) + self._compiled_hints[hint_name] = list( + map(utils.compile_simple_regex, hint_list) + ) if self._schema_tables: for table in self._schema_tables.values(): if "filters" in table: if "excludes" in table["filters"]: self._compiled_excludes[table["name"]] = list( - map(utils.compile_simple_regex, table["filters"]["excludes"]) + map( + utils.compile_simple_regex, table["filters"]["excludes"] + ) ) if "includes" in table["filters"]: self._compiled_includes[table["name"]] = list( - map(utils.compile_simple_regex, table["filters"]["includes"]) + map( + utils.compile_simple_regex, table["filters"]["includes"] + ) ) # look for auto-detections in settings and then normalizer self._type_detections = self._settings.get("detections") or self._normalizers_config.get("detections") or [] # type: ignore diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index ec60e4c365..f55c2f3c92 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -67,7 +67,12 @@ TWriteDisposition = Literal["skip", "append", "replace", "merge"] TTableFormat = Literal["iceberg", "parquet", "jsonl"] TTypeDetections = Literal[ - "timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double" + "timestamp", + "iso_timestamp", + "iso_date", + "large_integer", + "hexbytes_to_text", + "wei_to_double", ] TTypeDetectionFunc = Callable[[Type[Any], Any], Optional[TDataType]] TColumnNames = Union[str, Sequence[str]] @@ -124,7 +129,10 @@ class TColumnSchema(TColumnSchemaBase, total=False): TAnySchemaColumns = Union[ - TTableSchemaColumns, Sequence[TColumnSchema], _PydanticBaseModel, Type[_PydanticBaseModel] + TTableSchemaColumns, + Sequence[TColumnSchema], + _PydanticBaseModel, + Type[_PydanticBaseModel], ] TSimpleRegex = NewType("TSimpleRegex", str) diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 0a4e00759d..bd45bae57e 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -3,7 +3,18 @@ import hashlib import yaml from copy import deepcopy, copy -from typing import Dict, List, Sequence, Tuple, Type, Any, cast, Iterable, Optional, Union +from typing import ( + Dict, + List, + Sequence, + Tuple, + Type, + Any, + cast, + Iterable, + Optional, + Union, +) from dlt.common import json from dlt.common.data_types import TDataType @@ -160,7 +171,9 @@ def add_column_defaults(column: TColumnSchemaBase) -> TColumnSchema: # return copy(column) # type: ignore -def bump_version_if_modified(stored_schema: TStoredSchema) -> Tuple[int, str, str, Sequence[str]]: +def bump_version_if_modified( + stored_schema: TStoredSchema, +) -> Tuple[int, str, str, Sequence[str]]: """Bumps the `stored_schema` version and version hash if content modified, returns (new version, new hash, old hash, 10 last hashes) tuple""" hash_ = generate_version_hash(stored_schema) previous_hash = stored_schema.get("version_hash") @@ -174,7 +187,12 @@ def bump_version_if_modified(stored_schema: TStoredSchema) -> Tuple[int, str, st store_prev_hash(stored_schema, previous_hash) stored_schema["version_hash"] = hash_ - return stored_schema["version"], hash_, previous_hash, stored_schema["previous_hashes"] + return ( + stored_schema["version"], + hash_, + previous_hash, + stored_schema["previous_hashes"], + ) def store_prev_hash( @@ -183,7 +201,9 @@ def store_prev_hash( # unshift previous hash to previous_hashes and limit array to 10 entries if previous_hash not in stored_schema["previous_hashes"]: stored_schema["previous_hashes"].insert(0, previous_hash) - stored_schema["previous_hashes"] = stored_schema["previous_hashes"][:max_history_len] + stored_schema["previous_hashes"] = stored_schema["previous_hashes"][ + :max_history_len + ] def generate_version_hash(stored_schema: TStoredSchema) -> str: @@ -227,8 +247,8 @@ def simple_regex_validator(path: str, pk: str, pv: Any, t: Any) -> bool: if t is TSimpleRegex: if not isinstance(pv, str): raise DictValidationException( - f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while str" - " is expected", + f"In {path}: field {pk} value {pv} has invalid type" + f" {type(pv).__name__} while str is expected", path, pk, pv, @@ -239,7 +259,8 @@ def simple_regex_validator(path: str, pk: str, pv: Any, t: Any) -> bool: re.compile(pv[3:]) except Exception as e: raise DictValidationException( - f"In {path}: field {pk} value {pv[3:]} does not compile as regex: {str(e)}", + f"In {path}: field {pk} value {pv[3:]} does not compile as regex:" + f" {str(e)}", path, pk, pv, @@ -247,7 +268,8 @@ def simple_regex_validator(path: str, pk: str, pv: Any, t: Any) -> bool: else: if RE_NON_ALPHANUMERIC_UNDERSCORE.match(pv): raise DictValidationException( - f"In {path}: field {pk} value {pv} looks like a regex, please prefix with re:", + f"In {path}: field {pk} value {pv} looks like a regex, please" + " prefix with re:", path, pk, pv, @@ -264,8 +286,8 @@ def validator(path: str, pk: str, pv: Any, t: Any) -> bool: if t is TColumnName: if not isinstance(pv, str): raise DictValidationException( - f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while" - " str is expected", + f"In {path}: field {pk} value {pv} has invalid type" + f" {type(pv).__name__} while str is expected", path, pk, pv, @@ -273,11 +295,17 @@ def validator(path: str, pk: str, pv: Any, t: Any) -> bool: try: if naming.normalize_path(pv) != pv: raise DictValidationException( - f"In {path}: field {pk}: {pv} is not a valid column name", path, pk, pv + f"In {path}: field {pk}: {pv} is not a valid column name", + path, + pk, + pv, ) except ValueError: raise DictValidationException( - f"In {path}: field {pk}: {pv} is not a valid column name", path, pk, pv + f"In {path}: field {pk}: {pv} is not a valid column name", + path, + pk, + pv, ) return True else: @@ -309,7 +337,10 @@ def compile_simple_regexes(r: Iterable[TSimpleRegex]) -> REPattern: def validate_stored_schema(stored_schema: TStoredSchema) -> None: # use lambda to verify only non extra fields validate_dict_ignoring_xkeys( - spec=TStoredSchema, doc=stored_schema, path=".", validator_f=simple_regex_validator + spec=TStoredSchema, + doc=stored_schema, + path=".", + validator_f=simple_regex_validator, ) # check child parent relationships for table_name, table in stored_schema["tables"].items(): @@ -319,7 +350,9 @@ def validate_stored_schema(stored_schema: TStoredSchema) -> None: raise ParentTableNotFoundException(table_name, parent_table_name) -def autodetect_sc_type(detection_fs: Sequence[TTypeDetections], t: Type[Any], v: Any) -> TDataType: +def autodetect_sc_type( + detection_fs: Sequence[TTypeDetections], t: Type[Any], v: Any +) -> TDataType: if detection_fs: for detection_fn in detection_fs: # the method must exist in the module @@ -430,7 +463,9 @@ def diff_table(tab_a: TTableSchema, tab_b: TPartialTableSchema) -> TPartialTable # return False -def merge_table(table: TTableSchema, partial_table: TPartialTableSchema) -> TPartialTableSchema: +def merge_table( + table: TTableSchema, partial_table: TPartialTableSchema +) -> TPartialTableSchema: """Merges "partial_table" into "table". `table` is merged in place. Returns the diff partial table. `table` and `partial_table` names must be identical. A table diff is generated and applied to `table`: @@ -472,19 +507,24 @@ def hint_to_column_prop(h: TColumnHint) -> TColumnProp: def get_columns_names_with_prop( - table: TTableSchema, column_prop: Union[TColumnProp, str], include_incomplete: bool = False + table: TTableSchema, + column_prop: Union[TColumnProp, str], + include_incomplete: bool = False, ) -> List[str]: # column_prop: TColumnProp = hint_to_column_prop(hint_type) # default = column_prop != "nullable" # default is true, only for nullable false return [ c["name"] for c in table["columns"].values() - if bool(c.get(column_prop, False)) and (include_incomplete or is_complete_column(c)) + if bool(c.get(column_prop, False)) + and (include_incomplete or is_complete_column(c)) ] def get_first_column_name_with_prop( - table: TTableSchema, column_prop: Union[TColumnProp, str], include_incomplete: bool = False + table: TTableSchema, + column_prop: Union[TColumnProp, str], + include_incomplete: bool = False, ) -> Optional[str]: """Returns name of first column in `table` schema with property `column_prop` or None if no such column exists.""" column_names = get_columns_names_with_prop(table, column_prop, include_incomplete) @@ -494,7 +534,9 @@ def get_first_column_name_with_prop( def has_column_with_prop( - table: TTableSchema, column_prop: Union[TColumnProp, str], include_incomplete: bool = False + table: TTableSchema, + column_prop: Union[TColumnProp, str], + include_incomplete: bool = False, ) -> bool: """Checks if `table` schema contains column with property `column_prop`.""" return len(get_columns_names_with_prop(table, column_prop, include_incomplete)) > 0 @@ -509,7 +551,9 @@ def get_dedup_sort_tuple( Returns None if "dedup_sort" hint was not provided. """ - dedup_sort_col = get_first_column_name_with_prop(table, "dedup_sort", include_incomplete) + dedup_sort_col = get_first_column_name_with_prop( + table, "dedup_sort", include_incomplete + ) if dedup_sort_col is None: return None dedup_sort_order = table["columns"][dedup_sort_col]["dedup_sort"] @@ -522,13 +566,18 @@ def merge_schema_updates(schema_updates: Sequence[TSchemaUpdate]) -> TSchemaTabl for table_name, table_updates in schema_update.items(): for partial_table in table_updates: # aggregate schema updates - aggregated_table = aggregated_update.setdefault(table_name, partial_table) + aggregated_table = aggregated_update.setdefault( + table_name, partial_table + ) aggregated_table["columns"].update(partial_table["columns"]) return aggregated_update def get_inherited_table_hint( - tables: TSchemaTables, table_name: str, table_hint_name: str, allow_none: bool = False + tables: TSchemaTables, + table_name: str, + table_hint_name: str, + allow_none: bool = False, ) -> Any: table = tables.get(table_name, {}) hint = table.get(table_hint_name) @@ -543,7 +592,8 @@ def get_inherited_table_hint( return None raise ValueError( - f"No table hint '{table_hint_name} found in the chain of tables for '{table_name}'." + f"No table hint '{table_hint_name} found in the chain of tables for" + f" '{table_name}'." ) @@ -551,13 +601,16 @@ def get_write_disposition(tables: TSchemaTables, table_name: str) -> TWriteDispo """Returns table hint of a table if present. If not, looks up into parent table""" return cast( TWriteDisposition, - get_inherited_table_hint(tables, table_name, "write_disposition", allow_none=False), + get_inherited_table_hint( + tables, table_name, "write_disposition", allow_none=False + ), ) def get_table_format(tables: TSchemaTables, table_name: str) -> TTableFormat: return cast( - TTableFormat, get_inherited_table_hint(tables, table_name, "table_format", allow_none=True) + TTableFormat, + get_inherited_table_hint(tables, table_name, "table_format", allow_none=True), ) @@ -741,4 +794,6 @@ def to_pretty_json(stored_schema: TStoredSchema) -> str: def to_pretty_yaml(stored_schema: TStoredSchema) -> str: - return yaml.dump(stored_schema, allow_unicode=True, default_flow_style=False, sort_keys=False) + return yaml.dump( + stored_schema, allow_unicode=True, default_flow_style=False, sort_keys=False + ) diff --git a/dlt/common/storages/configuration.py b/dlt/common/storages/configuration.py index d0100c335d..5948b450bd 100644 --- a/dlt/common/storages/configuration.py +++ b/dlt/common/storages/configuration.py @@ -1,5 +1,15 @@ import os -from typing import TYPE_CHECKING, Any, Literal, Optional, Type, get_args, ClassVar, Dict, Union +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Optional, + Type, + get_args, + ClassVar, + Dict, + Union, +) from urllib.parse import urlparse from dlt.common.configuration import configspec, resolve_type @@ -24,9 +34,13 @@ @configspec class SchemaStorageConfiguration(BaseConfiguration): schema_volume_path: str = None # path to volume with default schemas - import_schema_path: Optional[str] = None # path from which to import a schema into storage + import_schema_path: Optional[str] = ( + None # path from which to import a schema into storage + ) export_schema_path: Optional[str] = None # path to which export schema from storage - external_schema_format: TSchemaFileFormat = "yaml" # format in which to expect external schema + external_schema_format: TSchemaFileFormat = ( # format in which to expect external schema + "yaml" + ) external_schema_format_remove_defaults: bool = ( True # remove default values when exporting schema ) @@ -34,7 +48,9 @@ class SchemaStorageConfiguration(BaseConfiguration): @configspec class NormalizeStorageConfiguration(BaseConfiguration): - normalize_volume_path: str = None # path to volume where normalized loader files will be stored + normalize_volume_path: str = ( + None # path to volume where normalized loader files will be stored + ) @configspec @@ -89,7 +105,9 @@ def protocol(self) -> str: """`bucket_url` protocol""" url = urlparse(self.bucket_url) # this prevents windows absolute paths to be recognized as schemas - if not url.scheme or (os.path.isabs(self.bucket_url) and "\\" in self.bucket_url): + if not url.scheme or ( + os.path.isabs(self.bucket_url) and "\\" in self.bucket_url + ): return "file" else: return url.scheme @@ -98,8 +116,9 @@ def on_resolved(self) -> None: url = urlparse(self.bucket_url) if not url.path and not url.netloc: raise ConfigurationValueError( - "File path or netloc missing. Field bucket_url of FilesystemClientConfiguration" - " must contain valid url with a path or host:password component." + "File path or netloc missing. Field bucket_url of" + " FilesystemClientConfiguration must contain valid url with a path or" + " host:password component." ) # this is just a path in a local file system if url.path == self.bucket_url: diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index 816c6bc494..e150099dce 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -41,7 +41,11 @@ def write_data_item( return writer.write_data_item(item, columns) def write_empty_items_file( - self, load_id: str, schema_name: str, table_name: str, columns: TTableSchemaColumns + self, + load_id: str, + schema_name: str, + table_name: str, + columns: TTableSchemaColumns, ) -> DataWriterMetrics: """Writes empty file: only header and footer without actual items. Closed the empty file and returns metrics. Mind that header and footer will be written.""" @@ -70,8 +74,8 @@ def close_writers(self, load_id: str) -> None: for name, writer in self.buffered_writers.items(): if name.startswith(load_id) and not writer.closed: logger.debug( - f"Closing writer for {name} with file {writer._file} and actual name" - f" {writer._file_name}" + f"Closing writer for {name} with file {writer._file} and actual" + f" name {writer._file_name}" ) writer.close() @@ -112,6 +116,8 @@ def _write_temp_job_file( return Path(file_name).name @abstractmethod - def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str: + def _get_data_item_path_template( + self, load_id: str, schema_name: str, table_name: str + ) -> str: """Returns a file template for item writer. note: use %s for file id to create required template format""" pass diff --git a/dlt/common/storages/exceptions.py b/dlt/common/storages/exceptions.py index f4288719c1..684ddc39a0 100644 --- a/dlt/common/storages/exceptions.py +++ b/dlt/common/storages/exceptions.py @@ -23,8 +23,8 @@ def __init__( self.migrated_version = migrated_version self.target_version = target_version super().__init__( - f"Could not find migration path for {storage_path} from v {initial_version} to" - f" {target_version}, stopped at {migrated_version}" + f"Could not find migration path for {storage_path} from v" + f" {initial_version} to {target_version}, stopped at {migrated_version}" ) @@ -39,7 +39,8 @@ def __init__( self.initial_version = initial_version self.target_version = target_version super().__init__( - f"Expected storage {storage_path} with v {target_version} but found {initial_version}" + f"Expected storage {storage_path} with v {target_version} but found" + f" {initial_version}" ) @@ -55,7 +56,8 @@ def __init__( self.from_version = from_version self.target_version = target_version super().__init__( - f"Storage {storage_path} with target v {target_version} at {from_version}: " + info + f"Storage {storage_path} with target v {target_version} at {from_version}: " + + info ) @@ -65,14 +67,17 @@ class LoadStorageException(StorageException): class JobWithUnsupportedWriterException(LoadStorageException, TerminalValueError): def __init__( - self, load_id: str, expected_file_formats: Iterable[TLoaderFileFormat], wrong_job: str + self, + load_id: str, + expected_file_formats: Iterable[TLoaderFileFormat], + wrong_job: str, ) -> None: self.load_id = load_id self.expected_file_formats = expected_file_formats self.wrong_job = wrong_job super().__init__( - f"Job {wrong_job} for load id {load_id} requires loader file format that is not one of" - f" {expected_file_formats}" + f"Job {wrong_job} for load id {load_id} requires loader file format that is" + f" not one of {expected_file_formats}" ) @@ -89,9 +94,9 @@ class SchemaStorageException(StorageException): class InStorageSchemaModified(SchemaStorageException): def __init__(self, schema_name: str, storage_path: str) -> None: msg = ( - f"Schema {schema_name} in {storage_path} was externally modified. This is not allowed" - " as that would prevent correct version tracking. Use import/export capabilities of" - " dlt to provide external changes." + f"Schema {schema_name} in {storage_path} was externally modified. This is" + " not allowed as that would prevent correct version tracking. Use" + " import/export capabilities of dlt to provide external changes." ) super().__init__(msg) @@ -113,14 +118,15 @@ def __init__( class UnexpectedSchemaName(SchemaStorageException, ValueError): def __init__(self, schema_name: str, storage_path: str, stored_name: str) -> None: super().__init__( - f"A schema file name '{schema_name}' in {storage_path} does not correspond to the name" - f" of schema in the file {stored_name}" + f"A schema file name '{schema_name}' in {storage_path} does not correspond" + f" to the name of schema in the file {stored_name}" ) class CurrentLoadPackageStateNotAvailable(StorageException): def __init__(self) -> None: super().__init__( - "State of the current load package is not available. Current load package state is" - " only available in a function decorated with @dlt.destination during loading." + "State of the current load package is not available. Current load package" + " state is only available in a function decorated with @dlt.destination" + " during loading." ) diff --git a/dlt/common/storages/file_storage.py b/dlt/common/storages/file_storage.py index 7fe62a9728..ee07d06b61 100644 --- a/dlt/common/storages/file_storage.py +++ b/dlt/common/storages/file_storage.py @@ -16,7 +16,9 @@ class FileStorage: - def __init__(self, storage_path: str, file_type: str = "t", makedirs: bool = False) -> None: + def __init__( + self, storage_path: str, file_type: str = "t", makedirs: bool = False + ) -> None: # make it absolute path self.storage_path = os.path.realpath(storage_path) # os.path.join(, '') self.file_type = file_type @@ -24,10 +26,14 @@ def __init__(self, storage_path: str, file_type: str = "t", makedirs: bool = Fal os.makedirs(storage_path, exist_ok=True) def save(self, relative_path: str, data: Any) -> str: - return self.save_atomic(self.storage_path, relative_path, data, file_type=self.file_type) + return self.save_atomic( + self.storage_path, relative_path, data, file_type=self.file_type + ) @staticmethod - def save_atomic(storage_path: str, relative_path: str, data: Any, file_type: str = "t") -> str: + def save_atomic( + storage_path: str, relative_path: str, data: Any, file_type: str = "t" + ) -> str: mode = "w" + file_type with tempfile.NamedTemporaryFile( dir=storage_path, mode=mode, delete=False, encoding=encoding_for_mode(mode) @@ -114,12 +120,19 @@ def open_file(self, relative_path: str, mode: str = "r") -> IO[Any]: mode = mode + self.file_type if "r" in mode: return FileStorage.open_zipsafe_ro(self.make_full_path(relative_path), mode) - return open(self.make_full_path(relative_path), mode, encoding=encoding_for_mode(mode)) + return open( + self.make_full_path(relative_path), mode, encoding=encoding_for_mode(mode) + ) - def open_temp(self, delete: bool = False, mode: str = "w", file_type: str = None) -> IO[Any]: + def open_temp( + self, delete: bool = False, mode: str = "w", file_type: str = None + ) -> IO[Any]: mode = mode + file_type or self.file_type return tempfile.NamedTemporaryFile( - dir=self.storage_path, mode=mode, delete=delete, encoding=encoding_for_mode(mode) + dir=self.storage_path, + mode=mode, + delete=delete, + encoding=encoding_for_mode(mode), ) def has_file(self, relative_path: str) -> bool: @@ -142,7 +155,9 @@ def list_folder_files(self, relative_path: str, to_root: bool = True) -> List[st if to_root: # list files in relative path, returning paths relative to storage root return [ - os.path.join(relative_path, e.name) for e in os.scandir(scan_path) if e.is_file() + os.path.join(relative_path, e.name) + for e in os.scandir(scan_path) + if e.is_file() ] else: # or to the folder @@ -154,7 +169,9 @@ def list_folder_dirs(self, relative_path: str, to_root: bool = True) -> List[str if to_root: # list folders in relative path, returning paths relative to storage root return [ - os.path.join(relative_path, e.name) for e in os.scandir(scan_path) if e.is_dir() + os.path.join(relative_path, e.name) + for e in os.scandir(scan_path) + if e.is_dir() ] else: # or to the folder @@ -165,7 +182,10 @@ def create_folder(self, relative_path: str, exists_ok: bool = False) -> None: def link_hard(self, from_relative_path: str, to_relative_path: str) -> None: # note: some interesting stuff on links https://lightrun.com/answers/conan-io-conan-research-investigate-symlinks-and-hard-links - os.link(self.make_full_path(from_relative_path), self.make_full_path(to_relative_path)) + os.link( + self.make_full_path(from_relative_path), + self.make_full_path(to_relative_path), + ) @staticmethod def link_hard_with_fallback(external_file_path: str, to_file_path: str) -> None: @@ -188,7 +208,10 @@ def atomic_rename(self, from_relative_path: str, to_relative_path: str) -> None: 3. All buckets mapped with FUSE are not atomic """ - os.rename(self.make_full_path(from_relative_path), self.make_full_path(to_relative_path)) + os.rename( + self.make_full_path(from_relative_path), + self.make_full_path(to_relative_path), + ) def rename_tree(self, from_relative_path: str, to_relative_path: str) -> None: """Renames a tree using os.rename if possible making it atomic @@ -228,7 +251,10 @@ def rename_tree_files(self, from_relative_path: str, to_relative_path: str) -> N os.rmdir(root) def atomic_import( - self, external_file_path: str, to_folder: str, new_file_name: Optional[str] = None + self, + external_file_path: str, + to_folder: str, + new_file_name: Optional[str] = None, ) -> str: """Moves a file at `external_file_path` into the `to_folder` effectively importing file into storage @@ -294,7 +320,9 @@ def validate_file_name_component(name: str) -> None: # component cannot contain "." if FILE_COMPONENT_INVALID_CHARACTERS.search(name): raise pathvalidate.error.InvalidCharError( - description="Component name cannot contain the following characters: . % { }" + description=( + "Component name cannot contain the following characters: . % { }" + ) ) @staticmethod diff --git a/dlt/common/storages/fsspec_filesystem.py b/dlt/common/storages/fsspec_filesystem.py index b1cbc11bf9..d1f925568a 100644 --- a/dlt/common/storages/fsspec_filesystem.py +++ b/dlt/common/storages/fsspec_filesystem.py @@ -31,7 +31,10 @@ AzureCredentials, ) from dlt.common.exceptions import MissingDependencyException -from dlt.common.storages.configuration import FileSystemCredentials, FilesystemConfiguration +from dlt.common.storages.configuration import ( + FileSystemCredentials, + FilesystemConfiguration, +) from dlt.common.time import ensure_pendulum_datetime from dlt.common.typing import DictStrAny @@ -67,13 +70,21 @@ class FileItem(TypedDict, total=False): # Map of protocol to a filesystem type CREDENTIALS_DISPATCH: Dict[str, Callable[[FilesystemConfiguration], DictStrAny]] = { "s3": lambda config: cast(AwsCredentials, config.credentials).to_s3fs_credentials(), - "adl": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), - "az": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), + "adl": lambda config: cast( + AzureCredentials, config.credentials + ).to_adlfs_credentials(), + "az": lambda config: cast( + AzureCredentials, config.credentials + ).to_adlfs_credentials(), "gcs": lambda config: cast(GcpCredentials, config.credentials).to_gcs_credentials(), "gs": lambda config: cast(GcpCredentials, config.credentials).to_gcs_credentials(), "gdrive": lambda config: {"credentials": cast(GcpCredentials, config.credentials)}, - "abfs": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), - "azure": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), + "abfs": lambda config: cast( + AzureCredentials, config.credentials + ).to_adlfs_credentials(), + "azure": lambda config: cast( + AzureCredentials, config.credentials + ).to_adlfs_credentials(), } @@ -94,7 +105,9 @@ def fsspec_filesystem( also see filesystem_from_config """ return fsspec_from_config( - FilesystemConfiguration(protocol, credentials, kwargs=kwargs, client_kwargs=client_kwargs) + FilesystemConfiguration( + protocol, credentials, kwargs=kwargs, client_kwargs=client_kwargs + ) ) @@ -115,7 +128,9 @@ def prepare_fsspec_args(config: FilesystemConfiguration) -> DictStrAny: if protocol == "gdrive": from dlt.common.storages.fsspecs.google_drive import GoogleDriveFileSystem - register_implementation("gdrive", GoogleDriveFileSystem, "GoogleDriveFileSystem") + register_implementation( + "gdrive", GoogleDriveFileSystem, "GoogleDriveFileSystem" + ) if config.kwargs is not None: fs_kwargs.update(config.kwargs) @@ -129,7 +144,9 @@ def prepare_fsspec_args(config: FilesystemConfiguration) -> DictStrAny: return fs_kwargs -def fsspec_from_config(config: FilesystemConfiguration) -> Tuple[AbstractFileSystem, str]: +def fsspec_from_config( + config: FilesystemConfiguration, +) -> Tuple[AbstractFileSystem, str]: """Instantiates an authenticated fsspec `FileSystem` from `config` argument. Authenticates following filesystems: @@ -210,8 +227,10 @@ def open( # noqa: A003 elif compression == "disable": compression_arg = None else: - raise ValueError("""The argument `compression` must have one of the following values: - "auto", "enable", "disable".""") + raise ValueError( + """The argument `compression` must have one of the following values: + "auto", "enable", "disable".""" + ) opened_file: IO[Any] # if the user has already extracted the content, we use it so there is no need to @@ -227,7 +246,9 @@ def open( # noqa: A003 if "t" not in mode: return bytes_io text_kwargs = { - k: kwargs.pop(k) for k in ["encoding", "errors", "newline"] if k in kwargs + k: kwargs.pop(k) + for k in ["encoding", "errors", "newline"] + if k in kwargs } return io.TextIOWrapper( bytes_io, @@ -256,7 +277,9 @@ def guess_mime_type(file_name: str) -> Sequence[str]: type_ = list(mimetypes.guess_type(posixpath.basename(file_name), strict=False)) if not type_[0]: - type_[0] = "application/" + (posixpath.splitext(file_name)[1][1:] or "octet-stream") + type_[0] = "application/" + ( + posixpath.splitext(file_name)[1][1:] or "octet-stream" + ) return type_ @@ -278,21 +301,25 @@ def glob_files( bucket_url_parsed = urlparse(bucket_url) # if this is a file path without a scheme - if not bucket_url_parsed.scheme or (os.path.isabs(bucket_url) and "\\" in bucket_url): + if not bucket_url_parsed.scheme or ( + os.path.isabs(bucket_url) and "\\" in bucket_url + ): # this is a file so create a proper file url bucket_url = pathlib.Path(bucket_url).absolute().as_uri() bucket_url_parsed = urlparse(bucket_url) bucket_url_no_schema = bucket_url_parsed._replace(scheme="", query="").geturl() bucket_url_no_schema = ( - bucket_url_no_schema[2:] if bucket_url_no_schema.startswith("//") else bucket_url_no_schema + bucket_url_no_schema[2:] + if bucket_url_no_schema.startswith("//") + else bucket_url_no_schema ) filter_url = posixpath.join(bucket_url_no_schema, file_glob) glob_result = fs_client.glob(filter_url, detail=True) if isinstance(glob_result, list): raise NotImplementedError( - "Cannot request details when using fsspec.glob. For adlfs (Azure) please use version" - " 2023.9.0 or later" + "Cannot request details when using fsspec.glob. For adlfs (Azure) please" + " use version 2023.9.0 or later" ) for file, md in glob_result.items(): diff --git a/dlt/common/storages/fsspecs/google_drive.py b/dlt/common/storages/fsspecs/google_drive.py index 3bc4b1d7d7..f851c6a871 100644 --- a/dlt/common/storages/fsspecs/google_drive.py +++ b/dlt/common/storages/fsspecs/google_drive.py @@ -11,7 +11,9 @@ from googleapiclient.discovery import build from googleapiclient.errors import HttpError except ModuleNotFoundError: - raise MissingDependencyException("GoogleDriveFileSystem", ["google-api-python-client"]) + raise MissingDependencyException( + "GoogleDriveFileSystem", ["google-api-python-client"] + ) try: from google.auth.credentials import AnonymousCredentials @@ -170,10 +172,15 @@ def _trash(self, file_id: str) -> None: file_id (str): The ID of the file to trash. """ file_metadata = {"trashed": True} - self.service.update(fileId=file_id, supportsAllDrives=True, body=file_metadata).execute() + self.service.update( + fileId=file_id, supportsAllDrives=True, body=file_metadata + ).execute() def rm( - self, path: str, recursive: Optional[bool] = True, maxdepth: Optional[int] = None + self, + path: str, + recursive: Optional[bool] = True, + maxdepth: Optional[int] = None, ) -> None: """Remove files or directories. @@ -219,7 +226,9 @@ def info(self, path: str, **kwargs: Any) -> Dict[str, Any]: file_id = self.path_to_file_id(path) return self._info_by_id(file_id, self._parent(path)) - def _info_by_id(self, file_id: str, path_prefix: Optional[str] = None) -> Dict[str, Any]: + def _info_by_id( + self, file_id: str, path_prefix: Optional[str] = None + ) -> Dict[str, Any]: response = self.service.get( fileId=file_id, fields=FILE_INFO_FIELDS, @@ -341,10 +350,14 @@ def path_to_file_id( else: sub_path = posixpath.join(*descendants) return self.path_to_file_id( - sub_path, parent_id=top_file_id, parent_path=posixpath.join(parent_path, file_name) + sub_path, + parent_id=top_file_id, + parent_path=posixpath.join(parent_path, file_name), ) - def _find_file_id_in_dir(self, file_name: str, dir_file_id: str, dir_path: str) -> Any: + def _find_file_id_in_dir( + self, file_name: str, dir_file_id: str, dir_path: str + ) -> Any: """Get the file ID of a file with a given name in a directory. Args: @@ -362,7 +375,9 @@ def _find_file_id_in_dir(self, file_name: str, dir_file_id: str, dir_path: str) possible_children.append(child["id"]) if len(possible_children) == 0: - raise FileNotFoundError(f"Directory {dir_file_id} has no child named {file_name}") + raise FileNotFoundError( + f"Directory {dir_file_id} has no child named {file_name}" + ) if len(possible_children) == 1: return possible_children[0] else: @@ -372,7 +387,9 @@ def _find_file_id_in_dir(self, file_name: str, dir_file_id: str, dir_path: str) "to file_id." ) - def _open(self, path: str, mode: Optional[str] = "rb", **kwargs: Any) -> "GoogleDriveFile": + def _open( + self, path: str, mode: Optional[str] = "rb", **kwargs: Any + ) -> "GoogleDriveFile": """Open a file. Args: @@ -387,7 +404,9 @@ def _open(self, path: str, mode: Optional[str] = "rb", **kwargs: Any) -> "Google return GoogleDriveFile(self, path, mode=mode, **kwargs) @staticmethod - def _file_info_from_response(file: Dict[str, Any], path_prefix: str = None) -> Dict[str, Any]: + def _file_info_from_response( + file: Dict[str, Any], path_prefix: str = None + ) -> Dict[str, Any]: """Create fsspec compatible file info""" ftype = "directory" if file.get("mimeType") == DIR_MIME_TYPE else "file" if path_prefix: @@ -436,7 +455,9 @@ def __init__( self.parent_id: str = None self.location = None - def _fetch_range(self, start: Optional[int] = None, end: Optional[int] = None) -> Any: + def _fetch_range( + self, start: Optional[int] = None, end: Optional[int] = None + ) -> Any: """Read data from Google Drive. Args: @@ -488,8 +509,13 @@ def _upload_chunk(self, final: Optional[bool] = False) -> bool: head["Content-Range"] = "bytes */%i" % self.offset data = None else: - head["Content-Range"] = "bytes %i-%i/*" % (self.offset, self.offset + length - 1) - head.update({"Content-Type": "application/octet-stream", "Content-Length": str(length)}) + head["Content-Range"] = "bytes %i-%i/*" % ( + self.offset, + self.offset + length - 1, + ) + head.update( + {"Content-Type": "application/octet-stream", "Content-Length": str(length)} + ) req = self.fs.service._http.request head, body = req(self.location, method="PUT", body=data, headers=head) status = int(head["status"]) diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 3b8af424ee..b3c0be7587 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -38,7 +38,10 @@ from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns from dlt.common.storages import FileStorage -from dlt.common.storages.exceptions import LoadPackageNotFound, CurrentLoadPackageStateNotAvailable +from dlt.common.storages.exceptions import ( + LoadPackageNotFound, + CurrentLoadPackageStateNotAvailable, +) from dlt.common.typing import DictStrAny, SupportsHumanize from dlt.common.utils import flatten_list_or_items from dlt.common.versioned_state import ( @@ -74,7 +77,9 @@ def generate_loadpackage_state_version_hash(state: TLoadPackageState) -> str: return generate_state_version_hash(state) -def bump_loadpackage_state_version_if_modified(state: TLoadPackageState) -> Tuple[int, str, str]: +def bump_loadpackage_state_version_if_modified( + state: TLoadPackageState, +) -> Tuple[int, str, str]: return bump_state_version_if_modified(state) @@ -167,7 +172,9 @@ def asdict(self) -> DictStrAny: def asstr(self, verbosity: int = 0) -> str: failed_msg = ( - "The job FAILED TERMINALLY and cannot be restarted." if self.failed_message else "" + "The job FAILED TERMINALLY and cannot be restarted." + if self.failed_message + else "" ) elapsed_msg = ( humanize.precisedelta(pendulum.duration(seconds=self.elapsed)) @@ -175,8 +182,8 @@ def asstr(self, verbosity: int = 0) -> str: else "---" ) msg = ( - f"Job: {self.job_file_info.job_id()}, table: {self.job_file_info.table_name} in" - f" {self.state}. " + f"Job: {self.job_file_info.job_id()}, table:" + f" {self.job_file_info.table_name} in {self.state}. " ) msg += ( f"File type: {self.job_file_info.file_format}, size:" @@ -244,9 +251,9 @@ def asstr(self, verbosity: int = 0) -> str: else "The package is NOT YET LOADED to the destination" ) msg = ( - f"The package with load id {self.load_id} for schema {self.schema_name} is in" - f" {self.state.upper()} state. It updated schema for {len(self.schema_update)} tables." - f" {completed_msg}.\n" + f"The package with load id {self.load_id} for schema {self.schema_name} is" + f" in {self.state.upper()} state. It updated schema for" + f" {len(self.schema_update)} tables. {completed_msg}.\n" ) msg += "Jobs details:\n" msg += "\n".join(job.asstr(verbosity) for job in flatten_list_or_items(iter(self.jobs.values()))) # type: ignore @@ -319,7 +326,9 @@ def list_failed_jobs(self, load_id: str) -> Sequence[str]: self.get_job_folder_path(load_id, PackageStorage.FAILED_JOBS_FOLDER) ) - def list_jobs_for_table(self, load_id: str, table_name: str) -> Sequence[LoadJobInfo]: + def list_jobs_for_table( + self, load_id: str, table_name: str + ) -> Sequence[LoadJobInfo]: return self.filter_jobs_for_table(self.list_all_jobs(load_id), table_name) def list_all_jobs(self, load_id: str) -> Sequence[LoadJobInfo]: @@ -333,7 +342,9 @@ def list_failed_jobs_infos(self, load_id: str) -> Sequence[LoadJobInfo]: package_created_at = pendulum.from_timestamp( os.path.getmtime( self.storage.make_full_path( - os.path.join(package_path, PackageStorage.PACKAGE_COMPLETED_FILE_NAME) + os.path.join( + package_path, PackageStorage.PACKAGE_COMPLETED_FILE_NAME + ) ) ) ) @@ -352,14 +363,21 @@ def import_job( self, load_id: str, job_file_path: str, job_state: TJobState = "new_jobs" ) -> None: """Adds new job by moving the `job_file_path` into `new_jobs` of package `load_id`""" - self.storage.atomic_import(job_file_path, self.get_job_folder_path(load_id, job_state)) + self.storage.atomic_import( + job_file_path, self.get_job_folder_path(load_id, job_state) + ) def start_job(self, load_id: str, file_name: str) -> str: return self._move_job( - load_id, PackageStorage.NEW_JOBS_FOLDER, PackageStorage.STARTED_JOBS_FOLDER, file_name + load_id, + PackageStorage.NEW_JOBS_FOLDER, + PackageStorage.STARTED_JOBS_FOLDER, + file_name, ) - def fail_job(self, load_id: str, file_name: str, failed_message: Optional[str]) -> str: + def fail_job( + self, load_id: str, file_name: str, failed_message: Optional[str] + ) -> str: # save the exception to failed jobs if failed_message: self.storage.save( @@ -404,22 +422,33 @@ def complete_job(self, load_id: str, file_name: str) -> str: def create_package(self, load_id: str) -> None: self.storage.create_folder(load_id) # create processing directories - self.storage.create_folder(os.path.join(load_id, PackageStorage.NEW_JOBS_FOLDER)) - self.storage.create_folder(os.path.join(load_id, PackageStorage.COMPLETED_JOBS_FOLDER)) - self.storage.create_folder(os.path.join(load_id, PackageStorage.FAILED_JOBS_FOLDER)) - self.storage.create_folder(os.path.join(load_id, PackageStorage.STARTED_JOBS_FOLDER)) + self.storage.create_folder( + os.path.join(load_id, PackageStorage.NEW_JOBS_FOLDER) + ) + self.storage.create_folder( + os.path.join(load_id, PackageStorage.COMPLETED_JOBS_FOLDER) + ) + self.storage.create_folder( + os.path.join(load_id, PackageStorage.FAILED_JOBS_FOLDER) + ) + self.storage.create_folder( + os.path.join(load_id, PackageStorage.STARTED_JOBS_FOLDER) + ) # ensure created timestamp is set in state when load package is created state = self.get_load_package_state(load_id) if not state.get("created_at"): state["created_at"] = pendulum.now().to_iso8601_string() self.save_load_package_state(load_id, state) - def complete_loading_package(self, load_id: str, load_state: TLoadPackageStatus) -> str: + def complete_loading_package( + self, load_id: str, load_state: TLoadPackageStatus + ) -> str: """Completes loading the package by writing marker file with`package_state. Returns path to the completed package""" load_path = self.get_package_path(load_id) # save marker file self.storage.save( - os.path.join(load_path, PackageStorage.PACKAGE_COMPLETED_FILE_NAME), load_state + os.path.join(load_path, PackageStorage.PACKAGE_COMPLETED_FILE_NAME), + load_state, ) return load_path @@ -452,7 +481,9 @@ def schema_name(self, load_id: str) -> str: def save_schema(self, load_id: str, schema: Schema) -> str: # save a schema to a temporary load package dump = json.dumps(schema.to_dict()) - return self.storage.save(os.path.join(load_id, PackageStorage.SCHEMA_FILE_NAME), dump) + return self.storage.save( + os.path.join(load_id, PackageStorage.SCHEMA_FILE_NAME), dump + ) def save_schema_updates(self, load_id: str, schema_update: TSchemaTables) -> None: with self.storage.open_file( @@ -505,7 +536,9 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: applied_update: TSchemaTables = {} # check if package completed - completed_file_path = os.path.join(package_path, PackageStorage.PACKAGE_COMPLETED_FILE_NAME) + completed_file_path = os.path.join( + package_path, PackageStorage.PACKAGE_COMPLETED_FILE_NAME + ) if self.storage.has_file(completed_file_path): package_created_at = pendulum.from_timestamp( os.path.getmtime(self.storage.make_full_path(completed_file_path)) @@ -526,9 +559,13 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: jobs: List[LoadJobInfo] = [] with contextlib.suppress(FileNotFoundError): # we ignore if load package lacks one of working folders. completed_jobs may be deleted on archiving - for file in self.storage.list_folder_files(os.path.join(package_path, state)): + for file in self.storage.list_folder_files( + os.path.join(package_path, state) + ): if not file.endswith(".exception"): - jobs.append(self._read_job_file_info(state, file, package_created_at)) + jobs.append( + self._read_job_file_info(state, file, package_created_at) + ) all_jobs[state] = jobs return LoadPackageInfo( @@ -541,7 +578,9 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: all_jobs, ) - def _read_job_file_info(self, state: TJobState, file: str, now: DateTime = None) -> LoadJobInfo: + def _read_job_file_info( + self, state: TJobState, file: str, now: DateTime = None + ) -> LoadJobInfo: try: failed_message = self.storage.load(file + ".exception") except FileNotFoundError: @@ -553,7 +592,9 @@ def _read_job_file_info(self, state: TJobState, file: str, now: DateTime = None) full_path, st.st_size, pendulum.from_timestamp(st.st_mtime), - PackageStorage._job_elapsed_time_seconds(full_path, now.timestamp() if now else None), + PackageStorage._job_elapsed_time_seconds( + full_path, now.timestamp() if now else None + ), ParsedLoadJobFileName.parse(file), failed_message, ) @@ -574,7 +615,9 @@ def _move_job( assert file_name == FileStorage.get_file_name_from_file_path(file_name) load_path = self.get_package_path(load_id) dest_path = os.path.join(load_path, dest_folder, new_file_name or file_name) - self.storage.atomic_rename(os.path.join(load_path, source_folder, file_name), dest_path) + self.storage.atomic_rename( + os.path.join(load_path, source_folder, file_name), dest_path + ) # print(f"{join(load_path, source_folder, file_name)} -> {dest_path}") return self.storage.make_full_path(dest_path) diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index ffd55e7f29..8dc04594c9 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -21,12 +21,17 @@ TJobState, TLoadPackageState, ) -from dlt.common.storages.exceptions import JobWithUnsupportedWriterException, LoadPackageNotFound +from dlt.common.storages.exceptions import ( + JobWithUnsupportedWriterException, + LoadPackageNotFound, +) class LoadStorage(DataItemStorage, VersionedStorage): STORAGE_VERSION = "1.0.0" - NORMALIZED_FOLDER = "normalized" # folder within the volume where load packages are stored + NORMALIZED_FOLDER = ( # folder within the volume where load packages are stored + "normalized" + ) LOADED_FOLDER = "loaded" # folder to keep the loads that were completely processed NEW_PACKAGES_FOLDER = "new" # folder where new packages are created @@ -45,9 +50,14 @@ def __init__( supported_file_formats = list(supported_file_formats) supported_file_formats.append("jsonl") - if not LoadStorage.ALL_SUPPORTED_FILE_FORMATS.issuperset(supported_file_formats): + if not LoadStorage.ALL_SUPPORTED_FILE_FORMATS.issuperset( + supported_file_formats + ): raise TerminalValueError(supported_file_formats) - if preferred_file_format and preferred_file_format not in supported_file_formats: + if ( + preferred_file_format + and preferred_file_format not in supported_file_formats + ): raise TerminalValueError(preferred_file_format) self.supported_file_formats = supported_file_formats self.config = config @@ -61,13 +71,16 @@ def __init__( self.initialize_storage() # create package storages self.new_packages = PackageStorage( - FileStorage(join(config.load_volume_path, LoadStorage.NEW_PACKAGES_FOLDER)), "new" + FileStorage(join(config.load_volume_path, LoadStorage.NEW_PACKAGES_FOLDER)), + "new", ) self.normalized_packages = PackageStorage( - FileStorage(join(config.load_volume_path, LoadStorage.NORMALIZED_FOLDER)), "normalized" + FileStorage(join(config.load_volume_path, LoadStorage.NORMALIZED_FOLDER)), + "normalized", ) self.loaded_packages = PackageStorage( - FileStorage(join(config.load_volume_path, LoadStorage.LOADED_FOLDER)), "loaded" + FileStorage(join(config.load_volume_path, LoadStorage.LOADED_FOLDER)), + "loaded", ) def initialize_storage(self) -> None: @@ -75,7 +88,9 @@ def initialize_storage(self) -> None: self.storage.create_folder(LoadStorage.NORMALIZED_FOLDER, exists_ok=True) self.storage.create_folder(LoadStorage.LOADED_FOLDER, exists_ok=True) - def _get_data_item_path_template(self, load_id: str, _: str, table_name: str) -> str: + def _get_data_item_path_template( + self, load_id: str, _: str, table_name: str + ) -> str: # implements DataItemStorage._get_data_item_path_template file_name = PackageStorage.build_job_file_name(table_name, "%s") file_path = self.new_packages.get_job_file_path( @@ -91,17 +106,21 @@ def list_new_jobs(self, load_id: str) -> Sequence[str]: ( j for j in new_jobs - if ParsedLoadJobFileName.parse(j).file_format not in self.supported_file_formats + if ParsedLoadJobFileName.parse(j).file_format + not in self.supported_file_formats ), None, ) if wrong_job is not None: - raise JobWithUnsupportedWriterException(load_id, self.supported_file_formats, wrong_job) + raise JobWithUnsupportedWriterException( + load_id, self.supported_file_formats, wrong_job + ) return new_jobs def commit_new_load_package(self, load_id: str) -> None: self.storage.rename_tree( - self.get_new_package_path(load_id), self.get_normalized_package_path(load_id) + self.get_new_package_path(load_id), + self.get_normalized_package_path(load_id), ) def list_normalized_packages(self) -> Sequence[str]: @@ -122,7 +141,9 @@ def begin_schema_update(self, load_id: str) -> Optional[TSchemaTables]: raise FileNotFoundError(package_path) schema_update_file = join(package_path, PackageStorage.SCHEMA_UPDATES_FILE_NAME) if self.storage.has_file(schema_update_file): - schema_update: TSchemaTables = json.loads(self.storage.load(schema_update_file)) + schema_update: TSchemaTables = json.loads( + self.storage.load(schema_update_file) + ) return schema_update else: return None @@ -161,7 +182,9 @@ def complete_load_package(self, load_id: str, aborted: bool) -> None: ) # move to completed completed_path = self.get_loaded_package_path(load_id) - self.storage.rename_tree(self.get_normalized_package_path(load_id), completed_path) + self.storage.rename_tree( + self.get_normalized_package_path(load_id), completed_path + ) def maybe_remove_completed_jobs(self, load_id: str) -> None: """Deletes completed jobs if delete_completed_jobs config flag is set. If package has failed jobs, nothing gets deleted.""" @@ -175,15 +198,20 @@ def wipe_normalized_packages(self) -> None: self.storage.delete_folder(self.NORMALIZED_FOLDER, recursively=True) def get_new_package_path(self, load_id: str) -> str: - return join(LoadStorage.NEW_PACKAGES_FOLDER, self.new_packages.get_package_path(load_id)) + return join( + LoadStorage.NEW_PACKAGES_FOLDER, self.new_packages.get_package_path(load_id) + ) def get_normalized_package_path(self, load_id: str) -> str: return join( - LoadStorage.NORMALIZED_FOLDER, self.normalized_packages.get_package_path(load_id) + LoadStorage.NORMALIZED_FOLDER, + self.normalized_packages.get_package_path(load_id), ) def get_loaded_package_path(self, load_id: str) -> str: - return join(LoadStorage.LOADED_FOLDER, self.loaded_packages.get_package_path(load_id)) + return join( + LoadStorage.LOADED_FOLDER, self.loaded_packages.get_package_path(load_id) + ) def get_load_package_info(self, load_id: str) -> LoadPackageInfo: """Gets information on normalized OR loaded package with given load_id, all jobs and their statuses.""" diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index 2b90b7c088..68ff31f435 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -17,11 +17,15 @@ class NormalizeStorage(VersionedStorage): STORAGE_VERSION: ClassVar[str] = "1.0.1" - EXTRACTED_FOLDER: ClassVar[str] = ( - "extracted" # folder within the volume where extracted files to be normalized are stored + EXTRACTED_FOLDER: ClassVar[ + str + ] = ( # folder within the volume where extracted files to be normalized are stored + "extracted" ) - @with_config(spec=NormalizeStorageConfiguration, sections=(known_sections.NORMALIZE,)) + @with_config( + spec=NormalizeStorageConfiguration, sections=(known_sections.NORMALIZE,) + ) def __init__( self, is_owner: bool, config: NormalizeStorageConfiguration = config.value ) -> None: @@ -34,7 +38,11 @@ def __init__( if is_owner: self.initialize_storage() self.extracted_packages = PackageStorage( - FileStorage(os.path.join(self.storage.storage_path, NormalizeStorage.EXTRACTED_FOLDER)), + FileStorage( + os.path.join( + self.storage.storage_path, NormalizeStorage.EXTRACTED_FOLDER + ) + ), "extracted", ) @@ -43,7 +51,9 @@ def initialize_storage(self) -> None: def list_files_to_normalize_sorted(self) -> Sequence[str]: """Gets all data files in extracted packages storage. This method is compatible with current and all past storages""" - root_dir = os.path.join(self.storage.storage_path, NormalizeStorage.EXTRACTED_FOLDER) + root_dir = os.path.join( + self.storage.storage_path, NormalizeStorage.EXTRACTED_FOLDER + ) with set_working_dir(root_dir): files = glob.glob("**/*", recursive=True) # return all files that are not schema files @@ -57,7 +67,9 @@ def list_files_to_normalize_sorted(self) -> Sequence[str]: ] ) - def migrate_storage(self, from_version: VersionInfo, to_version: VersionInfo) -> None: + def migrate_storage( + self, from_version: VersionInfo, to_version: VersionInfo + ) -> None: if from_version == "1.0.0" and from_version < to_version: # get files in storage if len(self.list_files_to_normalize_sorted()) > 0: @@ -65,8 +77,9 @@ def migrate_storage(self, from_version: VersionInfo, to_version: VersionInfo) -> self.storage.storage_path, from_version, to_version, - f"There are extract files in {NormalizeStorage.EXTRACTED_FOLDER} folder." - " Storage will not migrate automatically duo to possible data loss. Delete the" + "There are extract files in" + f" {NormalizeStorage.EXTRACTED_FOLDER} folder. Storage will not" + " migrate automatically duo to possible data loss. Delete the" " files or normalize it with dlt 0.3.x", ) from_version = semver.VersionInfo.parse("1.0.1") diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index 23b695b839..76370d9d52 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -59,7 +59,9 @@ def save_schema(self, schema: Schema) -> str: # check if there's schema to import if self.config.import_schema_path: try: - imported_schema = Schema.from_dict(self._load_import_schema(schema.name)) + imported_schema = Schema.from_dict( + self._load_import_schema(schema.name) + ) # link schema being saved to current imported schema so it will not overwrite this save when loaded schema._imported_version_hash = imported_schema.stored_version_hash except FileNotFoundError: @@ -100,7 +102,9 @@ def __iter__(self) -> Iterator[str]: def __contains__(self, name: str) -> bool: # type: ignore return name in self.list_schemas() - def _maybe_import_schema(self, name: str, storage_schema: DictStrAny = None) -> Schema: + def _maybe_import_schema( + self, name: str, storage_schema: DictStrAny = None + ) -> Schema: rv_schema: Schema = None try: imported_schema = self._load_import_schema(name) @@ -110,9 +114,9 @@ def _maybe_import_schema(self, name: str, storage_schema: DictStrAny = None) -> # store import hash to self to track changes rv_schema._imported_version_hash = rv_schema.version_hash logger.info( - f"Schema {name} not present in {self.storage.storage_path} and got imported" - f" with version {rv_schema.stored_version} and imported hash" - f" {rv_schema._imported_version_hash}" + f"Schema {name} not present in {self.storage.storage_path} and got" + f" imported with version {rv_schema.stored_version} and imported" + f" hash {rv_schema._imported_version_hash}" ) # if schema was imported, overwrite storage schema self._save_schema(rv_schema) @@ -126,9 +130,9 @@ def _maybe_import_schema(self, name: str, storage_schema: DictStrAny = None) -> rv_schema.replace_schema_content(i_s, link_to_replaced_schema=True) rv_schema._imported_version_hash = i_s.version_hash logger.info( - f"Schema {name} was present in {self.storage.storage_path} but is" - f" overwritten with imported schema version {i_s.version} and" - f" imported hash {i_s.version_hash}" + f"Schema {name} was present in {self.storage.storage_path} but" + " is overwritten with imported schema version" + f" {i_s.version} and imported hash {i_s.version_hash}" ) # if schema was imported, overwrite storage schema self._save_schema(rv_schema) @@ -165,7 +169,9 @@ def _export_schema(self, schema: Schema, export_path: str) -> None: raise ValueError(self.config.external_schema_format) export_storage = FileStorage(export_path, makedirs=True) - schema_file = self._file_name_in_store(schema.name, self.config.external_schema_format) + schema_file = self._file_name_in_store( + schema.name, self.config.external_schema_format + ) export_storage.save(schema_file, exported_schema_s) logger.info( f"Schema {schema.name} exported to {export_path} with version" @@ -185,13 +191,17 @@ def _save_schema(self, schema: Schema) -> str: @staticmethod def load_schema_file( - path: str, name: str, extensions: Tuple[TSchemaFileFormat, ...] = SchemaFileExtensions + path: str, + name: str, + extensions: Tuple[TSchemaFileFormat, ...] = SchemaFileExtensions, ) -> Schema: storage = FileStorage(path) for extension in extensions: file = SchemaStorage._file_name_in_store(name, extension) if storage.has_file(file): - parsed_schema = SchemaStorage._parse_schema_str(storage.load(file), extension) + parsed_schema = SchemaStorage._parse_schema_str( + storage.load(file), extension + ) schema = Schema.from_dict(parsed_schema) if schema.name != name: raise UnexpectedSchemaName(name, path, schema.name) diff --git a/dlt/common/storages/transactional_file.py b/dlt/common/storages/transactional_file.py index e5ee220904..c5a6092426 100644 --- a/dlt/common/storages/transactional_file.py +++ b/dlt/common/storages/transactional_file.py @@ -56,13 +56,16 @@ def __init__(self, path: str, fs: fsspec.AbstractFileSystem) -> None: path: The path to lock. fs: The fsspec file system. """ - proto = fs.protocol[0] if isinstance(fs.protocol, (list, tuple)) else fs.protocol + proto = ( + fs.protocol[0] if isinstance(fs.protocol, (list, tuple)) else fs.protocol + ) self.extract_mtime = MTIME_DISPATCH.get(proto, MTIME_DISPATCH["file"]) parsed_path = Path(path) if not parsed_path.is_absolute(): raise ValueError( - f"{path} is not absolute. Please pass only absolute paths to TransactionalFile" + f"{path} is not absolute. Please pass only absolute paths to" + " TransactionalFile" ) self.path = path if proto == "file": @@ -98,7 +101,9 @@ def _sync_locks(self) -> t.List[str]: output = [] now = pendulum.now() - for lock in self._fs.ls(posixpath.dirname(self.lock_path), refresh=True, detail=True): + for lock in self._fs.ls( + posixpath.dirname(self.lock_path), refresh=True, detail=True + ): name = lock["name"] if not name.startswith(self.lock_prefix): continue @@ -114,8 +119,8 @@ def _sync_locks(self) -> t.List[str]: output.append(name) if not output: raise RuntimeError( - f"When syncing locks for path {self.path} and lock {self.lock_path} no lock file" - " was found" + f"When syncing locks for path {self.path} and lock {self.lock_path} no" + " lock file was found" ) return output @@ -172,7 +177,9 @@ def acquire_lock( return True if jitter_mean > 0: - time.sleep(random.random() * jitter_mean) # Add jitter to avoid thundering herd + time.sleep( + random.random() * jitter_mean + ) # Add jitter to avoid thundering herd self.lock_path = f"{self.lock_prefix}.{lock_id()}" self._fs.touch(self.lock_path) locks = self._sync_locks() diff --git a/dlt/common/storages/versioned_storage.py b/dlt/common/storages/versioned_storage.py index 8e9a3eb88d..3e55f11194 100644 --- a/dlt/common/storages/versioned_storage.py +++ b/dlt/common/storages/versioned_storage.py @@ -3,14 +3,20 @@ import semver from dlt.common.storages.file_storage import FileStorage -from dlt.common.storages.exceptions import NoMigrationPathException, WrongStorageVersionException +from dlt.common.storages.exceptions import ( + NoMigrationPathException, + WrongStorageVersionException, +) class VersionedStorage: VERSION_FILE = ".version" def __init__( - self, version: Union[semver.VersionInfo, str], is_owner: bool, storage: FileStorage + self, + version: Union[semver.VersionInfo, str], + is_owner: bool, + storage: FileStorage, ) -> None: if isinstance(version, str): version = semver.VersionInfo.parse(version) @@ -22,7 +28,10 @@ def __init__( if existing_version > version: # version cannot be downgraded raise NoMigrationPathException( - storage.storage_path, existing_version, existing_version, version + storage.storage_path, + existing_version, + existing_version, + version, ) if is_owner: # only owner can migrate storage @@ -31,7 +40,10 @@ def __init__( migrated_version = self._load_version() if version != migrated_version: raise NoMigrationPathException( - storage.storage_path, existing_version, migrated_version, version + storage.storage_path, + existing_version, + migrated_version, + version, ) else: # we cannot use storage and we must wait for owner to upgrade it diff --git a/dlt/common/time.py b/dlt/common/time.py index d3c8f9746c..98e5d93a11 100644 --- a/dlt/common/time.py +++ b/dlt/common/time.py @@ -43,7 +43,9 @@ def timestamp_before(timestamp: float, max_inclusive: Optional[float]) -> bool: return timestamp <= (max_inclusive or FUTURE_TIMESTAMP) -def parse_iso_like_datetime(value: Any) -> Union[pendulum.DateTime, pendulum.Date, pendulum.Time]: +def parse_iso_like_datetime( + value: Any, +) -> Union[pendulum.DateTime, pendulum.Date, pendulum.Time]: """Parses ISO8601 string into pendulum datetime, date or time. Preserves timezone info. Note: naive datetimes will generated from string without timezone diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 99c2604cdf..a402d21cbd 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -238,12 +238,16 @@ def extract_inner_type(hint: Type[Any], preserve_new_types: bool = False) -> Typ return hint -def get_all_types_of_class_in_union(hint: Type[Any], cls: Type[TAny]) -> List[Type[TAny]]: +def get_all_types_of_class_in_union( + hint: Type[Any], cls: Type[TAny] +) -> List[Type[TAny]]: # hint is an Union that contains classes, return all classes that are a subclass or superclass of cls return [ t for t in get_args(hint) - if not is_typeddict(t) and inspect.isclass(t) and (issubclass(t, cls) or issubclass(cls, t)) + if not is_typeddict(t) + and inspect.isclass(t) + and (issubclass(t, cls) or issubclass(cls, t)) ] diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 4ddde87758..96b5cb3adc 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -178,7 +178,9 @@ def flatten_list_of_str_or_dicts(seq: Sequence[Union[StrAny, str]]) -> DictStrAn # return dicts -def flatten_list_or_items(_iter: Union[Iterable[TAny], Iterable[List[TAny]]]) -> Iterator[TAny]: +def flatten_list_or_items( + _iter: Union[Iterable[TAny], Iterable[List[TAny]]] +) -> Iterator[TAny]: for items in _iter: if isinstance(items, List): yield from items @@ -186,7 +188,9 @@ def flatten_list_or_items(_iter: Union[Iterable[TAny], Iterable[List[TAny]]]) -> yield items -def concat_strings_with_limit(strings: List[str], separator: str, limit: int) -> Iterator[str]: +def concat_strings_with_limit( + strings: List[str], separator: str, limit: int +) -> Iterator[str]: """ Generator function to concatenate strings. @@ -218,7 +222,9 @@ def concat_strings_with_limit(strings: List[str], separator: str, limit: int) -> start = i current_length = len(strings[i]) else: - current_length += len(strings[i]) + sep_len # accounts for the length of separator + current_length += ( + len(strings[i]) + sep_len + ) # accounts for the length of separator yield separator.join(strings[start:]) @@ -447,7 +453,9 @@ def is_inner_callable(f: AnyFun) -> bool: def obfuscate_pseudo_secret(pseudo_secret: str, pseudo_key: bytes) -> str: return base64.b64encode( - bytes([_a ^ _b for _a, _b in zip(pseudo_secret.encode("utf-8"), pseudo_key * 250)]) + bytes( + [_a ^ _b for _a, _b in zip(pseudo_secret.encode("utf-8"), pseudo_key * 250)] + ) ).decode() @@ -456,7 +464,8 @@ def reveal_pseudo_secret(obfuscated_secret: str, pseudo_key: bytes) -> str: [ _a ^ _b for _a, _b in zip( - base64.b64decode(obfuscated_secret.encode("ascii"), validate=True), pseudo_key * 250 + base64.b64decode(obfuscated_secret.encode("ascii"), validate=True), + pseudo_key * 250, ) ] ).decode("utf-8") @@ -500,10 +509,14 @@ def merge_row_counts(row_counts_1: RowCounts, row_counts_2: RowCounts) -> None: """merges row counts_2 into row_counts_1""" # only keys present in row_counts_2 are modifed for counter_name in row_counts_2.keys(): - row_counts_1[counter_name] = row_counts_1.get(counter_name, 0) + row_counts_2[counter_name] + row_counts_1[counter_name] = ( + row_counts_1.get(counter_name, 0) + row_counts_2[counter_name] + ) -def extend_list_deduplicated(original_list: List[Any], extending_list: Iterable[Any]) -> List[Any]: +def extend_list_deduplicated( + original_list: List[Any], extending_list: Iterable[Any] +) -> List[Any]: """extends the first list by the second, but does not add duplicates""" list_keys = set(original_list) for item in extending_list: @@ -538,7 +551,10 @@ def get_full_class_name(obj: Any) -> str: def get_exception_trace(exc: BaseException) -> ExceptionTrace: """Get exception trace and additional information for DltException(s)""" - trace: ExceptionTrace = {"message": str(exc), "exception_type": get_full_class_name(exc)} + trace: ExceptionTrace = { + "message": str(exc), + "exception_type": get_full_class_name(exc), + } if exc.__traceback__: tb_extract = traceback.extract_tb(exc.__traceback__) trace["stack_trace"] = traceback.format_list(tb_extract) @@ -563,7 +579,13 @@ def get_exception_trace(exc: BaseException) -> ExceptionTrace: except Exception: continue # extract special attrs - if k in ["load_id", "pipeline_name", "source_name", "resource_name", "job_id"]: + if k in [ + "load_id", + "pipeline_name", + "source_name", + "resource_name", + "job_id", + ]: trace[k] = v # type: ignore[literal-required] trace["exception_attrs"] = str_attrs diff --git a/dlt/common/validation.py b/dlt/common/validation.py index 4b54d6a29e..cb84e143de 100644 --- a/dlt/common/validation.py +++ b/dlt/common/validation.py @@ -98,8 +98,9 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: for ut in union_types ] raise DictValidationException( - f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__}." - f" One of these types expected: {', '.join(type_names)}.", + f"In {path}: field {pk} value {pv} has invalid type" + f" {type(pv).__name__}. One of these types expected:" + f" {', '.join(type_names)}.", path, pk, pv, @@ -108,13 +109,16 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: a_l = get_args(t) if pv not in a_l: raise DictValidationException( - f"In {path}: field {pk} value {pv} not in allowed {a_l}", path, pk, pv + f"In {path}: field {pk} value {pv} not in allowed {a_l}", + path, + pk, + pv, ) elif t in [int, bool, str, float]: if not isinstance(pv, t): raise DictValidationException( - f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while" - f" {t.__name__} is expected", + f"In {path}: field {pk} value {pv} has invalid type" + f" {type(pv).__name__} while {t.__name__} is expected", path, pk, pv, @@ -122,8 +126,8 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: elif is_typeddict(t): if not isinstance(pv, dict): raise DictValidationException( - f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while" - " dict is expected", + f"In {path}: field {pk} value {pv} has invalid type" + f" {type(pv).__name__} while dict is expected", path, pk, pv, @@ -132,8 +136,8 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: elif is_list_generic_type(t): if not isinstance(pv, list): raise DictValidationException( - f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while" - " list is expected", + f"In {path}: field {pk} value {pv} has invalid type" + f" {type(pv).__name__} while list is expected", path, pk, pv, @@ -145,8 +149,8 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: elif is_dict_generic_type(t): if not isinstance(pv, dict): raise DictValidationException( - f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__} while" - " dict is expected", + f"In {path}: field {pk} value {pv} has invalid type" + f" {type(pv).__name__} while dict is expected", path, pk, pv, @@ -156,7 +160,10 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: for d_k, d_v in pv.items(): if not isinstance(d_k, str): raise DictValidationException( - f"In {path}: field {pk} key {d_k} must be a string", path, pk, d_k + f"In {path}: field {pk} key {d_k} must be a string", + path, + pk, + d_k, ) verify_prop(f"{pk}[{d_k}]", d_v, d_v_t) elif t is Any: @@ -174,8 +181,8 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: if inspect.isclass(t): if not isinstance(pv, t): raise DictValidationException( - f"In {path}: field {pk} expect class {type_name} but got instance of" - f" {pv_type_name}", + f"In {path}: field {pk} expect class {type_name} but got" + f" instance of {pv_type_name}", path, pk, ) @@ -183,7 +190,8 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: # dropped, just __name__ can be used type_name = getattr(t, "__name__", str(t)) raise DictValidationException( - f"In {path}: field {pk} has expected type {type_name} which lacks validator", + f"In {path}: field {pk} has expected type {type_name} which lacks" + " validator", path, pk, ) diff --git a/dlt/common/versioned_state.py b/dlt/common/versioned_state.py index a051a6660c..bcb8116f10 100644 --- a/dlt/common/versioned_state.py +++ b/dlt/common/versioned_state.py @@ -13,7 +13,9 @@ class TVersionedState(TypedDict, total=False): _state_engine_version: int -def generate_state_version_hash(state: TVersionedState, exclude_attrs: List[str] = None) -> str: +def generate_state_version_hash( + state: TVersionedState, exclude_attrs: List[str] = None +) -> str: # generates hash out of stored schema content, excluding hash itself, version and local state state_copy = copy(state) exclude_attrs = exclude_attrs or [] diff --git a/dlt/common/warnings.py b/dlt/common/warnings.py index 9c62c69bf8..27a0d2bcaf 100644 --- a/dlt/common/warnings.py +++ b/dlt/common/warnings.py @@ -31,7 +31,9 @@ def __init__( super().__init__(message, *args) self.message = message.rstrip(".") self.since = ( - since if isinstance(since, semver.VersionInfo) else semver.parse_version_info(since) + since + if isinstance(since, semver.VersionInfo) + else semver.parse_version_info(since) ) if expected_due: expected_due = ( @@ -39,11 +41,14 @@ def __init__( if isinstance(expected_due, semver.VersionInfo) else semver.parse_version_info(expected_due) ) - self.expected_due = expected_due if expected_due is not None else self.since.bump_minor() + self.expected_due = ( + expected_due if expected_due is not None else self.since.bump_minor() + ) def __str__(self) -> str: message = ( - f"{self.message}. Deprecated in dlt {self.since} to be removed in {self.expected_due}." + f"{self.message}. Deprecated in dlt {self.since} to be removed in" + f" {self.expected_due}." ) return message @@ -51,7 +56,9 @@ def __str__(self) -> str: class Dlt04DeprecationWarning(DltDeprecationWarning): V04 = semver.parse_version_info("0.4.0") - def __init__(self, message: str, *args: typing.Any, expected_due: VersionString = None) -> None: + def __init__( + self, message: str, *args: typing.Any, expected_due: VersionString = None + ) -> None: super().__init__( message, *args, since=Dlt04DeprecationWarning.V04, expected_due=expected_due ) diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py index a920d336a2..2ded78978a 100644 --- a/dlt/destinations/decorators.py +++ b/dlt/destinations/decorators.py @@ -28,7 +28,14 @@ def destination( max_table_nesting: int = 0, spec: Type[CustomDestinationClientConfiguration] = None, ) -> Callable[ - [Callable[Concatenate[Union[TDataItems, str], TTableSchema, TDestinationCallableParams], Any]], + [ + Callable[ + Concatenate[ + Union[TDataItems, str], TTableSchema, TDestinationCallableParams + ], + Any, + ] + ], Callable[TDestinationCallableParams, _destination], ]: """A decorator that transforms a function that takes two positional arguments "table" and "items" and any number of keyword arguments with defaults @@ -61,12 +68,16 @@ def destination( def decorator( destination_callable: Callable[ - Concatenate[Union[TDataItems, str], TTableSchema, TDestinationCallableParams], Any + Concatenate[ + Union[TDataItems, str], TTableSchema, TDestinationCallableParams + ], + Any, ] ) -> Callable[TDestinationCallableParams, _destination]: @wraps(destination_callable) def wrapper( - *args: TDestinationCallableParams.args, **kwargs: TDestinationCallableParams.kwargs + *args: TDestinationCallableParams.args, + **kwargs: TDestinationCallableParams.kwargs, ) -> _destination: if args: logger.warning( diff --git a/dlt/destinations/exceptions.py b/dlt/destinations/exceptions.py index 5e6adb007d..6039e4c211 100644 --- a/dlt/destinations/exceptions.py +++ b/dlt/destinations/exceptions.py @@ -38,10 +38,10 @@ def __init__( self.dataset_name = dataset_name self.inner_exc = inner_exc super().__init__( - f"Connection with {client_type} to dataset name {dataset_name} failed. Please check if" - " you configured the credentials at all and provided the right credentials values. You" - " can be also denied access or your internet connection may be down. The actual reason" - f" given is: {reason}" + f"Connection with {client_type} to dataset name {dataset_name} failed." + " Please check if you configured the credentials at all and provided the" + " right credentials values. You can be also denied access or your internet" + f" connection may be down. The actual reason given is: {reason}" ) @@ -50,8 +50,9 @@ def __init__(self, client_type: str, dataset_name: str) -> None: self.client_type = client_type self.dataset_name = dataset_name super().__init__( - f"Connection with {client_type} to dataset {dataset_name} is closed. Open the" - " connection with 'client.open_connection' or with the 'with client:' statement" + f"Connection with {client_type} to dataset {dataset_name} is closed. Open" + " the connection with 'client.open_connection' or with the 'with client:'" + " statement" ) @@ -72,7 +73,8 @@ def __init__(self, job_id: str) -> None: class LoadJobTerminalException(DestinationTerminalException): def __init__(self, file_path: str, message: str) -> None: super().__init__( - f"Job with id/file name {file_path} encountered unrecoverable problem: {message}" + f"Job with id/file name {file_path} encountered unrecoverable problem:" + f" {message}" ) @@ -86,14 +88,18 @@ def __init__(self, from_state: TLoadJobState, to_state: TLoadJobState) -> None: class LoadJobFileTooBig(DestinationTerminalException): def __init__(self, file_name: str, max_size: int) -> None: super().__init__( - f"File {file_name} exceeds {max_size} and cannot be loaded. Split the file and try" - " again." + f"File {file_name} exceeds {max_size} and cannot be loaded. Split the file" + " and try again." ) class MergeDispositionException(DestinationTerminalException): def __init__( - self, dataset_name: str, staging_dataset_name: str, tables: Sequence[str], reason: str + self, + dataset_name: str, + staging_dataset_name: str, + tables: Sequence[str], + reason: str, ) -> None: self.dataset_name = dataset_name self.staging_dataset_name = staging_dataset_name @@ -101,11 +107,12 @@ def __init__( self.reason = reason msg = ( f"Merge sql job for dataset name {dataset_name}, staging dataset name" - f" {staging_dataset_name} COULD NOT BE GENERATED. Merge will not be performed. " + f" {staging_dataset_name} COULD NOT BE GENERATED. Merge will not be" + " performed. " ) msg += ( - f"Data for the following tables ({tables}) is loaded to staging dataset. You may need" - " to write your own materialization. The reason is:\n" + f"Data for the following tables ({tables}) is loaded to staging dataset." + " You may need to write your own materialization. The reason is:\n" ) msg += reason super().__init__(msg) @@ -114,7 +121,9 @@ def __init__( class InvalidFilesystemLayout(DestinationTerminalException): def __init__(self, invalid_placeholders: Sequence[str]) -> None: self.invalid_placeholders = invalid_placeholders - super().__init__(f"Invalid placeholders found in filesystem layout: {invalid_placeholders}") + super().__init__( + f"Invalid placeholders found in filesystem layout: {invalid_placeholders}" + ) class CantExtractTablePrefix(DestinationTerminalException): diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index b323832418..4ae130b384 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -22,7 +22,13 @@ import pyathena from pyathena import connect from pyathena.connection import Connection -from pyathena.error import OperationalError, DatabaseError, ProgrammingError, IntegrityError, Error +from pyathena.error import ( + OperationalError, + DatabaseError, + ProgrammingError, + IntegrityError, + Error, +) from pyathena.formatter import ( DefaultParameterFormatter, _DEFAULT_FORMATTERS, @@ -34,11 +40,20 @@ from dlt.common.utils import without_none from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, Schema, TSchemaTables, TTableSchema -from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition, TTableFormat +from dlt.common.schema.typing import ( + TTableSchema, + TColumnType, + TWriteDisposition, + TTableFormat, +) from dlt.common.schema.utils import table_schema_has_type, get_table_format from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import LoadJob, DoNothingFollowupJob, DoNothingJob -from dlt.common.destination.reference import TLoadJobState, NewLoadJob, SupportsStagingDestination +from dlt.common.destination.reference import ( + TLoadJobState, + NewLoadJob, + SupportsStagingDestination, +) from dlt.common.storages import FileStorage from dlt.common.data_writers.escape import escape_bigquery_identifier from dlt.destinations.sql_jobs import SqlStagingCopyJob @@ -120,7 +135,9 @@ def from_db_type( # add a formatter for pendulum to be used by pyathen dbapi -def _format_pendulum_datetime(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any: +def _format_pendulum_datetime( + formatter: Formatter, escaper: Callable[[str], str], val: Any +) -> Any: # copied from https://github.com/laughingman7743/PyAthena/blob/f4b21a0b0f501f5c3504698e25081f491a541d4e/pyathena/formatter.py#L114 # https://docs.aws.amazon.com/athena/latest/ug/engine-versions-reference-0003.html#engine-versions-reference-0003-timestamp-changes # ICEBERG tables have TIMESTAMP(6), other tables have TIMESTAMP(3), we always generate TIMESTAMP(6) @@ -145,7 +162,9 @@ def __init__(self) -> None: formatters[datetime] = _format_pendulum_datetime formatters[Date] = _format_date - super(DefaultParameterFormatter, self).__init__(mappings=formatters, default=None) + super(DefaultParameterFormatter, self).__init__( + mappings=formatters, default=None + ) DLTAthenaFormatter._INSTANCE = self @@ -198,18 +217,23 @@ def create_dataset(self) -> None: self.execute_sql(f"CREATE DATABASE {self.fully_qualified_ddl_dataset_name()};") def drop_dataset(self) -> None: - self.execute_sql(f"DROP DATABASE {self.fully_qualified_ddl_dataset_name()} CASCADE;") + self.execute_sql( + f"DROP DATABASE {self.fully_qualified_ddl_dataset_name()} CASCADE;" + ) def fully_qualified_dataset_name(self, escape: bool = True) -> str: return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + self.capabilities.escape_identifier(self.dataset_name) + if escape + else self.dataset_name ) def drop_tables(self, *tables: str) -> None: if not tables: return statements = [ - f"DROP TABLE IF EXISTS {self.make_qualified_ddl_table_name(table)};" for table in tables + f"DROP TABLE IF EXISTS {self.make_qualified_ddl_table_name(table)};" + for table in tables ] self.execute_many(statements) @@ -217,7 +241,8 @@ def drop_tables(self, *tables: str) -> None: @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: logger.warning( - "Athena does not support transactions! Each SQL statement is auto-committed separately." + "Athena does not support transactions! Each SQL statement is auto-committed" + " separately." ) yield self @@ -276,7 +301,9 @@ def _convert_to_old_pyformat( @contextmanager @raise_database_error - def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: + def execute_query( + self, query: AnyStr, *args: Any, **kwargs: Any + ) -> Iterator[DBApiCursor]: assert isinstance(query, str) db_args = kwargs # convert sql and params to PyFormat, as athena does not support anything else @@ -329,13 +356,18 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(hive_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql( + self, c: TColumnSchema, table_format: TTableFormat = None + ) -> str: return ( f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" ) def _get_table_update_sql( - self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + self, + table_name: str, + new_columns: Sequence[TColumnSchema], + generate_alter: bool, ) -> List[str]: bucket = self.config.staging_config.bucket_url dataset = self.sql_client.dataset_name @@ -346,8 +378,13 @@ def _get_table_update_sql( # or if we are in iceberg mode, we create iceberg tables for all tables table = self.prepare_load_table(table_name, self.in_staging_mode) table_format = table.get("table_format") - is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip" - columns = ", ".join([self._get_column_def_sql(c, table_format) for c in new_columns]) + is_iceberg = ( + self._is_iceberg_table(table) + or table.get("write_disposition", None) == "skip" + ) + columns = ", ".join( + [self._get_column_def_sql(c, table_format) for c in new_columns] + ) # this will fail if the table prefix is not properly defined table_prefix = self.table_prefix_layout.format(table_name=table_name) @@ -357,7 +394,9 @@ def _get_table_update_sql( qualified_table_name = self.sql_client.make_qualified_ddl_table_name(table_name) if generate_alter: # alter table to add new columns at the end - sql.append(f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""") + sql.append( + f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""" + ) else: if is_iceberg: sql.append(f"""CREATE TABLE {qualified_table_name} @@ -376,13 +415,16 @@ def _get_table_update_sql( LOCATION '{location}';""") return sql - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" if table_schema_has_type(table, "time"): raise LoadJobTerminalException( file_path, "Athena cannot load TIME columns from parquet tables. Please convert" - " `datetime.time` objects in your data to `str` or `datetime.datetime`.", + " `datetime.time` objects in your data to `str` or" + " `datetime.datetime`.", ) job = super().start_file_load(table, file_path, load_id) if not job: @@ -393,10 +435,14 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> ) return job - def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_append_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ - SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": False}) + SqlStagingCopyJob.from_table_chain( + table_chain, self.sql_client, {"replace": False} + ) ] return super()._create_append_followup_jobs(table_chain) @@ -405,11 +451,15 @@ def _create_replace_followup_jobs( ) -> List[NewLoadJob]: if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ - SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True}) + SqlStagingCopyJob.from_table_chain( + table_chain, self.sql_client, {"replace": True} + ) ] return super()._create_replace_followup_jobs(table_chain) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: # fall back to append jobs for merge return self._create_append_followup_jobs(table_chain) @@ -423,7 +473,9 @@ def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: return True return super().should_load_data_to_staging_dataset(table) - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination( + self, table: TTableSchema + ) -> bool: # on athena we only truncate replace tables that are not iceberg table = self.prepare_load_table(table["name"]) if table["write_disposition"] == "replace" and not self._is_iceberg_table( diff --git a/dlt/destinations/impl/athena/configuration.py b/dlt/destinations/impl/athena/configuration.py index 59dfeee4ec..0b606112ed 100644 --- a/dlt/destinations/impl/athena/configuration.py +++ b/dlt/destinations/impl/athena/configuration.py @@ -2,7 +2,9 @@ from typing import ClassVar, Final, List, Optional from dlt.common.configuration import configspec -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.reference import ( + DestinationClientDwhWithStagingConfiguration, +) from dlt.common.configuration.specs import AwsCredentials diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index d4261a1636..63d7f515f1 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -38,7 +38,10 @@ TABLE_EXPIRATION_HINT, ) from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration -from dlt.destinations.impl.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS +from dlt.destinations.impl.bigquery.sql_client import ( + BigQuerySqlClient, + BQ_TERMINAL_REASONS, +) from dlt.destinations.job_client_impl import SqlJobClientWithStaging from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_jobs import SqlMergeJob @@ -108,9 +111,14 @@ def __init__( super().__init__(file_name) def state(self) -> TLoadJobState: - if not self.bq_load_job.done(retry=self.default_retry, timeout=self.http_timeout): + if not self.bq_load_job.done( + retry=self.default_retry, timeout=self.http_timeout + ): return "running" - if self.bq_load_job.output_rows is not None and self.bq_load_job.error_result is None: + if ( + self.bq_load_job.output_rows is not None + and self.bq_load_job.error_result is None + ): return "completed" reason = self.bq_load_job.error_result.get("reason") if reason in BQ_TERMINAL_REASONS: @@ -157,8 +165,8 @@ def gen_key_table_clauses( for_delete: bool, ) -> List[str]: sql: List[str] = [ - f"FROM {root_table_name} AS d WHERE EXISTS (SELECT 1 FROM {staging_root_table_name} AS" - f" s WHERE {clause.format(d='d', s='s')})" + f"FROM {root_table_name} AS d WHERE EXISTS (SELECT 1 FROM" + f" {staging_root_table_name} AS s WHERE {clause.format(d='d', s='s')})" for clause in key_clauses ] return sql @@ -180,7 +188,9 @@ def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: self.sql_client: BigQuerySqlClient = sql_client # type: ignore self.type_mapper = BigQueryTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: return [BigQueryMergeJob.from_table_chain(table_chain, self.sql_client)] def restore_file_load(self, file_path: str) -> LoadJob: @@ -216,7 +226,9 @@ def restore_file_load(self, file_path: str) -> LoadJob: raise DestinationTransientException(gace) from gace return job - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: job = super().start_file_load(table, file_path, load_id) if not job: @@ -246,7 +258,10 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> return job def _get_table_update_sql( - self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + self, + table_name: str, + new_columns: Sequence[TColumnSchema], + generate_alter: bool, ) -> List[str]: table: Optional[TTableSchema] = self.prepare_load_table(table_name) sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) @@ -256,15 +271,23 @@ def _get_table_update_sql( c for c in new_columns if c.get("partition") or c.get(PARTITION_HINT, False) ]: if len(partition_list) > 1: - col_names = [self.capabilities.escape_identifier(c["name"]) for c in partition_list] + col_names = [ + self.capabilities.escape_identifier(c["name"]) + for c in partition_list + ] raise DestinationSchemaWillNotUpdate( - canonical_name, col_names, "Partition requested for more than one column" + canonical_name, + col_names, + "Partition requested for more than one column", ) elif (c := partition_list[0])["data_type"] == "date": - sql[0] += f"\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}" + sql[ + 0 + ] += f"\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}" elif (c := partition_list[0])["data_type"] == "timestamp": sql[0] = ( - f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" + f"{sql[0]}\nPARTITION BY" + f" DATE({self.capabilities.escape_identifier(c['name'])})" ) # Automatic partitioning of an INT64 type requires us to be prescriptive - we treat the column as a UNIX timestamp. # This is due to the bounds requirement of GENERATE_ARRAY function for partitioning. @@ -273,7 +296,8 @@ def _get_table_update_sql( # See: https://dlthub.com/devel/dlt-ecosystem/destinations/bigquery#supported-column-hints elif (c := partition_list[0])["data_type"] == "bigint": sql[0] += ( - f"\nPARTITION BY RANGE_BUCKET({self.capabilities.escape_identifier(c['name'])}," + "\nPARTITION BY" + f" RANGE_BUCKET({self.capabilities.escape_identifier(c['name'])}," " GENERATE_ARRAY(-172800000, 691200000, 86400))" ) @@ -306,7 +330,11 @@ def _get_table_update_sql( sql[0] += ( "\nOPTIONS (" + ", ".join( - [f"{key}={value}" for key, value in table_options.items() if value is not None] + [ + f"{key}={value}" + for key, value in table_options.items() + if value is not None + ] ) + ")" ) @@ -321,12 +349,17 @@ def prepare_load_table( if TABLE_DESCRIPTION_HINT not in table: table[TABLE_DESCRIPTION_HINT] = ( # type: ignore[name-defined, typeddict-unknown-key, unused-ignore] get_inherited_table_hint( - self.schema.tables, table_name, TABLE_DESCRIPTION_HINT, allow_none=True + self.schema.tables, + table_name, + TABLE_DESCRIPTION_HINT, + allow_none=True, ) ) return table - def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql( + self, column: TColumnSchema, table_format: TTableFormat = None + ) -> str: name = self.capabilities.escape_identifier(column["name"]) column_def_sql = ( f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}" @@ -345,7 +378,9 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] retry=self.sql_client._default_retry, timeout=self.config.http_timeout, ) - partition_field = table.time_partitioning.field if table.time_partitioning else None + partition_field = ( + table.time_partitioning.field if table.time_partitioning else None + ) for c in table.schema: schema_c: TColumnSchema = { "name": c.name, @@ -382,7 +417,8 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load if table_schema_has_type(table, "complex"): raise LoadJobTerminalException( file_path, - "Bigquery cannot load into JSON data type from parquet. Use jsonl instead.", + "Bigquery cannot load into JSON data type from parquet. Use jsonl" + " instead.", ) source_format = bigquery.SourceFormat.PARQUET # parquet needs NUMERIC type auto-detection diff --git a/dlt/destinations/impl/bigquery/bigquery_adapter.py b/dlt/destinations/impl/bigquery/bigquery_adapter.py index 1d630e9802..aced2b7993 100644 --- a/dlt/destinations/impl/bigquery/bigquery_adapter.py +++ b/dlt/destinations/impl/bigquery/bigquery_adapter.py @@ -17,9 +17,15 @@ ROUND_HALF_AWAY_FROM_ZERO_HINT: Literal["x-bigquery-round-half-away-from-zero"] = ( "x-bigquery-round-half-away-from-zero" ) -ROUND_HALF_EVEN_HINT: Literal["x-bigquery-round-half-even"] = "x-bigquery-round-half-even" -TABLE_EXPIRATION_HINT: Literal["x-bigquery-table-expiration"] = "x-bigquery-table-expiration" -TABLE_DESCRIPTION_HINT: Literal["x-bigquery-table-description"] = "x-bigquery-table-description" +ROUND_HALF_EVEN_HINT: Literal["x-bigquery-round-half-even"] = ( + "x-bigquery-round-half-even" +) +TABLE_EXPIRATION_HINT: Literal["x-bigquery-table-expiration"] = ( + "x-bigquery-table-expiration" +) +TABLE_DESCRIPTION_HINT: Literal["x-bigquery-table-description"] = ( + "x-bigquery-table-description" +) def bigquery_adapter( @@ -87,7 +93,8 @@ def bigquery_adapter( cluster = [cluster] if not isinstance(cluster, list): raise ValueError( - "`cluster` must be a list of column names or a single column name as a string." + "`cluster` must be a list of column names or a single column name as a" + " string." ) for column_name in cluster: column_hints[column_name] = {"name": column_name, CLUSTER_HINT: True} # type: ignore[typeddict-unknown-key] @@ -98,8 +105,8 @@ def bigquery_adapter( round_half_away_from_zero = [round_half_away_from_zero] if not isinstance(round_half_away_from_zero, list): raise ValueError( - "`round_half_away_from_zero` must be a list of column names or a single column" - " name." + "`round_half_away_from_zero` must be a list of column names or a single" + " column name." ) for column_name in round_half_away_from_zero: column_hints[column_name] = {"name": column_name, ROUND_HALF_AWAY_FROM_ZERO_HINT: True} # type: ignore[typeddict-unknown-key] @@ -109,7 +116,8 @@ def bigquery_adapter( round_half_even = [round_half_even] if not isinstance(round_half_even, list): raise ValueError( - "`round_half_even` must be a list of column names or a single column name." + "`round_half_even` must be a list of column names or a single column" + " name." ) for column_name in round_half_even: column_hints[column_name] = {"name": column_name, ROUND_HALF_EVEN_HINT: True} # type: ignore[typeddict-unknown-key] @@ -119,37 +127,41 @@ def bigquery_adapter( set(round_half_even) ): raise ValueError( - f"Columns `{intersection_columns}` are present in both `round_half_away_from_zero`" - " and `round_half_even` which is not allowed. They must be mutually exclusive." + f"Columns `{intersection_columns}` are present in both" + " `round_half_away_from_zero` and `round_half_even` which is not" + " allowed. They must be mutually exclusive." ) if table_description: if not isinstance(table_description, str): raise ValueError( - "`table_description` must be string representing BigQuery table description." + "`table_description` must be string representing BigQuery table" + " description." ) additional_table_hints |= {TABLE_DESCRIPTION_HINT: table_description} # type: ignore[operator] if table_expiration_datetime: if not isinstance(table_expiration_datetime, str): raise ValueError( - "`table_expiration_datetime` must be string representing the datetime when the" - " BigQuery table." + "`table_expiration_datetime` must be string representing the datetime" + " when the BigQuery table." ) try: - parsed_table_expiration_datetime = parser.parse(table_expiration_datetime).replace( - tzinfo=timezone.utc - ) + parsed_table_expiration_datetime = parser.parse( + table_expiration_datetime + ).replace(tzinfo=timezone.utc) additional_table_hints |= {TABLE_EXPIRATION_HINT: parsed_table_expiration_datetime} # type: ignore[operator] except ValueError as e: raise ValueError(f"{table_expiration_datetime} could not be parsed!") from e if column_hints or additional_table_hints: - resource.apply_hints(columns=column_hints, additional_table_hints=additional_table_hints) + resource.apply_hints( + columns=column_hints, additional_table_hints=additional_table_hints + ) else: raise ValueError( "AT LEAST one of `partition`, `cluster`, `round_half_away_from_zero`," - " `round_half_even`, `table_description` or `table_expiration_datetime` must be" - " specified." + " `round_half_even`, `table_description` or `table_expiration_datetime`" + " must be specified." ) return resource diff --git a/dlt/destinations/impl/bigquery/configuration.py b/dlt/destinations/impl/bigquery/configuration.py index a6686c3f2d..158addf373 100644 --- a/dlt/destinations/impl/bigquery/configuration.py +++ b/dlt/destinations/impl/bigquery/configuration.py @@ -6,7 +6,9 @@ from dlt.common.configuration.specs import GcpServiceAccountCredentials from dlt.common.utils import digest128 -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.reference import ( + DestinationClientDwhWithStagingConfiguration, +) @configspec @@ -16,7 +18,9 @@ class BigQueryClientConfiguration(DestinationClientDwhWithStagingConfiguration): location: str = "US" http_timeout: float = 15.0 # connection timeout for http request to BigQuery api - file_upload_timeout: float = 30 * 60.0 # a timeout for file upload when loading local files + file_upload_timeout: float = ( + 30 * 60.0 + ) # a timeout for file upload when loading local files retry_deadline: float = ( 60.0 # how long to retry the operation in case of error, the backoff 60 s. ) @@ -29,8 +33,9 @@ def get_location(self) -> str: # default was changed in credentials, emit deprecation message if self.credentials.location != "US": warnings.warn( - "Setting BigQuery location in the credentials is deprecated. Please set the" - " location directly in bigquery section ie. destinations.bigquery.location='EU'" + "Setting BigQuery location in the credentials is deprecated. Please set" + " the location directly in bigquery section ie." + " destinations.bigquery.location='EU'" ) return self.credentials.location diff --git a/dlt/destinations/impl/bigquery/sql_client.py b/dlt/destinations/impl/bigquery/sql_client.py index 95cb7ea73b..1445c8c6ed 100644 --- a/dlt/destinations/impl/bigquery/sql_client.py +++ b/dlt/destinations/impl/bigquery/sql_client.py @@ -5,7 +5,10 @@ from google.api_core import exceptions as api_core_exceptions from google.cloud import exceptions as gcp_exceptions from google.cloud.bigquery import dbapi as bq_dbapi -from google.cloud.bigquery.dbapi import Connection as DbApiConnection, Cursor as BQDbApiCursor +from google.cloud.bigquery.dbapi import ( + Connection as DbApiConnection, + Cursor as BQDbApiCursor, +) from google.cloud.bigquery.dbapi import exceptions as dbapi_exceptions from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults @@ -205,14 +208,18 @@ def execute_sql( @contextmanager @raise_database_error - def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: + def execute_query( + self, query: AnyStr, *args: Any, **kwargs: Any + ) -> Iterator[DBApiCursor]: conn: DbApiConnection = None db_args = args or (kwargs or None) try: conn = DbApiConnection(client=self._client) curr = conn.cursor() # if session exists give it a preference - curr.execute(query, db_args, job_config=self._session_query or self._default_query) + curr.execute( + query, db_args, job_config=self._session_query or self._default_query + ) yield BigQueryDBApiCursorImpl(curr) # type: ignore finally: if conn: @@ -221,7 +228,9 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB def fully_qualified_dataset_name(self, escape: bool = True) -> str: if escape: - project_id = self.capabilities.escape_identifier(self.credentials.project_id) + project_id = self.capabilities.escape_identifier( + self.credentials.project_id + ) dataset_name = self.capabilities.escape_identifier(self.dataset_name) else: project_id = self.credentials.project_id @@ -236,13 +245,19 @@ def _make_database_exception(cls, ex: Exception) -> Exception: cloud_ex = ex.args[0] reason = cls._get_reason_from_errors(cloud_ex) if reason is None: - if isinstance(ex, (dbapi_exceptions.DataError, dbapi_exceptions.IntegrityError)): + if isinstance( + ex, (dbapi_exceptions.DataError, dbapi_exceptions.IntegrityError) + ): return DatabaseTerminalException(ex) elif isinstance(ex, dbapi_exceptions.ProgrammingError): return DatabaseTransientException(ex) if reason == "notFound": return DatabaseUndefinedRelation(ex) - if reason == "invalidQuery" and "was not found" in str(ex) and "Dataset" in str(ex): + if ( + reason == "invalidQuery" + and "was not found" in str(ex) + and "Dataset" in str(ex) + ): return DatabaseUndefinedRelation(ex) if ( reason == "invalidQuery" @@ -263,7 +278,9 @@ def _make_database_exception(cls, ex: Exception) -> Exception: return DatabaseTransientException(ex) @staticmethod - def _get_reason_from_errors(gace: api_core_exceptions.GoogleAPICallError) -> Optional[str]: + def _get_reason_from_errors( + gace: api_core_exceptions.GoogleAPICallError, + ) -> Optional[str]: errors: List[StrAny] = getattr(gace, "errors", None) if errors and isinstance(errors, Sequence): return errors[0].get("reason") # type: ignore @@ -277,6 +294,6 @@ def is_dbapi_exception(ex: Exception) -> bool: class TransactionsNotImplementedError(NotImplementedError): def __init__(self) -> None: super().__init__( - "BigQuery does not support transaction management. Instead you may wrap your SQL script" - " in BEGIN TRANSACTION; ... COMMIT TRANSACTION;" + "BigQuery does not support transaction management. Instead you may wrap" + " your SQL script in BEGIN TRANSACTION; ... COMMIT TRANSACTION;" ) diff --git a/dlt/destinations/impl/databricks/__init__.py b/dlt/destinations/impl/databricks/__init__.py index 81884fae4b..07540cfc2a 100644 --- a/dlt/destinations/impl/databricks/__init__.py +++ b/dlt/destinations/impl/databricks/__init__.py @@ -1,5 +1,8 @@ from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_writers.escape import escape_databricks_identifier, escape_databricks_literal +from dlt.common.data_writers.escape import ( + escape_databricks_identifier, + escape_databricks_literal, +) from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 3bd2d12a5a..22b705265e 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -3,8 +3,13 @@ from dlt.common.typing import TSecretStrValue from dlt.common.configuration.exceptions import ConfigurationValueError -from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.configuration.specs.base_configuration import ( + CredentialsConfiguration, + configspec, +) +from dlt.common.destination.reference import ( + DestinationClientDwhWithStagingConfiguration, +) @configspec diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 07e827cd28..6edc4a7ad4 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -1,4 +1,15 @@ -from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any, Iterable, Type, cast +from typing import ( + ClassVar, + Dict, + Optional, + Sequence, + Tuple, + List, + Any, + Iterable, + Type, + cast, +) from urllib.parse import urlparse, urlunparse from dlt.common.destination import DestinationCapabilitiesContext @@ -18,7 +29,12 @@ from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables, TTableFormat +from dlt.common.schema.typing import ( + TTableSchema, + TColumnType, + TSchemaTables, + TTableFormat, +) from dlt.common.schema.utils import table_schema_has_type @@ -125,7 +141,9 @@ def __init__( else "" ) file_name = ( - FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + FileStorage.get_file_name_from_file_path(bucket_path) + if bucket_path + else file_name ) from_clause = "" credentials_clause = "" @@ -166,14 +184,14 @@ def __init__( else: raise LoadJobTerminalException( file_path, - f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and" - " azure buckets are supported", + f"Databricks cannot load data from staging bucket {bucket_path}." + " Only s3 and azure buckets are supported", ) else: raise LoadJobTerminalException( file_path, - "Cannot load from local file. Databricks does not support loading from local files." - " Configure staging with an s3 or azure storage bucket.", + "Cannot load from local file. Databricks does not support loading from" + " local files. Configure staging with an s3 or azure storage bucket.", ) # decide on source format, stage_file_path will either be a local file or a bucket path @@ -183,33 +201,34 @@ def __init__( if not config.get("data_writer.disable_compression"): raise LoadJobTerminalException( file_path, - "Databricks loader does not support gzip compressed JSON files. Please disable" + "Databricks loader does not support gzip compressed JSON files." + " Please disable" " compression in the data writer configuration:" " https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", ) if table_schema_has_type(table, "decimal"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load DECIMAL type columns from json files. Switch to" - " parquet format to load decimals.", + "Databricks loader cannot load DECIMAL type columns from json" + " files. Switch to parquet format to load decimals.", ) if table_schema_has_type(table, "binary"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load BINARY type columns from json files. Switch to" - " parquet format to load byte values.", + "Databricks loader cannot load BINARY type columns from json files." + " Switch to parquet format to load byte values.", ) if table_schema_has_type(table, "complex"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load complex columns (lists and dicts) from json" - " files. Switch to parquet format to load complex types.", + "Databricks loader cannot load complex columns (lists and dicts)" + " from json files. Switch to parquet format to load complex types.", ) if table_schema_has_type(table, "date"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load DATE type columns from json files. Switch to" - " parquet format to load dates.", + "Databricks loader cannot load DATE type columns from json files." + " Switch to parquet format to load dates.", ) source_format = "JSON" @@ -242,7 +261,11 @@ def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: @classmethod def gen_delete_from_sql( - cls, table_name: str, column_name: str, temp_table_name: str, temp_table_column: str + cls, + table_name: str, + column_name: str, + temp_table_name: str, + temp_table_column: str, ) -> str: # Databricks does not support subqueries in DELETE FROM statements so we use a MERGE statement instead return f"""MERGE INTO {table_name} @@ -256,13 +279,17 @@ class DatabricksClient(InsertValuesJobClient, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: DatabricksClientConfiguration) -> None: - sql_client = DatabricksSqlClient(config.normalize_dataset_name(schema), config.credentials) + sql_client = DatabricksSqlClient( + config.normalize_dataset_name(schema), config.credentials + ) super().__init__(schema, config, sql_client) self.config: DatabricksClientConfiguration = config self.sql_client: DatabricksSqlClient = sql_client self.type_mapper = DatabricksTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: job = super().start_file_load(table, file_path, load_id) if not job: @@ -272,21 +299,28 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> table["name"], load_id, self.sql_client, - staging_config=cast(FilesystemConfiguration, self.config.staging_config), + staging_config=cast( + FilesystemConfiguration, self.config.staging_config + ), ) return job def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None ) -> List[str]: # Override because databricks requires multiple columns in a single ADD COLUMN clause - return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] + return [ + "ADD COLUMN\n" + + ",\n".join(self._get_column_def_sql(c) for c in new_columns) + ] def _get_table_update_sql( self, @@ -298,7 +332,9 @@ def _get_table_update_sql( sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) cluster_list = [ - self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("cluster") + self.capabilities.escape_identifier(c["name"]) + for c in new_columns + if c.get("cluster") ] if cluster_list: @@ -311,7 +347,9 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql( + self, c: TColumnSchema, table_format: TTableFormat = None + ) -> str: name = self.capabilities.escape_identifier(c["name"]) return ( f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" @@ -319,7 +357,9 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non def _get_storage_table_query_columns(self) -> List[str]: fields = super()._get_storage_table_query_columns() - fields[1] = ( # Override because this is the only way to get data type with precision + fields[ + 1 + ] = ( # Override because this is the only way to get data type with precision "full_data_type" ) return fields diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 68ea863cc4..b66f69cd46 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -1,5 +1,15 @@ from contextlib import contextmanager, suppress -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List, Union, Dict +from typing import ( + Any, + AnyStr, + ClassVar, + Iterator, + Optional, + Sequence, + List, + Union, + Dict, +) from databricks import sql as databricks_lib from databricks.sql.client import ( @@ -68,7 +78,9 @@ def native_connection(self) -> "DatabricksSqlConnection": return self._conn def drop_dataset(self) -> None: - self.execute_sql("DROP SCHEMA IF EXISTS %s CASCADE;" % self.fully_qualified_dataset_name()) + self.execute_sql( + "DROP SCHEMA IF EXISTS %s CASCADE;" % self.fully_qualified_dataset_name() + ) def drop_tables(self, *tables: str) -> None: # Tables are drop with `IF EXISTS`, but databricks raises when the schema doesn't exist. @@ -88,7 +100,9 @@ def execute_sql( @contextmanager @raise_database_error - def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: + def execute_query( + self, query: AnyStr, *args: Any, **kwargs: Any + ) -> Iterator[DBApiCursor]: curr: DBApiCursor = None # TODO: databricks connector 3.0.0 will use :named paramstyle only # if args: @@ -137,7 +151,9 @@ def _make_database_exception(ex: Exception) -> Exception: return DatabaseTerminalException(ex) elif isinstance(ex, databricks_lib.OperationalError): return DatabaseTerminalException(ex) - elif isinstance(ex, (databricks_lib.ProgrammingError, databricks_lib.IntegrityError)): + elif isinstance( + ex, (databricks_lib.ProgrammingError, databricks_lib.IntegrityError) + ): return DatabaseTerminalException(ex) elif isinstance(ex, databricks_lib.DatabaseError): return DatabaseTransientException(ex) diff --git a/dlt/destinations/impl/destination/__init__.py b/dlt/destinations/impl/destination/__init__.py index 560c9d4eda..b2e3333691 100644 --- a/dlt/destinations/impl/destination/__init__.py +++ b/dlt/destinations/impl/destination/__init__.py @@ -8,7 +8,9 @@ def capabilities( naming_convention: str = "direct", max_table_nesting: Optional[int] = 0, ) -> DestinationCapabilitiesContext: - caps = DestinationCapabilitiesContext.generic_capabilities(preferred_loader_file_format) + caps = DestinationCapabilitiesContext.generic_capabilities( + preferred_loader_file_format + ) caps.supported_loader_file_formats = ["puae-jsonl", "parquet"] caps.supports_ddl_transactions = False caps.supports_transactions = False diff --git a/dlt/destinations/impl/destination/configuration.py b/dlt/destinations/impl/destination/configuration.py index 30e54a8313..241970b33b 100644 --- a/dlt/destinations/impl/destination/configuration.py +++ b/dlt/destinations/impl/destination/configuration.py @@ -18,7 +18,9 @@ @configspec class CustomDestinationClientConfiguration(DestinationClientConfiguration): destination_type: Final[str] = dataclasses.field(default="destination", init=False, repr=False, compare=False) # type: ignore - destination_callable: Optional[Union[str, TDestinationCallable]] = None # noqa: A003 + destination_callable: Optional[Union[str, TDestinationCallable]] = ( + None # noqa: A003 + ) loader_file_format: TLoaderFileFormat = "puae-jsonl" batch_size: int = 10 skip_dlt_columns_and_tables: bool = True diff --git a/dlt/destinations/impl/destination/destination.py b/dlt/destinations/impl/destination/destination.py index 4a3cabde34..d2324fad80 100644 --- a/dlt/destinations/impl/destination/destination.py +++ b/dlt/destinations/impl/destination/destination.py @@ -50,7 +50,9 @@ def __init__( # we create pre_resolved callable here self._callable = destination_callable self._state: TLoadJobState = "running" - self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" + self._storage_id = ( + f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" + ) self.skipped_columns = skipped_columns try: if self._config.batch_size == 0: @@ -140,7 +142,9 @@ class DestinationClient(JobClientBase): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() - def __init__(self, schema: Schema, config: CustomDestinationClientConfiguration) -> None: + def __init__( + self, schema: Schema, config: CustomDestinationClientConfiguration + ) -> None: super().__init__(schema, config) self.config: CustomDestinationClientConfiguration = config # create pre-resolved callable to avoid multiple config resolutions during execution of the jobs @@ -162,7 +166,9 @@ def update_stored_schema( ) -> Optional[TSchemaTables]: return super().update_stored_schema(only_tables, expected_update) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: # skip internal tables and remove columns from schema if so configured skipped_columns: List[str] = [] if self.config.skip_dlt_columns_and_tables: @@ -207,6 +213,9 @@ def __enter__(self) -> "DestinationClient": return self def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, ) -> None: pass diff --git a/dlt/destinations/impl/destination/factory.py b/dlt/destinations/impl/destination/factory.py index 7cca8f2202..da298f1618 100644 --- a/dlt/destinations/impl/destination/factory.py +++ b/dlt/destinations/impl/destination/factory.py @@ -35,10 +35,14 @@ class DestinationInfo(t.NamedTuple): """A registry of all the decorated destinations""" -class destination(Destination[CustomDestinationClientConfiguration, "DestinationClient"]): +class destination( + Destination[CustomDestinationClientConfiguration, "DestinationClient"] +): def capabilities(self) -> DestinationCapabilitiesContext: return capabilities( - preferred_loader_file_format=self.config_params.get("loader_file_format", "puae-jsonl"), + preferred_loader_file_format=self.config_params.get( + "loader_file_format", "puae-jsonl" + ), naming_convention=self.config_params.get("naming_convention", "direct"), max_table_nesting=self.config_params.get("max_table_nesting", None), ) @@ -67,8 +71,8 @@ def __init__( ) -> None: if spec and not issubclass(spec, CustomDestinationClientConfiguration): raise ValueError( - "A SPEC for a sink destination must use CustomDestinationClientConfiguration as a" - " base." + "A SPEC for a sink destination must use" + " CustomDestinationClientConfiguration as a base." ) # resolve callable if callable(destination_callable): @@ -92,19 +96,22 @@ def __init__( # this is needed for cli commands to work if not destination_callable: logger.warning( - "No destination callable provided, providing dummy callable which will fail on" - " load." + "No destination callable provided, providing dummy callable which will" + " fail on load." ) def dummy_callable(*args: t.Any, **kwargs: t.Any) -> None: raise DestinationTransientException( - "You tried to load to a custom destination without a valid callable." + "You tried to load to a custom destination without a valid" + " callable." ) destination_callable = dummy_callable elif not callable(destination_callable): - raise ConfigurationValueError("Resolved Sink destination callable is not a callable.") + raise ConfigurationValueError( + "Resolved Sink destination callable is not a callable." + ) # resolve destination name if destination_name is None: diff --git a/dlt/destinations/impl/duckdb/__init__.py b/dlt/destinations/impl/duckdb/__init__.py index 5cbc8dea53..7f72a13898 100644 --- a/dlt/destinations/impl/duckdb/__init__.py +++ b/dlt/destinations/impl/duckdb/__init__.py @@ -1,4 +1,7 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal +from dlt.common.data_writers.escape import ( + escape_postgres_identifier, + escape_duckdb_literal, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE diff --git a/dlt/destinations/impl/duckdb/configuration.py b/dlt/destinations/impl/duckdb/configuration.py index 70d91dcb56..31c0cc9206 100644 --- a/dlt/destinations/impl/duckdb/configuration.py +++ b/dlt/destinations/impl/duckdb/configuration.py @@ -2,7 +2,17 @@ import dataclasses import threading from pathvalidate import is_valid_filepath -from typing import Any, ClassVar, Final, List, Optional, Tuple, TYPE_CHECKING, Type, Union +from typing import ( + Any, + ClassVar, + Final, + List, + Optional, + Tuple, + TYPE_CHECKING, + Type, + Union, +) from dlt.common import logger from dlt.common.configuration import configspec @@ -43,7 +53,9 @@ def borrow_conn(self, read_only: bool) -> Any: # obtain a lock because duck releases the GIL and we have refcount concurrency with self._conn_lock: if not hasattr(self, "_conn"): - self._conn = duckdb.connect(database=self._conn_str(), read_only=read_only) + self._conn = duckdb.connect( + database=self._conn_str(), read_only=read_only + ) self._conn_owner = True self._conn_borrows = 0 @@ -81,7 +93,9 @@ def parse_native_representation(self, native_value: Any) -> None: try: super().parse_native_representation(native_value) except InvalidConnectionString: - if native_value == ":pipeline:" or is_valid_filepath(native_value, platform="auto"): + if native_value == ":pipeline:" or is_valid_filepath( + native_value, platform="auto" + ): self.database = native_value else: raise @@ -123,7 +137,9 @@ def on_resolved(self) -> None: self.database = self._path_in_pipeline(DEFAULT_DUCK_DB_NAME) else: # maybe get database - maybe_database, maybe_is_default_path = self._path_from_pipeline(DEFAULT_DUCK_DB_NAME) + maybe_database, maybe_is_default_path = self._path_from_pipeline( + DEFAULT_DUCK_DB_NAME + ) # if pipeline context was not present or database was not set if not self.database or not maybe_is_default_path: # create database locally @@ -145,7 +161,8 @@ def _path_in_pipeline(self, rel_path: str) -> str: # pipeline is active, get the working directory return os.path.join(context.pipeline().working_dir, rel_path) raise RuntimeError( - "Attempting to use special duckdb database :pipeline: outside of pipeline context." + "Attempting to use special duckdb database :pipeline: outside of pipeline" + " context." ) def _path_to_pipeline(self, abspath: str) -> None: @@ -185,8 +202,8 @@ def _path_from_pipeline(self, default_path: str) -> Tuple[str, bool]: if not os.path.exists(pipeline_path): logger.warning( f"Duckdb attached to pipeline {pipeline.pipeline_name} in path" - f" {os.path.relpath(pipeline_path)} was deleted. Attaching to duckdb" - f" database '{default_path}' in current folder." + f" {os.path.relpath(pipeline_path)} was deleted. Attaching to" + f" duckdb database '{default_path}' in current folder." ) else: return pipeline_path, False diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 735a4ce7e3..2bc8f1c5d2 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -110,7 +110,9 @@ def from_db_type( class DuckDbCopyJob(LoadJob, FollowupJob): - def __init__(self, table_name: str, file_path: str, sql_client: DuckDbSqlClient) -> None: + def __init__( + self, table_name: str, file_path: str, sql_client: DuckDbSqlClient + ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) qualified_table_name = sql_client.make_qualified_table_name(table_name) @@ -149,20 +151,26 @@ class DuckDbClient(InsertValuesJobClient): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: DuckDbClientConfiguration) -> None: - sql_client = DuckDbSqlClient(config.normalize_dataset_name(schema), config.credentials) + sql_client = DuckDbSqlClient( + config.normalize_dataset_name(schema), config.credentials + ) super().__init__(schema, config, sql_client) self.config: DuckDbClientConfiguration = config self.sql_client: DuckDbSqlClient = sql_client # type: ignore self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} self.type_mapper = DuckDbTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: job = super().start_file_load(table, file_path, load_id) if not job: job = DuckDbCopyJob(table["name"], file_path, self.sql_client) return job - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql( + self, c: TColumnSchema, table_format: TTableFormat = None + ) -> str: hints_str = " ".join( self.active_hints.get(h, "") for h in self.active_hints.keys() diff --git a/dlt/destinations/impl/duckdb/factory.py b/dlt/destinations/impl/duckdb/factory.py index 6a0152df26..db0b080096 100644 --- a/dlt/destinations/impl/duckdb/factory.py +++ b/dlt/destinations/impl/duckdb/factory.py @@ -1,7 +1,10 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext -from dlt.destinations.impl.duckdb.configuration import DuckDbCredentials, DuckDbClientConfiguration +from dlt.destinations.impl.duckdb.configuration import ( + DuckDbCredentials, + DuckDbClientConfiguration, +) from dlt.destinations.impl.duckdb import capabilities if t.TYPE_CHECKING: diff --git a/dlt/destinations/impl/duckdb/sql_client.py b/dlt/destinations/impl/duckdb/sql_client.py index 2863d4943e..d1b82485f7 100644 --- a/dlt/destinations/impl/duckdb/sql_client.py +++ b/dlt/destinations/impl/duckdb/sql_client.py @@ -114,7 +114,9 @@ def execute_sql( @contextmanager @raise_database_error - def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: + def execute_query( + self, query: AnyStr, *args: Any, **kwargs: Any + ) -> Iterator[DBApiCursor]: assert isinstance(query, str) db_args = args if args else kwargs if kwargs else None if db_args: @@ -144,7 +146,9 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB def fully_qualified_dataset_name(self, escape: bool = True) -> str: return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + self.capabilities.escape_identifier(self.dataset_name) + if escape + else self.dataset_name ) @classmethod @@ -173,7 +177,9 @@ def _make_database_exception(cls, ex: Exception) -> Exception: return term else: return DatabaseTransientException(ex) - elif isinstance(ex, (duckdb.DataError, duckdb.ProgrammingError, duckdb.IntegrityError)): + elif isinstance( + ex, (duckdb.DataError, duckdb.ProgrammingError, duckdb.IntegrityError) + ): return DatabaseTerminalException(ex) elif cls.is_dbapi_exception(ex): return DatabaseTransientException(ex) @@ -181,7 +187,9 @@ def _make_database_exception(cls, ex: Exception) -> Exception: return ex @staticmethod - def _maybe_make_terminal_exception_from_data_error(pg_ex: duckdb.Error) -> Optional[Exception]: + def _maybe_make_terminal_exception_from_data_error( + pg_ex: duckdb.Error, + ) -> Optional[Exception]: return None @staticmethod diff --git a/dlt/destinations/impl/dummy/__init__.py b/dlt/destinations/impl/dummy/__init__.py index 37b2e77c8a..03a8f2f28b 100644 --- a/dlt/destinations/impl/dummy/__init__.py +++ b/dlt/destinations/impl/dummy/__init__.py @@ -14,7 +14,9 @@ "dummy", ), ) -def _configure(config: DummyClientConfiguration = config.value) -> DummyClientConfiguration: +def _configure( + config: DummyClientConfiguration = config.value, +) -> DummyClientConfiguration: return config @@ -25,9 +27,13 @@ def capabilities() -> DestinationCapabilitiesContext: ) caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = config.loader_file_format - caps.supported_loader_file_formats = additional_formats + [config.loader_file_format] + caps.supported_loader_file_formats = additional_formats + [ + config.loader_file_format + ] caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = additional_formats + [config.loader_file_format] + caps.supported_staging_file_formats = additional_formats + [ + config.loader_file_format + ] caps.max_identifier_length = 127 caps.max_column_identifier_length = 127 caps.max_query_length = 8 * 1024 * 1024 diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index 0d91220d88..c9d8bced24 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -96,7 +96,9 @@ class LoadDummyJob(LoadDummyBaseJob, FollowupJob): def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: if self.config.create_followup_jobs and final_state == "completed": new_job = NewReferenceJob( - file_name=self.file_name(), status="running", remote_path=self._file_name + file_name=self.file_name(), + status="running", + remote_path=self._file_name, ) CREATED_FOLLOWUP_JOBS[new_job.job_id()] = new_job return [new_job] @@ -136,7 +138,9 @@ def update_stored_schema( ) return applied_update - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: job_id = FileStorage.get_file_name_from_file_path(file_path) file_name = FileStorage.get_file_name_from_file_path(file_path) # return existing job if already there @@ -179,7 +183,10 @@ def __enter__(self) -> "DummyClient": return self def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, ) -> None: pass diff --git a/dlt/destinations/impl/filesystem/factory.py b/dlt/destinations/impl/filesystem/factory.py index 029a5bdda5..86206fa220 100644 --- a/dlt/destinations/impl/filesystem/factory.py +++ b/dlt/destinations/impl/filesystem/factory.py @@ -1,6 +1,8 @@ import typing as t -from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration +from dlt.destinations.impl.filesystem.configuration import ( + FilesystemDestinationClientConfiguration, +) from dlt.destinations.impl.filesystem import capabilities from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.common.storages.configuration import FileSystemCredentials @@ -9,7 +11,9 @@ from dlt.destinations.impl.filesystem.filesystem import FilesystemClient -class filesystem(Destination[FilesystemDestinationClientConfiguration, "FilesystemClient"]): +class filesystem( + Destination[FilesystemDestinationClientConfiguration, "FilesystemClient"] +): spec = FilesystemDestinationClientConfiguration def capabilities(self) -> DestinationCapabilitiesContext: diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 33a597f915..ca4f14ee44 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -20,7 +20,9 @@ from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.impl.filesystem import capabilities -from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration +from dlt.destinations.impl.filesystem.configuration import ( + FilesystemDestinationClientConfiguration, +) from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations import path_utils @@ -81,7 +83,9 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: jobs = super().create_followup_jobs(final_state) if final_state == "completed": ref_job = NewReferenceJob( - file_name=self.file_name(), status="running", remote_path=self.make_remote_path() + file_name=self.file_name(), + status="running", + remote_path=self.make_remote_path(), ) jobs.append(ref_job) return jobs @@ -94,7 +98,9 @@ class FilesystemClient(JobClientBase, WithStagingDataset): fs_client: AbstractFileSystem fs_path: str - def __init__(self, schema: Schema, config: FilesystemDestinationClientConfiguration) -> None: + def __init__( + self, schema: Schema, config: FilesystemDestinationClientConfiguration + ) -> None: super().__init__(schema, config) self.fs_client, self.fs_path = fsspec_from_config(config) self.config: FilesystemDestinationClientConfiguration = config @@ -144,7 +150,9 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: # NOTE: without refresh you get random results here logger.info(f"Will truncate tables in {truncate_dir}") try: - all_files = self.fs_client.ls(truncate_dir, detail=False, refresh=True) + all_files = self.fs_client.ls( + truncate_dir, detail=False, refresh=True + ) # logger.debug(f"Found {len(all_files)} CANDIDATE files in {truncate_dir}") # print(f"in truncate dir {truncate_dir}: {all_files}") for item in all_files: @@ -163,8 +171,8 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: raise FileExistsError(item) except FileNotFoundError: logger.info( - f"Directory or path to truncate tables {truncate_dir} does not exist but it" - " should be created previously!" + f"Directory or path to truncate tables {truncate_dir} does not" + " exist but it should be created previously!" ) def update_stored_schema( @@ -191,7 +199,9 @@ def _get_table_dirs(self, table_names: Iterable[str]) -> Set[str]: def is_storage_initialized(self) -> bool: return self.fs_client.isdir(self.dataset_path) # type: ignore[no-any-return] - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: cls = FollowupFilesystemJob if self.config.as_staging else LoadFilesystemJob return cls( file_path, @@ -214,7 +224,10 @@ def __enter__(self) -> "FilesystemClient": return self def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, ) -> None: pass diff --git a/dlt/destinations/impl/motherduck/__init__.py b/dlt/destinations/impl/motherduck/__init__.py index 74c0e36ef3..3a7c9f5021 100644 --- a/dlt/destinations/impl/motherduck/__init__.py +++ b/dlt/destinations/impl/motherduck/__init__.py @@ -1,4 +1,7 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal +from dlt.common.data_writers.escape import ( + escape_postgres_identifier, + escape_duckdb_literal, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE diff --git a/dlt/destinations/impl/motherduck/configuration.py b/dlt/destinations/impl/motherduck/configuration.py index 3179295c54..e470e2dab1 100644 --- a/dlt/destinations/impl/motherduck/configuration.py +++ b/dlt/destinations/impl/motherduck/configuration.py @@ -2,7 +2,9 @@ from typing import Any, ClassVar, Final, List from dlt.common.configuration import configspec -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.reference import ( + DestinationClientDwhWithStagingConfiguration, +) from dlt.common.destination.exceptions import DestinationTerminalException from dlt.common.typing import TSecretValue from dlt.common.utils import digest128 @@ -36,10 +38,14 @@ def borrow_conn(self, read_only: bool) -> Any: try: return super().borrow_conn(read_only) except (InvalidInputException, HTTPException) as ext_ex: - if "Failed to download extension" in str(ext_ex) and "motherduck" in str(ext_ex): + if "Failed to download extension" in str(ext_ex) and "motherduck" in str( + ext_ex + ): from importlib.metadata import version as pkg_version - raise MotherduckLocalVersionNotSupported(pkg_version("duckdb")) from ext_ex + raise MotherduckLocalVersionNotSupported( + pkg_version("duckdb") + ) from ext_ex raise @@ -51,8 +57,9 @@ def on_resolved(self) -> None: self._token_to_password() if self.drivername == MOTHERDUCK_DRIVERNAME and not self.password: raise ConfigurationValueError( - "Motherduck schema 'md' was specified without corresponding token or password. The" - " required format of connection string is: md:///?token=" + "Motherduck schema 'md' was specified without corresponding token or" + " password. The required format of connection string is:" + " md:///?token=" ) @@ -76,6 +83,6 @@ class MotherduckLocalVersionNotSupported(DestinationTerminalException): def __init__(self, duckdb_version: str) -> None: self.duckdb_version = duckdb_version super().__init__( - f"Looks like your local duckdb version ({duckdb_version}) is not supported by" - " Motherduck" + f"Looks like your local duckdb version ({duckdb_version}) is not supported" + " by Motherduck" ) diff --git a/dlt/destinations/impl/motherduck/motherduck.py b/dlt/destinations/impl/motherduck/motherduck.py index c695d9715e..27c2b324c6 100644 --- a/dlt/destinations/impl/motherduck/motherduck.py +++ b/dlt/destinations/impl/motherduck/motherduck.py @@ -15,6 +15,8 @@ class MotherDuckClient(DuckDbClient): def __init__(self, schema: Schema, config: MotherDuckClientConfiguration) -> None: super().__init__(schema, config) # type: ignore - sql_client = MotherDuckSqlClient(config.normalize_dataset_name(schema), config.credentials) + sql_client = MotherDuckSqlClient( + config.normalize_dataset_name(schema), config.credentials + ) self.config: MotherDuckClientConfiguration = config # type: ignore self.sql_client: MotherDuckSqlClient = sql_client diff --git a/dlt/destinations/impl/motherduck/sql_client.py b/dlt/destinations/impl/motherduck/sql_client.py index 7990f90947..704cfaba92 100644 --- a/dlt/destinations/impl/motherduck/sql_client.py +++ b/dlt/destinations/impl/motherduck/sql_client.py @@ -17,7 +17,10 @@ raise_open_connection_error, ) -from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient, DuckDBDBApiCursorImpl +from dlt.destinations.impl.duckdb.sql_client import ( + DuckDbSqlClient, + DuckDBDBApiCursorImpl, +) from dlt.destinations.impl.motherduck import capabilities from dlt.destinations.impl.motherduck.configuration import MotherDuckCredentials @@ -36,6 +39,8 @@ def fully_qualified_dataset_name(self, escape: bool = True) -> str: else self.database_name ) dataset_name = ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + self.capabilities.escape_identifier(self.dataset_name) + if escape + else self.dataset_name ) return f"{database_name}.{dataset_name}" diff --git a/dlt/destinations/impl/mssql/__init__.py b/dlt/destinations/impl/mssql/__init__.py index e9d9fe24fd..3bdd5fda3c 100644 --- a/dlt/destinations/impl/mssql/__init__.py +++ b/dlt/destinations/impl/mssql/__init__.py @@ -1,4 +1,7 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal +from dlt.common.data_writers.escape import ( + escape_postgres_identifier, + escape_mssql_literal, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.wei import EVM_DECIMAL_PRECISION diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index 1d085f40c1..bd33c81c4b 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -8,7 +8,9 @@ from dlt.common.typing import TSecretValue from dlt.common.exceptions import SystemConfigurationException -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.reference import ( + DestinationClientDwhWithStagingConfiguration, +) @configspec(init=False) @@ -31,9 +33,13 @@ def parse_native_representation(self, native_value: Any) -> None: # TODO: Support ODBC connection string or sqlalchemy URL super().parse_native_representation(native_value) if self.query is not None: - self.query = {k.lower(): v for k, v in self.query.items()} # Make case-insensitive. + self.query = { + k.lower(): v for k, v in self.query.items() + } # Make case-insensitive. self.driver = self.query.get("driver", self.driver) - self.connect_timeout = int(self.query.get("connect_timeout", self.connect_timeout)) + self.connect_timeout = int( + self.query.get("connect_timeout", self.connect_timeout) + ) if not self.is_partial(): self.resolve() @@ -41,7 +47,8 @@ def on_resolved(self) -> None: if self.driver not in self.SUPPORTED_DRIVERS: raise SystemConfigurationException( f"""The specified driver "{self.driver}" is not supported.""" - f" Choose one of the supported drivers: {', '.join(self.SUPPORTED_DRIVERS)}." + " Choose one of the supported drivers:" + f" {', '.join(self.SUPPORTED_DRIVERS)}." ) self.database = self.database.lower() @@ -68,8 +75,9 @@ def _get_driver(self) -> str: return d docs_url = "https://learn.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server?view=sql-server-ver16" raise SystemConfigurationException( - f"No supported ODBC driver found for MS SQL Server. See {docs_url} for information on" - f" how to install the '{self.SUPPORTED_DRIVERS[0]}' on your platform." + f"No supported ODBC driver found for MS SQL Server. See {docs_url} for" + f" information on how to install the '{self.SUPPORTED_DRIVERS[0]}' on your" + " platform." ) def _get_odbc_dsn_dict(self) -> Dict[str, Any]: diff --git a/dlt/destinations/impl/mssql/factory.py b/dlt/destinations/impl/mssql/factory.py index 2e19d7c2a8..fea49c02e3 100644 --- a/dlt/destinations/impl/mssql/factory.py +++ b/dlt/destinations/impl/mssql/factory.py @@ -2,7 +2,10 @@ from dlt.common.destination import Destination, DestinationCapabilitiesContext -from dlt.destinations.impl.mssql.configuration import MsSqlCredentials, MsSqlClientConfiguration +from dlt.destinations.impl.mssql.configuration import ( + MsSqlCredentials, + MsSqlClientConfiguration, +) from dlt.destinations.impl.mssql import capabilities if t.TYPE_CHECKING: diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index b6af345e36..8fca42afff 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -105,7 +105,9 @@ def generate_sql( f" {staging_table_name};" ) # recreate staging table - sql.append(f"SELECT * INTO {staging_table_name} FROM {table_name} WHERE 1 = 0;") + sql.append( + f"SELECT * INTO {staging_table_name} FROM {table_name} WHERE 1 = 0;" + ) return sql @@ -144,14 +146,18 @@ class MsSqlClient(InsertValuesJobClient): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: MsSqlClientConfiguration) -> None: - sql_client = PyOdbcMsSqlClient(config.normalize_dataset_name(schema), config.credentials) + sql_client = PyOdbcMsSqlClient( + config.normalize_dataset_name(schema), config.credentials + ) super().__init__(schema, config, sql_client) self.config: MsSqlClientConfiguration = config self.sql_client = sql_client self.active_hints = HINT_TO_MSSQL_ATTR if self.config.create_indexes else {} self.type_mapper = MsSqlTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: return [MsSqlMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( @@ -159,10 +165,13 @@ def _make_add_column_sql( ) -> List[str]: # Override because mssql requires multiple columns in a single ADD COLUMN clause return [ - "ADD \n" + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns) + "ADD \n" + + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns) ] - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql( + self, c: TColumnSchema, table_format: TTableFormat = None + ) -> str: sc_type = c["data_type"] if sc_type == "text" and c.get("unique"): # MSSQL does not allow index on large TEXT columns @@ -176,7 +185,9 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non if c.get(h, False) is True ) column_name = self.capabilities.escape_identifier(c["name"]) - return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c['nullable'])}" + return ( + f"{column_name} {db_type} {hints_str} {self._gen_not_null(c['nullable'])}" + ) def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index cd1699adea..f9f425f02f 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -28,7 +28,9 @@ def handle_datetimeoffset(dto_value: bytes) -> datetime: # ref: https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794 - tup = struct.unpack("<6hI2h", dto_value) # e.g., (2017, 3, 16, 10, 35, 18, 500000000, -6, 0) + tup = struct.unpack( + "<6hI2h", dto_value + ) # e.g., (2017, 3, 16, 10, 35, 18, 500000000, -6, 0) return datetime( tup[0], tup[1], @@ -113,12 +115,15 @@ def _drop_views(self, *tables: str) -> None: if not tables: return statements = [ - f"DROP VIEW IF EXISTS {self.make_qualified_table_name(table)};" for table in tables + f"DROP VIEW IF EXISTS {self.make_qualified_table_name(table)};" + for table in tables ] self.execute_many(statements) def _drop_schema(self) -> None: - self.execute_sql("DROP SCHEMA IF EXISTS %s;" % self.fully_qualified_dataset_name()) + self.execute_sql( + "DROP SCHEMA IF EXISTS %s;" % self.fully_qualified_dataset_name() + ) def execute_sql( self, sql: AnyStr, *args: Any, **kwargs: Any @@ -132,11 +137,15 @@ def execute_sql( @contextmanager @raise_database_error - def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: + def execute_query( + self, query: AnyStr, *args: Any, **kwargs: Any + ) -> Iterator[DBApiCursor]: assert isinstance(query, str) curr: DBApiCursor = None if kwargs: - raise NotImplementedError("pyodbc does not support named parameters in queries") + raise NotImplementedError( + "pyodbc does not support named parameters in queries" + ) if args: # TODO: this is bad. See duckdb & athena also query = query.replace("%s", "?") @@ -151,7 +160,9 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB def fully_qualified_dataset_name(self, escape: bool = True) -> str: return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + self.capabilities.escape_identifier(self.dataset_name) + if escape + else self.dataset_name ) @classmethod diff --git a/dlt/destinations/impl/postgres/__init__.py b/dlt/destinations/impl/postgres/__init__.py index 43e6af1996..ab2b1066c0 100644 --- a/dlt/destinations/impl/postgres/__init__.py +++ b/dlt/destinations/impl/postgres/__init__.py @@ -1,6 +1,12 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_postgres_literal +from dlt.common.data_writers.escape import ( + escape_postgres_identifier, + escape_postgres_literal, +) from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration +from dlt.common.destination.reference import ( + JobClientBase, + DestinationClientConfiguration, +) from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.wei import EVM_DECIMAL_PRECISION diff --git a/dlt/destinations/impl/postgres/configuration.py b/dlt/destinations/impl/postgres/configuration.py index 0d12abbac7..ad6a93aa4a 100644 --- a/dlt/destinations/impl/postgres/configuration.py +++ b/dlt/destinations/impl/postgres/configuration.py @@ -7,7 +7,9 @@ from dlt.common.utils import digest128 from dlt.common.typing import TSecretValue -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.reference import ( + DestinationClientDwhWithStagingConfiguration, +) @configspec(init=False) @@ -22,7 +24,9 @@ class PostgresCredentials(ConnectionStringCredentials): def parse_native_representation(self, native_value: Any) -> None: super().parse_native_representation(native_value) - self.connect_timeout = int(self.query.get("connect_timeout", self.connect_timeout)) + self.connect_timeout = int( + self.query.get("connect_timeout", self.connect_timeout) + ) if not self.is_partial(): self.resolve() diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index f8fa3e341a..1168653be1 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -100,7 +100,9 @@ def generate_sql( f" {sql_client.fully_qualified_dataset_name()};" ) # recreate staging table - sql.append(f"CREATE TABLE {staging_table_name} (like {table_name} including all);") + sql.append( + f"CREATE TABLE {staging_table_name} (like {table_name} including all);" + ) return sql @@ -108,14 +110,18 @@ class PostgresClient(InsertValuesJobClient): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: PostgresClientConfiguration) -> None: - sql_client = Psycopg2SqlClient(config.normalize_dataset_name(schema), config.credentials) + sql_client = Psycopg2SqlClient( + config.normalize_dataset_name(schema), config.credentials + ) super().__init__(schema, config, sql_client) self.config: PostgresClientConfiguration = config self.sql_client = sql_client self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} self.type_mapper = PostgresTypeMapper(self.capabilities) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql( + self, c: TColumnSchema, table_format: TTableFormat = None + ) -> str: hints_str = " ".join( self.active_hints.get(h, "") for h in self.active_hints.keys() @@ -130,7 +136,9 @@ def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] ) -> List[NewLoadJob]: if self.config.replace_strategy == "staging-optimized": - return [PostgresStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + return [ + PostgresStagingCopyJob.from_table_chain(table_chain, self.sql_client) + ] return super()._create_replace_followup_jobs(table_chain) def _from_db_type( diff --git a/dlt/destinations/impl/postgres/sql_client.py b/dlt/destinations/impl/postgres/sql_client.py index 366ed243ef..37ac94ef0f 100644 --- a/dlt/destinations/impl/postgres/sql_client.py +++ b/dlt/destinations/impl/postgres/sql_client.py @@ -90,7 +90,9 @@ def execute_sql( @contextmanager @raise_database_error - def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: + def execute_query( + self, query: AnyStr, *args: Any, **kwargs: Any + ) -> Iterator[DBApiCursor]: curr: DBApiCursor = None db_args = args if args else kwargs if kwargs else None with self._conn.cursor() as curr: @@ -109,12 +111,16 @@ def execute_fragments( self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any ) -> Optional[Sequence[Sequence[Any]]]: # compose the statements using psycopg2 library - composed = Composed(sql if isinstance(sql, Composable) else SQL(sql) for sql in fragments) + composed = Composed( + sql if isinstance(sql, Composable) else SQL(sql) for sql in fragments + ) return self.execute_sql(composed, *args, **kwargs) def fully_qualified_dataset_name(self, escape: bool = True) -> str: return ( - self.capabilities.escape_identifier(self.dataset_name) if escape else self.dataset_name + self.capabilities.escape_identifier(self.dataset_name) + if escape + else self.dataset_name ) def _reset_connection(self) -> None: @@ -124,7 +130,9 @@ def _reset_connection(self) -> None: @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: - if isinstance(ex, (psycopg2.errors.UndefinedTable, psycopg2.errors.InvalidSchemaName)): + if isinstance( + ex, (psycopg2.errors.UndefinedTable, psycopg2.errors.InvalidSchemaName) + ): raise DatabaseUndefinedRelation(ex) if isinstance( ex, diff --git a/dlt/destinations/impl/qdrant/factory.py b/dlt/destinations/impl/qdrant/factory.py index df9cd64871..a6067c8528 100644 --- a/dlt/destinations/impl/qdrant/factory.py +++ b/dlt/destinations/impl/qdrant/factory.py @@ -2,7 +2,10 @@ from dlt.common.destination import Destination, DestinationCapabilitiesContext -from dlt.destinations.impl.qdrant.configuration import QdrantCredentials, QdrantClientConfiguration +from dlt.destinations.impl.qdrant.configuration import ( + QdrantCredentials, + QdrantClientConfiguration, +) from dlt.destinations.impl.qdrant import capabilities if t.TYPE_CHECKING: diff --git a/dlt/destinations/impl/qdrant/qdrant_adapter.py b/dlt/destinations/impl/qdrant/qdrant_adapter.py index 215d87a920..9145c33d3d 100644 --- a/dlt/destinations/impl/qdrant/qdrant_adapter.py +++ b/dlt/destinations/impl/qdrant/qdrant_adapter.py @@ -41,7 +41,8 @@ def qdrant_adapter( embed = [embed] if not isinstance(embed, list): raise ValueError( - "embed must be a list of column names or a single column name as a string" + "embed must be a list of column names or a single column name as a" + " string" ) for column_name in embed: diff --git a/dlt/destinations/impl/qdrant/qdrant_client.py b/dlt/destinations/impl/qdrant/qdrant_client.py index febfe38ec9..b6c01a66e9 100644 --- a/dlt/destinations/impl/qdrant/qdrant_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_client.py @@ -5,7 +5,12 @@ from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync +from dlt.common.destination.reference import ( + TLoadJobState, + LoadJob, + JobClientBase, + WithStateSync, +) from dlt.common.storages import FileStorage from dlt.destinations.job_impl import EmptyLoadJob @@ -33,7 +38,9 @@ def __init__( super().__init__(file_name) self.db_client = db_client self.collection_name = collection_name - self.embedding_fields = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) + self.embedding_fields = get_columns_names_with_prop( + table_schema, VECTORIZE_HINT + ) self.unique_identifiers = self._list_unique_identifiers(table_schema) self.config = client_config @@ -43,7 +50,9 @@ def __init__( for line in f: data = json.loads(line) point_id = ( - self._generate_uuid(data, self.unique_identifiers, self.collection_name) + self._generate_uuid( + data, self.unique_identifiers, self.collection_name + ) if self.unique_identifiers else uuid.uuid4() ) @@ -52,7 +61,9 @@ def __init__( ids.append(point_id) docs.append(embedding_doc) - embedding_model = db_client._get_or_init_model(db_client.embedding_model_name) + embedding_model = db_client._get_or_init_model( + db_client.embedding_model_name + ) embeddings = list( embedding_model.embed( docs, @@ -114,7 +125,10 @@ def _upload_data( ) def _generate_uuid( - self, data: Dict[str, Any], unique_identifiers: Sequence[str], collection_name: str + self, + data: Dict[str, Any], + unique_identifiers: Sequence[str], + collection_name: str, ) -> str: """Generates deterministic UUID. Used for deduplication. @@ -262,7 +276,9 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: self._create_sentinel_collection() elif truncate_tables: for table_name in truncate_tables: - qualified_table_name = self._make_qualified_collection_name(table_name=table_name) + qualified_table_name = self._make_qualified_collection_name( + table_name=table_name + ) if self._collection_exists(qualified_table_name): continue @@ -270,7 +286,9 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: self._create_collection(full_collection_name=qualified_table_name) def is_storage_initialized(self) -> bool: - return self._collection_exists(self.sentinel_collection, qualify_table_name=False) + return self._collection_exists( + self.sentinel_collection, qualify_table_name=False + ) def _create_sentinel_collection(self) -> None: """Create an empty collection to indicate that the storage is initialized.""" @@ -317,7 +335,8 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: scroll_filter=models.Filter( must=[ models.FieldCondition( - key="pipeline_name", match=models.MatchValue(value=pipeline_name) + key="pipeline_name", + match=models.MatchValue(value=pipeline_name), ) ] ), @@ -338,7 +357,8 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: count_filter=models.Filter( must=[ models.FieldCondition( - key="load_id", match=models.MatchValue(value=load_id) + key="load_id", + match=models.MatchValue(value=load_id), ) ] ), @@ -352,7 +372,9 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" try: - scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name) + scroll_table_name = self._make_qualified_collection_name( + self.schema.version_table_name + ) response = self.db_client.scroll( scroll_table_name, with_payload=True, @@ -371,16 +393,21 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: except Exception: return None - def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: + def get_stored_schema_by_hash( + self, schema_hash: str + ) -> Optional[StorageSchemaInfo]: try: - scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name) + scroll_table_name = self._make_qualified_collection_name( + self.schema.version_table_name + ) response = self.db_client.scroll( scroll_table_name, with_payload=True, scroll_filter=models.Filter( must=[ models.FieldCondition( - key="version_hash", match=models.MatchValue(value=schema_hash) + key="version_hash", + match=models.MatchValue(value=schema_hash), ) ] ), @@ -391,7 +418,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI except Exception: return None - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: return LoadQdrantJob( table, file_path, @@ -410,7 +439,9 @@ def complete_load(self, load_id: str) -> None: "status": 0, "inserted_at": str(pendulum.now()), } - loads_table_name = self._make_qualified_collection_name(self.schema.loads_table_name) + loads_table_name = self._make_qualified_collection_name( + self.schema.loads_table_name + ) self._create_point(properties, loads_table_name) def __enter__(self) -> "QdrantClient": @@ -434,7 +465,9 @@ def _update_schema_in_storage(self, schema: Schema) -> None: "inserted_at": str(pendulum.now()), "schema": schema_str, } - version_table_name = self._make_qualified_collection_name(self.schema.version_table_name) + version_table_name = self._make_qualified_collection_name( + self.schema.version_table_name + ) self._create_point(properties, version_table_name) def _execute_schema_update(self, only_tables: Iterable[str]) -> None: @@ -443,11 +476,15 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: if not exists: self._create_collection( - full_collection_name=self._make_qualified_collection_name(table_name) + full_collection_name=self._make_qualified_collection_name( + table_name + ) ) self._update_schema_in_storage(self.schema) - def _collection_exists(self, table_name: str, qualify_table_name: bool = True) -> bool: + def _collection_exists( + self, table_name: str, qualify_table_name: bool = True + ) -> bool: try: table_name = ( self._make_qualified_collection_name(table_name) diff --git a/dlt/destinations/impl/redshift/__init__.py b/dlt/destinations/impl/redshift/__init__.py index 8a8cae84b4..c5b292e7e6 100644 --- a/dlt/destinations/impl/redshift/__init__.py +++ b/dlt/destinations/impl/redshift/__init__.py @@ -1,4 +1,7 @@ -from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal +from dlt.common.data_writers.escape import ( + escape_redshift_identifier, + escape_redshift_literal, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index eaa1968133..768a5faec8 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -3,7 +3,10 @@ from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient -from dlt.common.schema.utils import table_schema_has_type, table_schema_has_type_with_precision +from dlt.common.schema.utils import ( + table_schema_has_type, + table_schema_has_type_with_precision, +) if platform.python_implementation() == "PyPy": import psycopg2cffi as psycopg2 @@ -29,7 +32,10 @@ from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.sql_jobs import SqlMergeJob -from dlt.destinations.exceptions import DatabaseTerminalException, LoadJobTerminalException +from dlt.destinations.exceptions import ( + DatabaseTerminalException, + LoadJobTerminalException, +) from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob, LoadJob from dlt.destinations.impl.redshift import capabilities @@ -156,16 +162,16 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: if table_schema_has_type(table, "time"): raise LoadJobTerminalException( self.file_name(), - f"Redshift cannot load TIME columns from {ext} files. Switch to direct INSERT file" - " format or convert `datetime.time` objects in your data to `str` or" - " `datetime.datetime`", + f"Redshift cannot load TIME columns from {ext} files. Switch to direct" + " INSERT file format or convert `datetime.time` objects in your data" + " to `str` or `datetime.datetime`", ) if ext == "jsonl": if table_schema_has_type(table, "binary"): raise LoadJobTerminalException( self.file_name(), - "Redshift cannot load VARBYTE columns from json files. Switch to parquet to" - " load binaries.", + "Redshift cannot load VARBYTE columns from json files. Switch to" + " parquet to load binaries.", ) file_type = "FORMAT AS JSON 'auto'" dateformat = "dateformat 'auto' timeformat 'auto'" @@ -174,8 +180,9 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: if table_schema_has_type_with_precision(table, "binary"): raise LoadJobTerminalException( self.file_name(), - f"Redshift cannot load fixed width VARBYTE columns from {ext} files. Switch to" - " direct INSERT file format or use binary columns without precision.", + "Redshift cannot load fixed width VARBYTE columns from" + f" {ext} files. Switch to direct INSERT file format or use binary" + " columns without precision.", ) file_type = "PARQUET" # if table contains complex types then SUPER field will be used. @@ -229,16 +236,22 @@ class RedshiftClient(InsertValuesJobClient, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: RedshiftClientConfiguration) -> None: - sql_client = RedshiftSqlClient(config.normalize_dataset_name(schema), config.credentials) + sql_client = RedshiftSqlClient( + config.normalize_dataset_name(schema), config.credentials + ) super().__init__(schema, config, sql_client) self.sql_client = sql_client self.config: RedshiftClientConfiguration = config self.type_mapper = RedshiftTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql( + self, c: TColumnSchema, table_format: TTableFormat = None + ) -> str: hints_str = " ".join( HINT_TO_REDSHIFT_ATTR.get(h, "") for h in HINT_TO_REDSHIFT_ATTR.keys() @@ -249,7 +262,9 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" job = super().start_file_load(table, file_path, load_id) if not job: diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 5a1f7a65a9..57599d41d0 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -9,7 +9,9 @@ from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common.configuration import configspec -from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.destination.reference import ( + DestinationClientDwhWithStagingConfiguration, +) from dlt.common.utils import digest128 @@ -75,8 +77,9 @@ def parse_native_representation(self, native_value: Any) -> None: def on_resolved(self) -> None: if not self.password and not self.private_key: raise ConfigurationValueError( - "Please specify password or private_key. SnowflakeCredentials supports password and" - " private key authentication and one of those must be specified." + "Please specify password or private_key. SnowflakeCredentials supports" + " password and private key authentication and one of those must be" + " specified." ) def to_url(self) -> URL: @@ -98,7 +101,9 @@ def to_url(self) -> URL: def to_connector_params(self) -> Dict[str, Any]: private_key: Optional[bytes] = None if self.private_key: - private_key = _read_private_key(self.private_key, self.private_key_passphrase) + private_key = _read_private_key( + self.private_key, self.private_key_passphrase + ) conn_params = dict( self.query or {}, user=self.username, diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 7fafbf83b7..911f427451 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -102,7 +102,9 @@ def __init__( else "" ) file_name = ( - FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + FileStorage.get_file_name_from_file_path(bucket_path) + if bucket_path + else file_name ) from_clause = "" credentials_clause = "" @@ -149,9 +151,10 @@ def __init__( # when loading from bucket stage must be given raise LoadJobTerminalException( file_path, - f"Cannot load from bucket path {bucket_path} without a stage name. See" - " https://dlthub.com/docs/dlt-ecosystem/destinations/snowflake for" - " instructions on setting up the `stage_name`", + f"Cannot load from bucket path {bucket_path} without a stage" + " name. See" + " https://dlthub.com/docs/dlt-ecosystem/destinations/snowflake" + " for instructions on setting up the `stage_name`", ) from_clause = f"FROM @{stage_name}/" files_clause = f"FILES = ('{urlparse(bucket_path).path.lstrip('/')}')" @@ -172,8 +175,8 @@ def __init__( # PUT and COPY in one tx if local file, otherwise only copy if not bucket_path: client.execute_sql( - f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' - " AUTO_COMPRESS = FALSE" + f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE =' + " TRUE, AUTO_COMPRESS = FALSE" ) client.execute_sql(f"""COPY INTO {qualified_table_name} {from_clause} @@ -196,13 +199,17 @@ class SnowflakeClient(SqlJobClientWithStaging, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: SnowflakeClientConfiguration) -> None: - sql_client = SnowflakeSqlClient(config.normalize_dataset_name(schema), config.credentials) + sql_client = SnowflakeSqlClient( + config.normalize_dataset_name(schema), config.credentials + ) super().__init__(schema, config, sql_client) self.config: SnowflakeClientConfiguration = config self.sql_client: SnowflakeSqlClient = sql_client # type: ignore self.type_mapper = SnowflakeTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: job = super().start_file_load(table, file_path, load_id) if not job: @@ -214,7 +221,9 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> stage_name=self.config.stage_name, keep_staged_files=self.config.keep_staged_files, staging_credentials=( - self.config.staging_config.credentials if self.config.staging_config else None + self.config.staging_config.credentials + if self.config.staging_config + else None ), ) return job @@ -241,7 +250,9 @@ def _get_table_update_sql( sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) cluster_list = [ - self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("cluster") + self.capabilities.escape_identifier(c["name"]) + for c in new_columns + if c.get("cluster") ] if cluster_list: @@ -254,14 +265,18 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql( + self, c: TColumnSchema, table_format: TTableFormat = None + ) -> str: name = self.capabilities.escape_identifier(c["name"]) return ( f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" ) def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - table_name = table_name.upper() # All snowflake tables are uppercased in information schema + table_name = ( + table_name.upper() + ) # All snowflake tables are uppercased in information schema exists, table = super().get_storage_table(table_name) if not exists: return exists, table diff --git a/dlt/destinations/impl/snowflake/sql_client.py b/dlt/destinations/impl/snowflake/sql_client.py index ba932277df..71b8d00d16 100644 --- a/dlt/destinations/impl/snowflake/sql_client.py +++ b/dlt/destinations/impl/snowflake/sql_client.py @@ -29,7 +29,9 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: return super().df(chunk_size=chunk_size, **kwargs) -class SnowflakeSqlClient(SqlClientBase[snowflake_lib.SnowflakeConnection], DBTransaction): +class SnowflakeSqlClient( + SqlClientBase[snowflake_lib.SnowflakeConnection], DBTransaction +): dbapi: ClassVar[DBApi] = snowflake_lib capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -97,7 +99,9 @@ def execute_sql( @contextmanager @raise_database_error - def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: + def execute_query( + self, query: AnyStr, *args: Any, **kwargs: Any + ) -> Iterator[DBApiCursor]: curr: DBApiCursor = None db_args = args if args else kwargs if kwargs else None with self._conn.cursor() as curr: # type: ignore[assignment] @@ -155,7 +159,9 @@ def _make_database_exception(cls, ex: Exception) -> Exception: return DatabaseTransientException(ex) elif isinstance(ex, TypeError): # snowflake raises TypeError on malformed query parameters - return DatabaseTransientException(snowflake_lib.errors.ProgrammingError(str(ex))) + return DatabaseTransientException( + snowflake_lib.errors.ProgrammingError(str(ex)) + ) elif cls.is_dbapi_exception(ex): return DatabaseTransientException(ex) else: diff --git a/dlt/destinations/impl/synapse/__init__.py b/dlt/destinations/impl/synapse/__init__.py index 53dbabc090..6b3d72ca75 100644 --- a/dlt/destinations/impl/synapse/__init__.py +++ b/dlt/destinations/impl/synapse/__init__.py @@ -1,4 +1,7 @@ -from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal +from dlt.common.data_writers.escape import ( + escape_postgres_identifier, + escape_mssql_literal, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.wei import EVM_DECIMAL_PRECISION @@ -14,7 +17,9 @@ def capabilities() -> DestinationCapabilitiesContext: caps.preferred_staging_file_format = "parquet" caps.supported_staging_file_formats = ["parquet"] - caps.insert_values_writer_type = "select_union" # https://stackoverflow.com/a/77014299 + caps.insert_values_writer_type = ( # https://stackoverflow.com/a/77014299 + "select_union" + ) caps.escape_identifier = escape_postgres_identifier caps.escape_literal = escape_mssql_literal diff --git a/dlt/destinations/impl/synapse/sql_client.py b/dlt/destinations/impl/synapse/sql_client.py index 089c58e57c..dfe83cfc56 100644 --- a/dlt/destinations/impl/synapse/sql_client.py +++ b/dlt/destinations/impl/synapse/sql_client.py @@ -19,7 +19,9 @@ def drop_tables(self, *tables: str) -> None: return # Synapse does not support DROP TABLE IF EXISTS. # Workaround: use DROP TABLE and suppress non-existence errors. - statements = [f"DROP TABLE {self.make_qualified_table_name(table)};" for table in tables] + statements = [ + f"DROP TABLE {self.make_qualified_table_name(table)};" for table in tables + ] with suppress(DatabaseUndefinedRelation): self.execute_fragments(statements) diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 457e128ba0..22df6973e7 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -23,7 +23,11 @@ from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.job_client_impl import SqlJobClientBase, LoadJob, CopyRemoteFileLoadJob +from dlt.destinations.job_client_impl import ( + SqlJobClientBase, + LoadJob, + CopyRemoteFileLoadJob, +) from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.mssql.mssql import ( @@ -68,13 +72,18 @@ def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: self.active_hints.pop("unique", None) def _get_table_update_sql( - self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + self, + table_name: str, + new_columns: Sequence[TColumnSchema], + generate_alter: bool, ) -> List[str]: table = self.prepare_load_table(table_name, staging=self.in_staging_mode) table_index_type = cast(TTableIndexType, table.get(TABLE_INDEX_TYPE_HINT)) if self.in_staging_mode: final_table = self.prepare_load_table(table_name, staging=False) - final_table_index_type = cast(TTableIndexType, final_table.get(TABLE_INDEX_TYPE_HINT)) + final_table_index_type = cast( + TTableIndexType, final_table.get(TABLE_INDEX_TYPE_HINT) + ) else: final_table_index_type = table_index_type if final_table_index_type == "clustered_columnstore_index": @@ -127,10 +136,14 @@ def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] ) -> List[NewLoadJob]: if self.config.replace_strategy == "staging-optimized": - return [SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + return [ + SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client) + ] return super()._create_replace_followup_jobs(table_chain) - def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: + def prepare_load_table( + self, table_name: str, staging: bool = False + ) -> TTableSchema: table = super().prepare_load_table(table_name, staging) if staging and self.config.replace_strategy == "insert-from-staging": # Staging tables should always be heap tables, because "when you are @@ -150,14 +163,19 @@ def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSc if TABLE_INDEX_TYPE_HINT not in table: # If present in parent table, fetch hint from there. table[TABLE_INDEX_TYPE_HINT] = get_inherited_table_hint( # type: ignore[typeddict-unknown-key] - self.schema.tables, table_name, TABLE_INDEX_TYPE_HINT, allow_none=True + self.schema.tables, + table_name, + TABLE_INDEX_TYPE_HINT, + allow_none=True, ) if table[TABLE_INDEX_TYPE_HINT] is None: # type: ignore[typeddict-item] # Hint still not defined, fall back to default. table[TABLE_INDEX_TYPE_HINT] = self.config.default_table_index_type # type: ignore[typeddict-unknown-key] return table - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: job = super().start_file_load(table, file_path, load_id) if not job: assert NewReferenceJob.is_reference_job( @@ -167,7 +185,10 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> table, file_path, self.sql_client, - cast(AzureCredentialsWithoutDefaults, self.config.staging_config.credentials), + cast( + AzureCredentialsWithoutDefaults, + self.config.staging_config.credentials, + ), self.config.staging_use_msi, ) return job @@ -197,7 +218,10 @@ def generate_sql( job_client = current.pipeline().destination_client() # type: ignore[operator] with job_client.with_staging_dataset(): # get table columns from schema - columns = [c for c in job_client.schema.get_table_columns(table["name"]).values()] + columns = [ + c + for c in job_client.schema.get_table_columns(table["name"]).values() + ] # generate CREATE TABLE statement create_table_stmt = job_client._get_table_update_sql( table["name"], columns, generate_alter=False @@ -228,9 +252,9 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: # an incompatibility error. raise LoadJobTerminalException( self.file_name(), - "Synapse cannot load TIME columns from Parquet files. Switch to direct INSERT" - " file format or convert `datetime.time` objects in your data to `str` or" - " `datetime.datetime`", + "Synapse cannot load TIME columns from Parquet files. Switch to" + " direct INSERT file format or convert `datetime.time` objects in" + " your data to `str` or `datetime.datetime`", ) file_type = "PARQUET" diff --git a/dlt/destinations/impl/weaviate/exceptions.py b/dlt/destinations/impl/weaviate/exceptions.py index ee798e4e76..ad7fce7ffd 100644 --- a/dlt/destinations/impl/weaviate/exceptions.py +++ b/dlt/destinations/impl/weaviate/exceptions.py @@ -1,4 +1,7 @@ -from dlt.common.destination.exceptions import DestinationException, DestinationTerminalException +from dlt.common.destination.exceptions import ( + DestinationException, + DestinationTerminalException, +) class WeaviateBatchError(DestinationException): @@ -8,9 +11,9 @@ class WeaviateBatchError(DestinationException): class PropertyNameConflict(DestinationTerminalException): def __init__(self) -> None: super().__init__( - "Your data contains items with identical property names when compared case insensitive." - " Weaviate cannot handle such data. Please clean up your data before loading or change" - " to case insensitive naming convention. See" - " https://dlthub.com/docs/dlt-ecosystem/destinations/weaviate#names-normalization for" - " details." + "Your data contains items with identical property names when compared case" + " insensitive. Weaviate cannot handle such data. Please clean up your data" + " before loading or change to case insensitive naming convention. See" + " https://dlthub.com/docs/dlt-ecosystem/destinations/weaviate#names-normalization" + " for details." ) diff --git a/dlt/destinations/impl/weaviate/naming.py b/dlt/destinations/impl/weaviate/naming.py index f5c94c872f..8b112d05df 100644 --- a/dlt/destinations/impl/weaviate/naming.py +++ b/dlt/destinations/impl/weaviate/naming.py @@ -1,7 +1,9 @@ import re from dlt.common.normalizers.naming import NamingConvention as BaseNamingConvention -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention +from dlt.common.normalizers.naming.snake_case import ( + NamingConvention as SnakeCaseNamingConvention, +) class NamingConvention(SnakeCaseNamingConvention): diff --git a/dlt/destinations/impl/weaviate/weaviate_adapter.py b/dlt/destinations/impl/weaviate/weaviate_adapter.py index a290ac65b4..4e31968f97 100644 --- a/dlt/destinations/impl/weaviate/weaviate_adapter.py +++ b/dlt/destinations/impl/weaviate/weaviate_adapter.py @@ -62,7 +62,8 @@ def weaviate_adapter( vectorize = [vectorize] if not isinstance(vectorize, list): raise ValueError( - "vectorize must be a list of column names or a single column name as a string" + "vectorize must be a list of column names or a single column name as a" + " string" ) # create weaviate-specific vectorize hints for column_name in vectorize: @@ -76,8 +77,8 @@ def weaviate_adapter( if method not in TOKENIZATION_METHODS: allowed_methods = ", ".join(TOKENIZATION_METHODS) raise ValueError( - f"Tokenization type {method} for column {column_name} is invalid. Allowed" - f" methods are: {allowed_methods}" + f"Tokenization type {method} for column {column_name} is invalid." + f" Allowed methods are: {allowed_methods}" ) if column_name in column_hints: column_hints[column_name][TOKENIZATION_HINT] = method # type: ignore diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index 6486a75e6e..2631b06c34 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -31,17 +31,28 @@ from dlt.common.schema.typing import TColumnSchema, TColumnType from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync +from dlt.common.destination.reference import ( + TLoadJobState, + LoadJob, + JobClientBase, + WithStateSync, +) from dlt.common.data_types import TDataType from dlt.common.storages import FileStorage -from dlt.destinations.impl.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT +from dlt.destinations.impl.weaviate.weaviate_adapter import ( + VECTORIZE_HINT, + TOKENIZATION_HINT, +) from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.job_client_impl import StorageSchemaInfo, StateInfo from dlt.destinations.impl.weaviate import capabilities from dlt.destinations.impl.weaviate.configuration import WeaviateClientConfiguration -from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict, WeaviateBatchError +from dlt.destinations.impl.weaviate.exceptions import ( + PropertyNameConflict, + WeaviateBatchError, +) from dlt.destinations.type_mapping import TypeMapper @@ -99,9 +110,9 @@ def _wrap(self: JobClientBase, *args: Any, **kwargs: Any) -> Any: if status_ex.status_code == 403: raise DestinationTerminalException(status_ex) if status_ex.status_code == 422: - if "conflict for property" in str(status_ex) or "none vectorizer module" in str( + if "conflict for property" in str( status_ex - ): + ) or "none vectorizer module" in str(status_ex): raise PropertyNameConflict() raise DestinationTerminalException(status_ex) # looks like there are no more terminal exception @@ -129,7 +140,9 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: ) if "conflict for property" in message: raise PropertyNameConflict() - raise DestinationTransientException(f"Batch failed {errors} AND WILL BE RETRIED") + raise DestinationTransientException( + f"Batch failed {errors} AND WILL BE RETRIED" + ) except Exception: raise DestinationTransientException("Batch failed AND WILL BE RETRIED") @@ -188,7 +201,9 @@ def check_batch_result(results: List[StrAny]) -> None: weaviate_error_retries=weaviate.WeaviateErrorRetryConf( self.client_config.batch_retries ), - consistency_level=weaviate.ConsistencyLevel[self.client_config.batch_consistency], + consistency_level=weaviate.ConsistencyLevel[ + self.client_config.batch_consistency + ], num_workers=self.client_config.batch_workers, callback=check_batch_result, ) as batch: @@ -202,7 +217,9 @@ def check_batch_result(results: List[StrAny]) -> None: if key in data: data[key] = ensure_pendulum_datetime(data[key]).isoformat() if self.unique_identifiers: - uuid = self.generate_uuid(data, self.unique_identifiers, self.class_name) + uuid = self.generate_uuid( + data, self.unique_identifiers, self.class_name + ) else: uuid = None @@ -291,7 +308,8 @@ def make_qualified_class_name(self, table_name: str) -> str: def get_class_schema(self, table_name: str) -> Dict[str, Any]: """Get the Weaviate class schema for a table.""" return cast( - Dict[str, Any], self.db_client.schema.get(self.make_qualified_class_name(table_name)) + Dict[str, Any], + self.db_client.schema.get(self.make_qualified_class_name(table_name)), ) def create_class( @@ -315,7 +333,9 @@ def create_class( self.db_client.schema.create_class(updated_schema) - def create_class_property(self, class_name: str, prop_schema: Dict[str, Any]) -> None: + def create_class_property( + self, class_name: str, prop_schema: Dict[str, Any] + ) -> None: """Create a Weaviate class property. Args: @@ -350,7 +370,9 @@ def query_class(self, class_name: str, properties: List[str]) -> GetBuilder: Returns: A Weaviate query builder. """ - return self.db_client.query.get(self.make_qualified_class_name(class_name), properties) + return self.db_client.query.get( + self.make_qualified_class_name(class_name), properties + ) def create_object(self, obj: Dict[str, Any], class_name: str) -> None: """Create a Weaviate object. @@ -359,7 +381,9 @@ def create_object(self, obj: Dict[str, Any], class_name: str) -> None: obj: The object to create. class_name: The name of the class to create the object on. """ - self.db_client.data_object.create(obj, self.make_qualified_class_name(class_name)) + self.db_client.data_object.create( + obj, self.make_qualified_class_name(class_name) + ) def drop_storage(self) -> None: """Drop the dataset from Weaviate instance. @@ -428,7 +452,9 @@ def update_stored_schema( # Retrieve the schema from Weaviate applied_update: TSchemaTables = {} try: - schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) + schema_info = self.get_stored_schema_by_hash( + self.schema.stored_version_hash + ) except DestinationUndefinedEntity: schema_info = None if schema_info is None: @@ -450,8 +476,13 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) # TODO: detect columns where vectorization was added or removed and modify it. currently we ignore change of hints - new_columns = self.schema.get_new_table_columns(table_name, existing_columns) - logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") + new_columns = self.schema.get_new_table_columns( + table_name, existing_columns + ) + logger.info( + f"Found {len(new_columns)} updates for {table_name} in" + f" {self.schema.name}" + ) if len(new_columns) > 0: if exists: for column in new_columns: @@ -549,7 +580,9 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: except IndexError: return None - def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: + def get_stored_schema_by_hash( + self, schema_hash: str + ) -> Optional[StorageSchemaInfo]: try: record = self.get_records( self.schema.version_table_name, @@ -603,7 +636,9 @@ def make_weaviate_class_schema(self, table_name: str) -> Dict[str, Any]: } # check if any column requires vectorization - if get_columns_names_with_prop(self.schema.get_table(table_name), VECTORIZE_HINT): + if get_columns_names_with_prop( + self.schema.get_table(table_name), VECTORIZE_HINT + ): class_schema.update(self._vectorizer_config) else: class_schema.update(NON_VECTORIZED_CLASS) @@ -622,7 +657,9 @@ def _make_properties(self, table_name: str) -> List[Dict[str, Any]]: for column_name, column in self.schema.get_table_columns(table_name).items() ] - def _make_property_schema(self, column_name: str, column: TColumnSchema) -> Dict[str, Any]: + def _make_property_schema( + self, column_name: str, column: TColumnSchema + ) -> Dict[str, Any]: extra_kv = {} vectorizer_name = self._vectorizer_config["vectorizer"] @@ -646,7 +683,9 @@ def _make_property_schema(self, column_name: str, column: TColumnSchema) -> Dict **extra_kv, } - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: return LoadWeaviateJob( self.schema, table, diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 776176078e..ea402ef549 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -13,7 +13,9 @@ class InsertValuesLoadJob(LoadJob, FollowupJob): - def __init__(self, table_name: str, file_path: str, sql_client: SqlClientBase[Any]) -> None: + def __init__( + self, table_name: str, file_path: str, sql_client: SqlClientBase[Any] + ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._sql_client = sql_client # insert file content immediately @@ -44,7 +46,9 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st max_rows = self._sql_client.capabilities.max_rows_per_insert insert_sql = [] - while content := f.read(self._sql_client.capabilities.max_query_length // 2): + while content := f.read( + self._sql_client.capabilities.max_query_length // 2 + ): # read one more line in order to # 1. complete the content which ends at "random" position, not an end line # 2. to modify its ending without a need to re-allocating the 8MB of "content" @@ -69,7 +73,10 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st for chunk in chunks(values_rows, max_rows - 1): processed += len(chunk) insert_sql.append(header.format(qualified_table_name)) - if self._sql_client.capabilities.insert_values_writer_type == "default": + if ( + self._sql_client.capabilities.insert_values_writer_type + == "default" + ): insert_sql.append(values_mark) if processed == len_rows: # On the last chunk we need to add the extra row read @@ -79,12 +86,20 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st insert_sql.append("".join(chunk).strip()[:-1] + ";\n") else: # otherwise write all content in a single INSERT INTO - if self._sql_client.capabilities.insert_values_writer_type == "default": + if ( + self._sql_client.capabilities.insert_values_writer_type + == "default" + ): insert_sql.extend( [header.format(qualified_table_name), values_mark, content] ) - elif self._sql_client.capabilities.insert_values_writer_type == "select_union": - insert_sql.extend([header.format(qualified_table_name), content]) + elif ( + self._sql_client.capabilities.insert_values_writer_type + == "select_union" + ): + insert_sql.extend( + [header.format(qualified_table_name), content] + ) if until_nl: insert_sql.append(until_nl) @@ -117,7 +132,9 @@ def restore_file_load(self, file_path: str) -> LoadJob: job = EmptyLoadJob.from_file_path(file_path, "completed") return job - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: job = super().start_file_load(table, file_path, load_id) if not job: # this is using sql_client internally and will raise a right exception diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index ea0d10d11d..0b56e70a0f 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -99,7 +99,11 @@ def _string_containts_ddl_queries(self, sql: str) -> bool: return False def _split_fragments(self, sql: str) -> List[str]: - return [s + (";" if not s.endswith(";") else "") for s in sql.split(";") if s.strip()] + return [ + s + (";" if not s.endswith(";") else "") + for s in sql.split(";") + if s.strip() + ] @staticmethod def is_sql_job(file_path: str) -> bool: @@ -154,7 +158,8 @@ def __init__( sql_client: SqlClientBase[TNativeConn], ) -> None: self.version_table_schema_columns = ", ".join( - sql_client.escape_column_name(col) for col in self._VERSION_TABLE_SCHEMA_COLUMNS + sql_client.escape_column_name(col) + for col in self._VERSION_TABLE_SCHEMA_COLUMNS ) self.state_table_columns = ", ".join( sql_client.escape_column_name(col) for col in self._STATE_TABLE_COLUMNS @@ -185,8 +190,8 @@ def update_stored_schema( schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) if schema_info is None: logger.info( - f"Schema with hash {self.schema.stored_version_hash} not found in the storage." - " upgrading" + f"Schema with hash {self.schema.stored_version_hash} not found in the" + " storage. upgrading" ) with self.maybe_ddl_transaction(): @@ -219,10 +224,14 @@ def should_truncate_table_before_load(self, table: TTableSchema) -> bool: and self.config.replace_strategy == "truncate-and-insert" ) - def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_append_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: return [] - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: return [SqlMergeJob.from_table_chain(table_chain, self.sql_client)] def _create_replace_followup_jobs( @@ -231,7 +240,9 @@ def _create_replace_followup_jobs( jobs: List[NewLoadJob] = [] if self.config.replace_strategy in ["insert-from-staging", "staging-optimized"]: jobs.append( - SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True}) + SqlStagingCopyJob.from_table_chain( + table_chain, self.sql_client, {"replace": True} + ) ) return jobs @@ -249,7 +260,9 @@ def create_table_chain_completed_followup_jobs( jobs.extend(self._create_replace_followup_jobs(table_chain)) return jobs - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" if SqlLoadJob.is_sql_job(file_path): # execute sql load job @@ -276,8 +289,8 @@ def complete_load(self, load_id: str) -> None: name = self.sql_client.make_qualified_table_name(self.schema.loads_table_name) now_ts = pendulum.now() self.sql_client.execute_sql( - f"INSERT INTO {name}(load_id, schema_name, status, inserted_at, schema_version_hash)" - " VALUES(%s, %s, %s, %s, %s);", + f"INSERT INTO {name}(load_id, schema_name, status, inserted_at," + " schema_version_hash) VALUES(%s, %s, %s, %s, %s);", load_id, self.schema.name, 0, @@ -290,7 +303,10 @@ def __enter__(self) -> "SqlJobClientBase": return self def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, ) -> None: self.sql_client.close_connection() @@ -312,9 +328,9 @@ def _null_to_bool(v: str) -> bool: raise ValueError(v) fields = self._get_storage_table_query_columns() - db_params = self.sql_client.make_qualified_table_name(table_name, escape=False).split( - ".", 3 - ) + db_params = self.sql_client.make_qualified_table_name( + table_name, escape=False + ).split(".", 3) query = f""" SELECT {",".join(fields)} FROM INFORMATION_SCHEMA.COLUMNS @@ -334,7 +350,9 @@ def _null_to_bool(v: str) -> bool: numeric_precision = ( c[3] if self.capabilities.schema_supports_numeric_precision else None ) - numeric_scale = c[4] if self.capabilities.schema_supports_numeric_precision else None + numeric_scale = ( + c[4] if self.capabilities.schema_supports_numeric_precision else None + ) schema_c: TColumnSchemaBase = { "name": c[0], "nullable": _null_to_bool(c[2]), @@ -352,18 +370,22 @@ def _from_db_type( def get_stored_schema(self) -> StorageSchemaInfo: name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) query = ( - f"SELECT {self.version_table_schema_columns} FROM {name} WHERE schema_name = %s ORDER" - " BY inserted_at DESC;" + f"SELECT {self.version_table_schema_columns} FROM {name} WHERE schema_name" + " = %s ORDER BY inserted_at DESC;" ) return self._row_to_schema_info(query, self.schema.name) def get_stored_state(self, pipeline_name: str) -> StateInfo: - state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name) - loads_table = self.sql_client.make_qualified_table_name(self.schema.loads_table_name) + state_table = self.sql_client.make_qualified_table_name( + self.schema.state_table_name + ) + loads_table = self.sql_client.make_qualified_table_name( + self.schema.loads_table_name + ) query = ( - f"SELECT {self.state_table_columns} FROM {state_table} AS s JOIN {loads_table} AS l ON" - " l.load_id = s._dlt_load_id WHERE pipeline_name = %s AND l.status = 0 ORDER BY" - " created_at DESC" + f"SELECT {self.state_table_columns} FROM {state_table} AS s JOIN" + f" {loads_table} AS l ON l.load_id = s._dlt_load_id WHERE pipeline_name =" + " %s AND l.status = 0 ORDER BY created_at DESC" ) with self.sql_client.execute_query(query, pipeline_name) as cur: row = cur.fetchone() @@ -382,7 +404,10 @@ def get_stored_state(self, pipeline_name: str) -> StateInfo: def get_stored_schema_by_hash(self, version_hash: str) -> StorageSchemaInfo: name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) - query = f"SELECT {self.version_table_schema_columns} FROM {name} WHERE version_hash = %s;" + query = ( + f"SELECT {self.version_table_schema_columns} FROM {name} WHERE version_hash" + " = %s;" + ) return self._row_to_schema_info(query, version_hash) def _execute_schema_update_sql(self, only_tables: Iterable[str]) -> TSchemaTables: @@ -416,7 +441,9 @@ def _build_schema_update_sql( new_columns = self._create_table_update(table_name, storage_table) if len(new_columns) > 0: # build and add sql to execute - sql_statements = self._get_table_update_sql(table_name, new_columns, exists) + sql_statements = self._get_table_update_sql( + table_name, new_columns, exists + ) for sql in sql_statements: if not sql.endswith(";"): sql += ";" @@ -433,10 +460,16 @@ def _make_add_column_sql( self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None ) -> List[str]: """Make one or more ADD COLUMN sql clauses to be joined in ALTER TABLE statement(s)""" - return [f"ADD COLUMN {self._get_column_def_sql(c, table_format)}" for c in new_columns] + return [ + f"ADD COLUMN {self._get_column_def_sql(c, table_format)}" + for c in new_columns + ] def _get_table_update_sql( - self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + self, + table_name: str, + new_columns: Sequence[TColumnSchema], + generate_alter: bool, ) -> List[str]: # build sql canonical_name = self.sql_client.make_qualified_table_name(table_name) @@ -446,7 +479,9 @@ def _get_table_update_sql( if not generate_alter: # build CREATE sql = f"CREATE TABLE {canonical_name} (\n" - sql += ",\n".join([self._get_column_def_sql(c, table_format) for c in new_columns]) + sql += ",\n".join( + [self._get_column_def_sql(c, table_format) for c in new_columns] + ) sql += ")" sql_result.append(sql) else: @@ -458,7 +493,10 @@ def _get_table_update_sql( else: # build ALTER as a separate statement for each column (redshift limitation) sql_result.extend( - [sql_base + col_statement for col_statement in add_column_statements] + [ + sql_base + col_statement + for col_statement in add_column_statements + ] ) # scan columns to get hints @@ -473,20 +511,22 @@ def _get_table_update_sql( ] if hint == "not_null": logger.warning( - f"Column(s) {hint_columns} with NOT NULL are being added to existing" - f" table {canonical_name}. If there's data in the table the operation" - " will fail." + f"Column(s) {hint_columns} with NOT NULL are being added to" + f" existing table {canonical_name}. If there's data in the" + " table the operation will fail." ) else: logger.warning( - f"Column(s) {hint_columns} with hint {hint} are being added to existing" - f" table {canonical_name}. Several hint types may not be added to" - " existing tables." + f"Column(s) {hint_columns} with hint {hint} are being added" + f" to existing table {canonical_name}. Several hint types" + " may not be added to existing tables." ) return sql_result @abstractmethod - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql( + self, c: TColumnSchema, table_format: TTableFormat = None + ) -> str: pass @staticmethod @@ -498,7 +538,9 @@ def _create_table_update( ) -> Sequence[TColumnSchema]: # compare table with stored schema and produce delta updates = self.schema.get_new_table_columns(table_name, storage_columns) - logger.info(f"Found {len(updates)} updates for {table_name} in {self.schema.name}") + logger.info( + f"Found {len(updates)} updates for {table_name} in {self.schema.name}" + ) return updates def _row_to_schema_info(self, query: str, *args: Any) -> StorageSchemaInfo: @@ -523,14 +565,18 @@ def _row_to_schema_info(self, query: str, *args: Any) -> StorageSchemaInfo: # make utc datetime inserted_at = pendulum.instance(row[4]) - return StorageSchemaInfo(row[0], row[1], row[2], row[3], inserted_at, schema_str) + return StorageSchemaInfo( + row[0], row[1], row[2], row[3], inserted_at, schema_str + ) def _replace_schema_in_storage(self, schema: Schema) -> None: """ Save the given schema in storage and remove all previous versions with the same name """ name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) - self.sql_client.execute_sql(f"DELETE FROM {name} WHERE schema_name = %s;", schema.name) + self.sql_client.execute_sql( + f"DELETE FROM {name} WHERE schema_name = %s;", schema.name + ) self._update_schema_in_storage(schema) def _update_schema_in_storage(self, schema: Schema) -> None: @@ -540,7 +586,9 @@ def _update_schema_in_storage(self, schema: Schema) -> None: schema_bytes = schema_str.encode("utf-8") if len(schema_bytes) > self.capabilities.max_text_data_type_length: # compress and to base64 - schema_str = base64.b64encode(zlib.compress(schema_bytes, level=9)).decode("ascii") + schema_str = base64.b64encode(zlib.compress(schema_bytes, level=9)).decode( + "ascii" + ) self._commit_schema_update(schema, schema_str) def _commit_schema_update(self, schema: Schema, schema_str: str) -> None: @@ -548,8 +596,8 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None: name = self.sql_client.make_qualified_table_name(self.schema.version_table_name) # values = schema.version_hash, schema.name, schema.version, schema.ENGINE_VERSION, str(now_ts), schema_str self.sql_client.execute_sql( - f"INSERT INTO {name}({self.version_table_schema_columns}) VALUES (%s, %s, %s, %s, %s," - " %s);", + f"INSERT INTO {name}({self.version_table_schema_columns}) VALUES (%s, %s," + " %s, %s, %s, %s);", schema.stored_version_hash, schema.name, schema.version, diff --git a/dlt/destinations/job_impl.py b/dlt/destinations/job_impl.py index 7a6b98544c..7b38cb3481 100644 --- a/dlt/destinations/job_impl.py +++ b/dlt/destinations/job_impl.py @@ -3,12 +3,19 @@ from dlt.common.storages import FileStorage -from dlt.common.destination.reference import NewLoadJob, FollowupJob, TLoadJobState, LoadJob +from dlt.common.destination.reference import ( + NewLoadJob, + FollowupJob, + TLoadJobState, + LoadJob, +) from dlt.common.storages.load_storage import ParsedLoadJobFileName class EmptyLoadJobWithoutFollowup(LoadJob): - def __init__(self, file_name: str, status: TLoadJobState, exception: str = None) -> None: + def __init__( + self, file_name: str, status: TLoadJobState, exception: str = None + ) -> None: self._status = status self._exception = exception super().__init__(file_name) @@ -17,7 +24,11 @@ def __init__(self, file_name: str, status: TLoadJobState, exception: str = None) def from_file_path( cls, file_path: str, status: TLoadJobState, message: str = None ) -> "EmptyLoadJobWithoutFollowup": - return cls(FileStorage.get_file_name_from_file_path(file_path), status, exception=message) + return cls( + FileStorage.get_file_name_from_file_path(file_path), + status, + exception=message, + ) def state(self) -> TLoadJobState: return self._status @@ -44,7 +55,11 @@ def new_file_path(self) -> str: class NewReferenceJob(NewLoadJobImpl): def __init__( - self, file_name: str, status: TLoadJobState, exception: str = None, remote_path: str = None + self, + file_name: str, + status: TLoadJobState, + exception: str = None, + remote_path: str = None, ) -> None: file_name = os.path.splitext(file_name)[0] + ".reference" super().__init__(file_name, status, exception) diff --git a/dlt/destinations/path_utils.py b/dlt/destinations/path_utils.py index 047cb274e0..d40da66563 100644 --- a/dlt/destinations/path_utils.py +++ b/dlt/destinations/path_utils.py @@ -7,7 +7,14 @@ from dlt.destinations.exceptions import InvalidFilesystemLayout, CantExtractTablePrefix # TODO: ensure layout only has supported placeholders -SUPPORTED_PLACEHOLDERS = {"schema_name", "table_name", "load_id", "file_id", "ext", "curr_date"} +SUPPORTED_PLACEHOLDERS = { + "schema_name", + "table_name", + "load_id", + "file_id", + "ext", + "curr_date", +} SUPPORTED_TABLE_NAME_PREFIX_PLACEHOLDERS = ("schema_name",) @@ -45,7 +52,9 @@ def create_path( def get_table_prefix_layout( layout: str, - supported_prefix_placeholders: Sequence[str] = SUPPORTED_TABLE_NAME_PREFIX_PLACEHOLDERS, + supported_prefix_placeholders: Sequence[ + str + ] = SUPPORTED_TABLE_NAME_PREFIX_PLACEHOLDERS, ) -> str: """get layout fragment that defines positions of the table, cutting other placeholders @@ -59,16 +68,24 @@ def get_table_prefix_layout( table_name_index = placeholders.index("table_name") # fail if any other prefix is defined before table_name - if [p for p in placeholders[:table_name_index] if p not in supported_prefix_placeholders]: + if [ + p + for p in placeholders[:table_name_index] + if p not in supported_prefix_placeholders + ]: if len(supported_prefix_placeholders) == 0: details = ( - "No other placeholders are allowed before {table_name} but you have %s present. " + "No other placeholders are allowed before {table_name} but you have %s" + " present. " % placeholders[:table_name_index] ) else: - details = "Only %s are allowed before {table_name} but you have %s present. " % ( - supported_prefix_placeholders, - placeholders[:table_name_index], + details = ( + "Only %s are allowed before {table_name} but you have %s present. " + % ( + supported_prefix_placeholders, + placeholders[:table_name_index], + ) ) raise CantExtractTablePrefix(layout, details) @@ -76,6 +93,8 @@ def get_table_prefix_layout( # this is to prevent selecting tables that have the same starting name prefix = layout[: layout.index("{table_name}") + 13] if prefix[-1] == "{": - raise CantExtractTablePrefix(layout, "A separator is required after a {table_name}. ") + raise CantExtractTablePrefix( + layout, "A separator is required after a {table_name}. " + ) return prefix diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 9d872a238e..187e6f4b32 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -26,7 +26,13 @@ LoadClientNotConnected, DatabaseTerminalException, ) -from dlt.destinations.typing import DBApi, TNativeConn, DBApiCursor, DataFrame, DBTransaction +from dlt.destinations.typing import ( + DBApi, + TNativeConn, + DBApiCursor, + DataFrame, + DBTransaction, +) class SqlClientBase(ABC, Generic[TNativeConn]): @@ -62,7 +68,10 @@ def __enter__(self) -> "SqlClientBase[TNativeConn]": return self def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, ) -> None: self.close_connection() @@ -87,17 +96,22 @@ def create_dataset(self) -> None: self.execute_sql("CREATE SCHEMA %s" % self.fully_qualified_dataset_name()) def drop_dataset(self) -> None: - self.execute_sql("DROP SCHEMA %s CASCADE;" % self.fully_qualified_dataset_name()) + self.execute_sql( + "DROP SCHEMA %s CASCADE;" % self.fully_qualified_dataset_name() + ) def truncate_tables(self, *tables: str) -> None: - statements = [self._truncate_table_sql(self.make_qualified_table_name(t)) for t in tables] + statements = [ + self._truncate_table_sql(self.make_qualified_table_name(t)) for t in tables + ] self.execute_many(statements) def drop_tables(self, *tables: str) -> None: if not tables: return statements = [ - f"DROP TABLE IF EXISTS {self.make_qualified_table_name(table)};" for table in tables + f"DROP TABLE IF EXISTS {self.make_qualified_table_name(table)};" + for table in tables ] self.execute_many(statements) @@ -227,7 +241,9 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: if chunk_size is None: return _wrap_result(self.native_cursor.fetchall(), columns, **kwargs) else: - df = _wrap_result(self.native_cursor.fetchmany(chunk_size), columns, **kwargs) + df = _wrap_result( + self.native_cursor.fetchmany(chunk_size), columns, **kwargs + ) # if no rows return None if df.shape[0] == 0: return None @@ -264,6 +280,8 @@ def _wrap(self: SqlClientBase[Any], *args: Any, **kwargs: Any) -> Any: try: return f(self, *args, **kwargs) except Exception as ex: - raise DestinationConnectionError(type(self).__name__, self.dataset_name, str(ex), ex) + raise DestinationConnectionError( + type(self).__name__, self.dataset_name, str(ex), ex + ) return _wrap # type: ignore diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 91be3a60c9..6fca7b6db2 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -57,7 +57,10 @@ def from_table_chain( except Exception: # return failed job tables_str = yaml.dump( - table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False + table_chain, + allow_unicode=True, + default_flow_style=False, + sort_keys=False, ) job = cls(file_info.file_name(), "failed", pretty_format_exception()) job._save_text_file("\n".join([cls.failed_text, tables_str])) @@ -76,7 +79,9 @@ def generate_sql( class SqlStagingCopyJob(SqlBaseJob): """Generates a list of sql statements that copy the data from staging dataset into destination dataset.""" - failed_text: str = "Tried to generate a staging copy sql job for the following tables:" + failed_text: str = ( + "Tried to generate a staging copy sql job for the following tables:" + ) @classmethod def _generate_clone_sql( @@ -116,7 +121,8 @@ def _generate_insert_sql( if params["replace"]: sql.append(sql_client._truncate_table_sql(table_name)) sql.append( - f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name};" + f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM" + f" {staging_table_name};" ) return sql @@ -167,11 +173,15 @@ def _gen_key_table_clauses( if primary_keys or merge_keys: if primary_keys: clauses.append( - " AND ".join(["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in primary_keys]) + " AND ".join( + ["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in primary_keys] + ) ) if merge_keys: clauses.append( - " AND ".join(["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in merge_keys]) + " AND ".join( + ["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in merge_keys] + ) ) return clauses or ["1=1"] @@ -188,8 +198,9 @@ def gen_key_table_clauses( A list of clauses may be returned for engines that do not support OR in subqueries. Like BigQuery """ return [ - f"FROM {root_table_name} as d WHERE EXISTS (SELECT 1 FROM {staging_root_table_name} as" - f" s WHERE {' OR '.join([c.format(d='d',s='s') for c in key_clauses])})" + f"FROM {root_table_name} as d WHERE EXISTS (SELECT 1 FROM" + f" {staging_root_table_name} as s WHERE" + f" {' OR '.join([c.format(d='d',s='s') for c in key_clauses])})" ] @classmethod @@ -205,7 +216,9 @@ def gen_delete_temp_table_sql( select_statement = f"SELECT d.{unique_column} {key_table_clauses[0]}" sql.append(cls._to_temp_table(select_statement, temp_table_name)) for clause in key_table_clauses[1:]: - sql.append(f"INSERT INTO {temp_table_name} SELECT {unique_column} {clause};") + sql.append( + f"INSERT INTO {temp_table_name} SELECT {unique_column} {clause};" + ) return sql, temp_table_name @classmethod @@ -294,7 +307,10 @@ def gen_insert_temp_table_sql( ) else: # don't deduplicate - select_sql = f"SELECT {unique_column} FROM {staging_root_table_name} WHERE {condition}" + select_sql = ( + f"SELECT {unique_column} FROM {staging_root_table_name} WHERE" + f" {condition}" + ) return [cls._to_temp_table(select_sql, temp_table_name)], temp_table_name @classmethod @@ -339,14 +355,20 @@ def gen_merge_sql( escape_id = sql_client.capabilities.escape_identifier escape_lit = sql_client.capabilities.escape_literal if escape_id is None: - escape_id = DestinationCapabilitiesContext.generic_capabilities().escape_identifier + escape_id = ( + DestinationCapabilitiesContext.generic_capabilities().escape_identifier + ) if escape_lit is None: - escape_lit = DestinationCapabilitiesContext.generic_capabilities().escape_literal + escape_lit = ( + DestinationCapabilitiesContext.generic_capabilities().escape_literal + ) # get top level table full identifiers root_table_name = sql_client.make_qualified_table_name(root_table["name"]) with sql_client.with_staging_dataset(staging=True): - staging_root_table_name = sql_client.make_qualified_table_name(root_table["name"]) + staging_root_table_name = sql_client.make_qualified_table_name( + root_table["name"] + ) # get merge and primary keys from top level primary_keys = list( map( @@ -383,14 +405,15 @@ def gen_merge_sql( sql_client.fully_qualified_dataset_name(), staging_root_table_name, [t["name"] for t in table_chain], - f"There is no unique column (ie _dlt_id) in top table {root_table['name']} so" - " it is not possible to link child tables to it.", + "There is no unique column (ie _dlt_id) in top table" + f" {root_table['name']} so it is not possible to link child tables" + " to it.", ) # get first unique column unique_column = escape_id(unique_columns[0]) # create temp table with unique identifier - create_delete_temp_table_sql, delete_temp_table_name = cls.gen_delete_temp_table_sql( - unique_column, key_table_clauses + create_delete_temp_table_sql, delete_temp_table_name = ( + cls.gen_delete_temp_table_sql(unique_column, key_table_clauses) ) sql.extend(create_delete_temp_table_sql) @@ -405,20 +428,26 @@ def gen_merge_sql( staging_root_table_name, [t["name"] for t in table_chain], "There is no root foreign key (ie _dlt_root_id) in child table" - f" {table['name']} so it is not possible to refer to top level table" - f" {root_table['name']} unique column {unique_column}", + f" {table['name']} so it is not possible to refer to top level" + f" table {root_table['name']} unique column {unique_column}", ) root_key_column = escape_id(root_key_columns[0]) sql.append( cls.gen_delete_from_sql( - table_name, root_key_column, delete_temp_table_name, unique_column + table_name, + root_key_column, + delete_temp_table_name, + unique_column, ) ) # delete from top table now that child tables have been prcessed sql.append( cls.gen_delete_from_sql( - root_table_name, unique_column, delete_temp_table_name, unique_column + root_table_name, + unique_column, + delete_temp_table_name, + unique_column, ) ) @@ -430,7 +459,9 @@ def gen_merge_sql( not_deleted_cond = f"{escape_id(hard_delete_col)} IS NULL" if root_table["columns"][hard_delete_col]["data_type"] == "bool": # only True values indicate a delete for boolean columns - not_deleted_cond += f" OR {escape_id(hard_delete_col)} = {escape_lit(False)}" + not_deleted_cond += ( + f" OR {escape_id(hard_delete_col)} = {escape_lit(False)}" + ) # get dedup sort information dedup_sort = get_dedup_sort_tuple(root_table) @@ -438,7 +469,9 @@ def gen_merge_sql( insert_temp_table_name: str = None if len(table_chain) > 1: if len(primary_keys) > 0 or hard_delete_col is not None: - condition_columns = [hard_delete_col] if not_deleted_cond is not None else None + condition_columns = ( + [hard_delete_col] if not_deleted_cond is not None else None + ) ( create_insert_temp_table_sql, insert_temp_table_name, @@ -464,12 +497,18 @@ def gen_merge_sql( and table.get("parent") is not None # child table and hard_delete_col is not None ): - uniq_column = unique_column if table.get("parent") is None else root_key_column - insert_cond = f"{uniq_column} IN (SELECT * FROM {insert_temp_table_name})" + uniq_column = ( + unique_column if table.get("parent") is None else root_key_column + ) + insert_cond = ( + f"{uniq_column} IN (SELECT * FROM {insert_temp_table_name})" + ) columns = list(map(escape_id, get_columns_names_with_prop(table, "name"))) col_str = ", ".join(columns) - select_sql = f"SELECT {col_str} FROM {staging_table_name} WHERE {insert_cond}" + select_sql = ( + f"SELECT {col_str} FROM {staging_table_name} WHERE {insert_cond}" + ) if len(primary_keys) > 0 and len(table_chain) == 1: # without child tables we deduplicate inside the query instead of using a temp table select_sql = cls.gen_select_from_dedup_sql( diff --git a/dlt/destinations/type_mapping.py b/dlt/destinations/type_mapping.py index dcd938b33c..f7dfd0eee7 100644 --- a/dlt/destinations/type_mapping.py +++ b/dlt/destinations/type_mapping.py @@ -32,7 +32,9 @@ def to_db_datetime_type( # Override in subclass if db supports other timestamp types (e.g. with different time resolutions) return None - def to_db_time_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: + def to_db_time_type( + self, precision: Optional[int], table_format: TTableFormat = None + ) -> str: # Override in subclass if db supports other time types (e.g. with different time resolutions) return None @@ -42,7 +44,9 @@ def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> return self.sct_to_unbound_dbt["decimal"] return self.sct_to_dbt["decimal"] % (precision_tup[0], precision_tup[1]) - def to_db_type(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: + def to_db_type( + self, column: TColumnSchema, table_format: TTableFormat = None + ) -> str: precision, scale = column.get("precision"), column.get("scale") sc_t = column["data_type"] if sc_t == "bigint": @@ -112,6 +116,8 @@ def from_db_type( ) -> TColumnType: return without_none( dict( # type: ignore[return-value] - data_type=self.dbt_to_sct.get(db_type, "text"), precision=precision, scale=scale + data_type=self.dbt_to_sct.get(db_type, "text"), + precision=precision, + scale=scale, ) ) diff --git a/dlt/extract/concurrency.py b/dlt/extract/concurrency.py index 6a330b2645..d6339d4773 100644 --- a/dlt/extract/concurrency.py +++ b/dlt/extract/concurrency.py @@ -10,7 +10,12 @@ from dlt.common.exceptions import PipelineException from dlt.common.configuration.container import Container from dlt.common.runtime.signals import sleep -from dlt.extract.items import DataItemWithMeta, TItemFuture, ResolvablePipeItem, FuturePipeItem +from dlt.extract.items import ( + DataItemWithMeta, + TItemFuture, + ResolvablePipeItem, + FuturePipeItem, +) from dlt.extract.exceptions import ( DltSourceException, @@ -27,7 +32,10 @@ class FuturesPool: """ def __init__( - self, workers: int = 5, poll_interval: float = 0.01, max_parallel_items: int = 20 + self, + workers: int = 5, + poll_interval: float = 0.01, + max_parallel_items: int = 20, ) -> None: self.futures: Dict[TItemFuture, FuturePipeItem] = {} self._thread_pool: ThreadPoolExecutor = None @@ -96,7 +104,9 @@ def submit(self, pipe_item: ResolvablePipeItem) -> TItemFuture: """ # Sanity check, negative free slots means there's a bug somewhere - assert self.free_slots >= 0, "Worker pool has negative free slots, this should never happen" + assert ( + self.free_slots >= 0 + ), "Worker pool has negative free slots, this should never happen" if self.free_slots == 0: # Wait until some future is completed to ensure there's a free slot @@ -136,7 +146,13 @@ def _resolve_future(self, future: TItemFuture) -> Optional[ResolvablePipeItem]: return None # Raise if any future fails if isinstance( - ex, (PipelineException, ExtractorException, DltSourceException, PipeException) + ex, + ( + PipelineException, + ExtractorException, + DltSourceException, + PipeException, + ), ): raise ex raise ResourceExtractionError(pipe.name, future, str(ex), "future") from ex @@ -152,7 +168,9 @@ def _resolve_future(self, future: TItemFuture) -> Optional[ResolvablePipeItem]: def _next_done_future(self) -> Optional[TItemFuture]: """Get the done future in the pool (if any). This does not block.""" - return next((fut for fut in self.futures if fut.done() and not fut.cancelled()), None) + return next( + (fut for fut in self.futures if fut.done() and not fut.cancelled()), None + ) def resolve_next_future( self, use_configured_timeout: bool = False @@ -224,7 +242,9 @@ def stop_background_loop(loop: asyncio.AbstractEventLoop) -> None: ) wait_for_futures([future]) - self._async_pool.call_soon_threadsafe(stop_background_loop, self._async_pool) + self._async_pool.call_soon_threadsafe( + stop_background_loop, self._async_pool + ) self._async_pool_thread.join() self._async_pool = None diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 28a2aca633..0dcc8c95b7 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -20,7 +20,12 @@ ) from typing_extensions import TypeVar -from dlt.common.configuration import with_config, get_fun_spec, known_sections, configspec +from dlt.common.configuration import ( + with_config, + get_fun_spec, + known_sections, + configspec, +) from dlt.common.configuration.container import Container from dlt.common.configuration.exceptions import ContextDefaultCannotBeCreated from dlt.common.configuration.resolve import inject_section @@ -118,7 +123,9 @@ def source( schema_contract: TSchemaContract = None, spec: Type[BaseConfiguration] = None, _impl_cls: Type[TDltSourceImpl] = DltSource, # type: ignore[assignment] -) -> Callable[[Callable[TSourceFunParams, Any]], Callable[TSourceFunParams, TDltSourceImpl]]: ... +) -> Callable[ + [Callable[TSourceFunParams, Any]], Callable[TSourceFunParams, TDltSourceImpl] +]: ... def source( @@ -195,7 +202,9 @@ def decorator( if not schema: # load the schema from file with name_schema.yaml/json from the same directory, the callable resides OR create new default schema - schema = _maybe_load_schema_for_callable(f, effective_name) or Schema(effective_name) + schema = _maybe_load_schema_for_callable(f, effective_name) or Schema( + effective_name + ) if name and name != schema.name: raise ExplicitSourceNameInvalid(name, schema.name) @@ -216,7 +225,9 @@ def _eval_rv(_rv: Any) -> TDltSourceImpl: _rv = list(_rv) # convert to source - s = _impl_cls.from_data(schema.clone(update_normalizers=True), source_section, _rv) + s = _impl_cls.from_data( + schema.clone(update_normalizers=True), source_section, _rv + ) # apply hints if max_table_nesting is not None: s.max_table_nesting = max_table_nesting @@ -231,7 +242,9 @@ def _wrap(*args: Any, **kwargs: Any) -> TDltSourceImpl: with Container().injectable_context(SourceSchemaInjectableContext(schema)): # configurations will be accessed in this section in the source proxy = Container()[PipelineContext] - pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name + pipeline_name = ( + None if not proxy.is_active() else proxy.pipeline().pipeline_name + ) with inject_section( ConfigSectionContext( pipeline_name=pipeline_name, @@ -250,7 +263,9 @@ async def _wrap_coro(*args: Any, **kwargs: Any) -> TDltSourceImpl: with Container().injectable_context(SourceSchemaInjectableContext(schema)): # configurations will be accessed in this section in the source proxy = Container()[PipelineContext] - pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name + pipeline_name = ( + None if not proxy.is_active() else proxy.pipeline().pipeline_name + ) with inject_section( ConfigSectionContext( pipeline_name=pipeline_name, @@ -264,7 +279,9 @@ async def _wrap_coro(*args: Any, **kwargs: Any) -> TDltSourceImpl: # get spec for wrapped function SPEC = get_fun_spec(conf_f) # get correct wrapper - wrapper = _wrap_coro if inspect.iscoroutinefunction(inspect.unwrap(f)) else _wrap + wrapper = ( + _wrap_coro if inspect.iscoroutinefunction(inspect.unwrap(f)) else _wrap + ) # store the source information _SOURCES[_wrap.__qualname__] = SourceInfo(SPEC, wrapper, func_module) if inspect.iscoroutinefunction(inspect.unwrap(f)): @@ -332,7 +349,9 @@ def resource( spec: Type[BaseConfiguration] = None, parallelized: bool = False, standalone: Literal[True] = True, -) -> Callable[[Callable[TResourceFunParams, Any]], Callable[TResourceFunParams, DltResource]]: ... +) -> Callable[ + [Callable[TResourceFunParams, Any]], Callable[TResourceFunParams, DltResource] +]: ... @overload @@ -436,7 +455,10 @@ def resource( """ def make_resource( - _name: str, _section: str, _data: Any, incremental: IncrementalResourceWrapper = None + _name: str, + _section: str, + _data: Any, + incremental: IncrementalResourceWrapper = None, ) -> DltResource: table_template = make_hints( table_name, @@ -501,7 +523,9 @@ def decorator( ) is_inner_resource = is_inner_callable(f) if conf_f != incr_f and is_inner_resource and not standalone: - raise ResourceInnerCallableConfigWrapDisallowed(resource_name, source_section) + raise ResourceInnerCallableConfigWrapDisallowed( + resource_name, source_section + ) # get spec for wrapped function SPEC = get_fun_spec(conf_f) @@ -517,7 +541,9 @@ def decorator( @wraps(conf_f) def _wrap(*args: Any, **kwargs: Any) -> DltResource: - _, mod_sig, bound_args = simulate_func_call(conf_f, skip_args, *args, **kwargs) + _, mod_sig, bound_args = simulate_func_call( + conf_f, skip_args, *args, **kwargs + ) actual_resource_name = ( name(bound_args.arguments) if callable(name) else resource_name ) @@ -569,7 +595,9 @@ def transformer( selected: bool = True, spec: Type[BaseConfiguration] = None, parallelized: bool = False, -) -> Callable[[Callable[Concatenate[TDataItem, TResourceFunParams], Any]], DltResource]: ... +) -> Callable[ + [Callable[Concatenate[TDataItem, TResourceFunParams], Any]], DltResource +]: ... @overload @@ -699,8 +727,8 @@ def transformer( """ if isinstance(f, DltResource): raise ValueError( - "Please pass `data_from=` argument as keyword argument. The only positional argument to" - " transformer is the decorated function" + "Please pass `data_from=` argument as keyword argument. The only positional" + " argument to transformer is the decorated function" ) return resource( # type: ignore diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index c3a20e72e5..1f389cd3e2 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -43,9 +43,9 @@ def __init__(self, pipe_name: str, has_parent: bool) -> None: self.has_parent = has_parent if has_parent: msg = ( - f"A pipe created from transformer {pipe_name} is unbound or its parent is unbound" - " or empty. Provide a resource in `data_from` argument or bind resources with |" - " operator." + f"A pipe created from transformer {pipe_name} is unbound or its parent" + " is unbound or empty. Provide a resource in `data_from` argument or" + " bind resources with | operator." ) else: msg = "Pipe is empty and does not have a resource at its head" @@ -53,14 +53,16 @@ def __init__(self, pipe_name: str, has_parent: bool) -> None: class InvalidStepFunctionArguments(PipeException): - def __init__(self, pipe_name: str, func_name: str, sig: Signature, call_error: str) -> None: + def __init__( + self, pipe_name: str, func_name: str, sig: Signature, call_error: str + ) -> None: self.func_name = func_name self.sig = sig super().__init__( pipe_name, f"Unable to call {func_name}: {call_error}. The mapping/filtering function" - f" {func_name} requires first argument to take data item and optional second argument" - f" named 'meta', but the signature is {sig}", + f" {func_name} requires first argument to take data item and optional" + f" second argument named 'meta', but the signature is {sig}", ) @@ -75,28 +77,30 @@ def __init__(self, pipe_name: str, gen: Any, msg: str, kind: str) -> None: ) super().__init__( pipe_name, - f"extraction of resource {pipe_name} in {kind} {self.func_name} caused an exception:" - f" {msg}", + f"extraction of resource {pipe_name} in {kind} {self.func_name} caused an" + f" exception: {msg}", ) class PipeGenInvalid(PipeException): def __init__(self, pipe_name: str, gen: Any) -> None: msg = ( - "A pipe generator element must be an Iterator (ie. list or generator function)." - " Generator element is typically created from a `data` argument to pipeline.run or" - " extract method." + "A pipe generator element must be an Iterator (ie. list or generator" + " function). Generator element is typically created from a `data` argument" + " to pipeline.run or extract method." ) msg += ( - " dlt will evaluate functions that were passed as data argument. If you passed a" - " function the returned data type is not iterable. " + " dlt will evaluate functions that were passed as data argument. If you" + " passed a function the returned data type is not iterable. " ) type_name = str(type(gen)) msg += f" Generator type is {type_name}." if "DltSource" in type_name: msg += " Did you pass a @dlt.source decorated function without calling it?" if "DltResource" in type_name: - msg += " Did you pass a function that returns dlt.resource without calling it?" + msg += ( + " Did you pass a function that returns dlt.resource without calling it?" + ) super().__init__(pipe_name, msg) @@ -125,8 +129,8 @@ class DynamicNameNotStandaloneResource(DltResourceException): def __init__(self, resource_name: str) -> None: super().__init__( resource_name, - "You must set the resource as standalone to be able to dynamically set its name based" - " on call arguments", + "You must set the resource as standalone to be able to dynamically set its" + " name based on call arguments", ) @@ -139,18 +143,21 @@ class ResourceNotFoundError(DltResourceException, KeyError): def __init__(self, resource_name: str, context: str) -> None: self.resource_name = resource_name super().__init__( - resource_name, f"Resource with a name {resource_name} could not be found. {context}" + resource_name, + f"Resource with a name {resource_name} could not be found. {context}", ) class InvalidResourceDataType(DltResourceException): - def __init__(self, resource_name: str, item: Any, _typ: Type[Any], msg: str) -> None: + def __init__( + self, resource_name: str, item: Any, _typ: Type[Any], msg: str + ) -> None: self.item = item self._typ = _typ super().__init__( resource_name, - f"Cannot create resource {resource_name} from specified data. If you want to process" - " just one data item, enclose it in a list. " + f"Cannot create resource {resource_name} from specified data. If you want" + " to process just one data item, enclose it in a list. " + msg, ) @@ -161,8 +168,8 @@ def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: resource_name, item, _typ, - "Parallel resource data must be a generator or a generator function. The provided" - f" data type for resource '{resource_name}' was {_typ.__name__}.", + "Parallel resource data must be a generator or a generator function. The" + f" provided data type for resource '{resource_name}' was {_typ.__name__}.", ) @@ -172,9 +179,9 @@ def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: resource_name, item, _typ, - f"Resources cannot be strings or dictionaries but {_typ.__name__} was provided. Please" - " pass your data in a list or as a function yielding items. If you want to process" - " just one data item, enclose it in a list.", + f"Resources cannot be strings or dictionaries but {_typ.__name__} was" + " provided. Please pass your data in a list or as a function yielding" + " items. If you want to process just one data item, enclose it in a list.", ) @@ -184,8 +191,8 @@ def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: resource_name, item, _typ, - "Please make sure that function decorated with @dlt.resource uses 'yield' to return the" - " data.", + "Please make sure that function decorated with @dlt.resource uses 'yield'" + " to return the data.", ) @@ -195,9 +202,9 @@ def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: resource_name, item, _typ, - "Resources with multiple parallel data pipes are not yet supported. This problem most" - " often happens when you are creating a source with @dlt.source decorator that has" - " several resources with the same name.", + "Resources with multiple parallel data pipes are not yet supported. This" + " problem most often happens when you are creating a source with" + " @dlt.source decorator that has several resources with the same name.", ) @@ -207,23 +214,34 @@ def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: resource_name, item, _typ, - "Transformer must be a function decorated with @dlt.transformer that takes data item as" - " its first argument. Only first argument may be 'positional only'.", + "Transformer must be a function decorated with @dlt.transformer that takes" + " data item as its first argument. Only first argument may be 'positional" + " only'.", ) class InvalidTransformerGeneratorFunction(DltResourceException): - def __init__(self, resource_name: str, func_name: str, sig: Signature, code: int) -> None: + def __init__( + self, resource_name: str, func_name: str, sig: Signature, code: int + ) -> None: self.func_name = func_name self.sig = sig self.code = code - msg = f"Transformer function {func_name} must take data item as its first argument. " + msg = ( + f"Transformer function {func_name} must take data item as its first" + " argument. " + ) if code == 1: msg += "The actual function does not take any arguments." elif code == 2: - msg += f"Only the first argument may be 'positional only', actual signature is {sig}" + msg += ( + "Only the first argument may be 'positional only', actual signature is" + f" {sig}" + ) elif code == 3: - msg += f"The first argument cannot be keyword only, actual signature is {sig}" + msg += ( + f"The first argument cannot be keyword only, actual signature is {sig}" + ) super().__init__(resource_name, msg) @@ -232,11 +250,11 @@ class ResourceInnerCallableConfigWrapDisallowed(DltResourceException): def __init__(self, resource_name: str, section: str) -> None: self.section = section msg = ( - f"Resource {resource_name} in section {section} is defined over an inner function and" - " requests config/secrets in its arguments. Requesting secret and config values via" - " 'dlt.secrets.values' or 'dlt.config.value' is disallowed for resources that are" - " inner functions. Use the dlt.source to get the required configuration and pass them" - " explicitly to your source." + f"Resource {resource_name} in section {section} is defined over an inner" + " function and requests config/secrets in its arguments. Requesting secret" + " and config values via 'dlt.secrets.values' or 'dlt.config.value' is" + " disallowed for resources that are inner functions. Use the dlt.source to" + " get the required configuration and pass them explicitly to your source." ) super().__init__(resource_name, msg) @@ -247,8 +265,8 @@ def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: resource_name, item, _typ, - "Resource data missing. Did you forget the return statement in @dlt.resource decorated" - " function?", + "Resource data missing. Did you forget the return statement in" + " @dlt.resource decorated function?", ) @@ -258,8 +276,9 @@ def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: resource_name, item, _typ, - f"Expected function or callable as first parameter to resource {resource_name} but" - f" {_typ.__name__} found. Please decorate a function with @dlt.resource", + "Expected function or callable as first parameter to resource" + f" {resource_name} but {_typ.__name__} found. Please decorate a function" + " with @dlt.resource", ) @@ -269,8 +288,8 @@ def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: resource_name, item, _typ, - f"A parent resource of {resource_name} is of type {_typ.__name__}. Did you forget to" - " use '@dlt.resource` decorator or `resource` function?", + f"A parent resource of {resource_name} is of type {_typ.__name__}. Did you" + " forget to use '@dlt.resource` decorator or `resource` function?", ) @@ -279,14 +298,17 @@ def __init__(self, resource_name: str, func_name: str) -> None: self.func_name = func_name super().__init__( resource_name, - f"A data source {func_name} of a transformer {resource_name} is an undecorated" - " function. Please decorate it with '@dlt.resource' or pass to 'resource' function.", + f"A data source {func_name} of a transformer {resource_name} is an" + " undecorated function. Please decorate it with '@dlt.resource' or pass to" + " 'resource' function.", ) class DeletingResourcesNotSupported(DltResourceException): def __init__(self, source_name: str, resource_name: str) -> None: - super().__init__(resource_name, f"Resource cannot be removed the the source {source_name}") + super().__init__( + resource_name, f"Resource cannot be removed the the source {source_name}" + ) class ParametrizedResourceUnbound(DltResourceException): @@ -296,9 +318,9 @@ def __init__( self.func_name = func_name self.sig = sig msg = ( - f"The {kind} {resource_name} is parametrized and expects following arguments: {sig}." - f" Did you forget to bind the {func_name} function? For example from" - f" `source.{resource_name}.bind(...)" + f"The {kind} {resource_name} is parametrized and expects following" + f" arguments: {sig}. Did you forget to bind the {func_name} function? For" + f" example from `source.{resource_name}.bind(...)" ) if error: msg += f" .Details: {error}" @@ -329,8 +351,8 @@ class SourceDataIsNone(DltSourceException): def __init__(self, source_name: str) -> None: self.source_name = source_name super().__init__( - f"No data returned or yielded from source function {source_name}. Did you forget the" - " return statement?" + f"No data returned or yielded from source function {source_name}. Did you" + " forget the return statement?" ) @@ -338,14 +360,17 @@ class SourceExhausted(DltSourceException): def __init__(self, source_name: str) -> None: self.source_name = source_name super().__init__( - f"Source {source_name} is exhausted or has active iterator. You can iterate or pass the" - " source to dlt pipeline only once." + f"Source {source_name} is exhausted or has active iterator. You can iterate" + " or pass the source to dlt pipeline only once." ) class ResourcesNotFoundError(DltSourceException): def __init__( - self, source_name: str, available_resources: Set[str], requested_resources: Set[str] + self, + source_name: str, + available_resources: Set[str], + requested_resources: Set[str], ) -> None: self.source_name = source_name self.available_resources = available_resources @@ -353,7 +378,8 @@ def __init__( self.not_found_resources = requested_resources.difference(available_resources) msg = ( f"The following resources could not be found in source {source_name}:" - f" {self.not_found_resources}. Available resources are: {available_resources}" + f" {self.not_found_resources}. Available resources are:" + f" {available_resources}" ) super().__init__(msg) @@ -364,8 +390,9 @@ def __init__(self, source_name: str, item: Any, _typ: Type[Any]) -> None: self.item = item self.typ = _typ super().__init__( - f"First parameter to the source {source_name} must be a function or callable but is" - f" {_typ.__name__}. Please decorate a function with @dlt.source" + f"First parameter to the source {source_name} must be a function or" + f" callable but is {_typ.__name__}. Please decorate a function with" + " @dlt.source" ) @@ -374,25 +401,25 @@ def __init__(self, source_name: str, _typ: Type[Any]) -> None: self.source_name = source_name self.typ = _typ super().__init__( - f"First parameter to the source {source_name} is a class {_typ.__name__}. Do not" - " decorate classes with @dlt.source. Instead implement __call__ in your class and pass" - " instance of such class to dlt.source() directly" + f"First parameter to the source {source_name} is a class {_typ.__name__}." + " Do not decorate classes with @dlt.source. Instead implement __call__ in" + " your class and pass instance of such class to dlt.source() directly" ) class CurrentSourceSchemaNotAvailable(DltSourceException): def __init__(self) -> None: super().__init__( - "Current source schema is available only when called from a function decorated with" - " dlt.source or dlt.resource" + "Current source schema is available only when called from a function" + " decorated with dlt.source or dlt.resource" ) class CurrentSourceNotAvailable(DltSourceException): def __init__(self) -> None: super().__init__( - "Current source is available only when called from a function decorated with" - " dlt.resource or dlt.transformer during the extract step" + "Current source is available only when called from a function decorated" + " with dlt.resource or dlt.transformer during the extract step" ) @@ -401,8 +428,8 @@ def __init__(self, source_name: str, schema_name: str) -> None: self.source_name = source_name self.schema_name = schema_name super().__init__( - f"Your explicit source name {source_name} is not a valid schema name. Please use a" - f" valid schema name ie. '{schema_name}'." + f"Your explicit source name {source_name} is not a valid schema name." + f" Please use a valid schema name ie. '{schema_name}'." ) @@ -410,9 +437,9 @@ class IncrementalUnboundError(DltResourceException): def __init__(self, cursor_path: str) -> None: super().__init__( "", - f"The incremental definition with cursor path {cursor_path} is used without being bound" - " to the resource. This most often happens when you create dynamic resource from a" - " generator function that uses incremental. See" + f"The incremental definition with cursor path {cursor_path} is used without" + " being bound to the resource. This most often happens when you create" + " dynamic resource from a generator function that uses incremental. See" " https://dlthub.com/docs/general-usage/incremental-loading#incremental-loading-with-last-value" " for an example.", ) diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 2fc4fd77aa..381d567f92 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -27,11 +27,18 @@ TSchemaContract, TWriteDisposition, ) -from dlt.common.storages import NormalizeStorageConfiguration, LoadPackageInfo, SchemaStorage +from dlt.common.storages import ( + NormalizeStorageConfiguration, + LoadPackageInfo, + SchemaStorage, +) from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.utils import get_callable_name, get_full_class_name -from dlt.extract.decorators import SourceInjectableContext, SourceSchemaInjectableContext +from dlt.extract.decorators import ( + SourceInjectableContext, + SourceSchemaInjectableContext, +) from dlt.extract.exceptions import DataItemRequiredForDynamicTableHints from dlt.extract.incremental import IncrementalResourceWrapper from dlt.extract.pipe_iterator import PipeIterator @@ -103,7 +110,9 @@ def append_data(data_item: Any) -> None: # iterator/iterable/generator # create resource first without table template resources.append( - DltResource.from_data(data_item, name=table_name, section=pipeline.pipeline_name) + DltResource.from_data( + data_item, name=table_name, section=pipeline.pipeline_name + ) ) if isinstance(data, C_Sequence) and len(data) > 0: @@ -139,7 +148,9 @@ def add_item(item: Any) -> bool: data_info.append( { "name": item.name, - "data_type": "resource" if isinstance(item, DltResource) else "source", + "data_type": ( + "resource" if isinstance(item, DltResource) else "source" + ), } ) return False @@ -185,16 +196,21 @@ def _compute_metrics(self, load_id: str, source: DltSource) -> ExtractMetrics: } # aggregate by table name table_metrics = { - table_name: sum(map(lambda pair: pair[1], metrics), EMPTY_DATA_WRITER_METRICS) + table_name: sum( + map(lambda pair: pair[1], metrics), EMPTY_DATA_WRITER_METRICS + ) for table_name, metrics in itertools.groupby( job_metrics.items(), lambda pair: pair[0].table_name ) } # aggregate by resource name resource_metrics = { - resource_name: sum(map(lambda pair: pair[1], metrics), EMPTY_DATA_WRITER_METRICS) + resource_name: sum( + map(lambda pair: pair[1], metrics), EMPTY_DATA_WRITER_METRICS + ) for resource_name, metrics in itertools.groupby( - table_metrics.items(), lambda pair: source.schema.get_table(pair[0])["resource"] + table_metrics.items(), + lambda pair: source.schema.get_table(pair[0])["resource"], ) } # collect resource hints @@ -227,7 +243,10 @@ def _compute_metrics(self, load_id: str, source: DltSource) -> ExtractMetrics: if name == "columns": if hint: hints[name] = yaml.dump( - hint, allow_unicode=True, default_flow_style=False, sort_keys=False + hint, + allow_unicode=True, + default_flow_style=False, + sort_keys=False, ) continue hints[name] = hint @@ -236,7 +255,9 @@ def _compute_metrics(self, load_id: str, source: DltSource) -> ExtractMetrics: "started_at": None, "finished_at": None, "schema_name": source.schema.name, - "job_metrics": {job.job_id(): metrics for job, metrics in job_metrics.items()}, + "job_metrics": { + job.job_id(): metrics for job, metrics in job_metrics.items() + }, "table_metrics": table_metrics, "resource_metrics": resource_metrics, "dag": source.resources.selected_dag, @@ -248,13 +269,18 @@ def _write_empty_files( ) -> None: schema = source.schema json_extractor = extractors["puae-jsonl"] - resources_with_items = set().union(*[e.resources_with_items for e in extractors.values()]) + resources_with_items = set().union( + *[e.resources_with_items for e in extractors.values()] + ) # find REPLACE resources that did not yield any pipe items and create empty jobs for them # NOTE: do not include tables that have never seen data data_tables = {t["name"]: t for t in schema.data_tables(seen_data_only=True)} tables_by_resources = utils.group_tables_by_resource(data_tables) for resource in source.resources.selected.values(): - if resource.write_disposition != "replace" or resource.name in resources_with_items: + if ( + resource.write_disposition != "replace" + or resource.name in resources_with_items + ): continue if resource.name not in tables_by_resources: continue @@ -300,7 +326,9 @@ def _extract_single_source( "puae-jsonl": JsonLExtractor( load_id, self.extract_storage, schema, collector=collector ), - "arrow": ArrowExtractor(load_id, self.extract_storage, schema, collector=collector), + "arrow": ArrowExtractor( + load_id, self.extract_storage, schema, collector=collector + ), } last_item_format: Optional[TLoaderFileFormat] = None @@ -327,9 +355,13 @@ def _extract_single_source( resource = source.resources[pipe_item.pipe.name] # Fallback to last item's format or default (puae-jsonl) if the current item is an empty list item_format = ( - Extractor.item_format(pipe_item.item) or last_item_format or "puae-jsonl" + Extractor.item_format(pipe_item.item) + or last_item_format + or "puae-jsonl" + ) + extractors[item_format].write_items( + resource, pipe_item.item, pipe_item.meta ) - extractors[item_format].write_items(resource, pipe_item.item, pipe_item.meta) last_item_format = item_format self._write_empty_files(source, extractors) @@ -340,7 +372,9 @@ def _extract_single_source( # flush all buffered writers self.extract_storage.close_writers(load_id) # gather metrics - self._step_info_complete_load_id(load_id, self._compute_metrics(load_id, source)) + self._step_info_complete_load_id( + load_id, self._compute_metrics(load_id, source) + ) # remove the metrics of files processed in this extract run # NOTE: there may be more than one extract run per load id: ie. the resource and then dlt state self.extract_storage.remove_closed_files(load_id) diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index b8e615aae4..71064c1f79 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -121,7 +121,9 @@ def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> No def write_empty_items_file(self, table_name: str) -> None: table_name = self.naming.normalize_table_identifier(table_name) - self.storage.write_empty_items_file(self.load_id, self.schema.name, table_name, None) + self.storage.write_empty_items_file( + self.load_id, self.schema.name, table_name, None + ) def _get_static_table_name(self, resource: DltResource, meta: Any) -> Optional[str]: if resource._table_name_hint_fun: @@ -133,7 +135,9 @@ def _get_static_table_name(self, resource: DltResource, meta: Any) -> Optional[s return self.naming.normalize_table_identifier(table_name) def _get_dynamic_table_name(self, resource: DltResource, item: TDataItem) -> str: - return self.naming.normalize_table_identifier(resource._table_name_hint_fun(item)) + return self.naming.normalize_table_identifier( + resource._table_name_hint_fun(item) + ) def _write_item( self, @@ -160,7 +164,10 @@ def _write_to_dynamic_table(self, resource: DltResource, items: TDataItems) -> N table_name = self._get_dynamic_table_name(resource, item) if table_name in self._filtered_tables: continue - if table_name not in self._table_contracts or resource._table_has_other_dynamic_hints: + if ( + table_name not in self._table_contracts + or resource._table_has_other_dynamic_hints + ): item = self._compute_and_update_table( resource, table_name, item, TableNameMeta(table_name) ) @@ -176,9 +183,13 @@ def _write_to_static_table( if table_name not in self._filtered_tables: self._write_item(table_name, resource.name, items) - def _compute_table(self, resource: DltResource, items: TDataItems, meta: Any) -> TTableSchema: + def _compute_table( + self, resource: DltResource, items: TDataItems, meta: Any + ) -> TTableSchema: """Computes a schema for a new or dynamic table and normalizes identifiers""" - return self.schema.normalize_table_identifiers(resource.compute_table_schema(items, meta)) + return self.schema.normalize_table_identifiers( + resource.compute_table_schema(items, meta) + ) def _compute_and_update_table( self, resource: DltResource, table_name: str, items: TDataItems, meta: Any @@ -191,11 +202,14 @@ def _compute_and_update_table( computed_table["name"] = table_name # get or compute contract schema_contract = self._table_contracts.setdefault( - table_name, self.schema.resolve_contract_settings_for_table(table_name, computed_table) + table_name, + self.schema.resolve_contract_settings_for_table(table_name, computed_table), ) # this is a new table so allow evolve once - if schema_contract["columns"] != "evolve" and self.schema.is_new_table(table_name): + if schema_contract["columns"] != "evolve" and self.schema.is_new_table( + table_name + ): computed_table["x-normalizer"] = {"evolve-columns-once": True} # type: ignore[typeddict-unknown-key] existing_table = self.schema._schema_tables.get(table_name, None) if existing_table: @@ -257,7 +271,10 @@ def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> No super().write_items(resource, items, meta) def _apply_contract_filters( - self, item: "TAnyArrowItem", resource: DltResource, static_table_name: Optional[str] + self, + item: "TAnyArrowItem", + resource: DltResource, + static_table_name: Optional[str], ) -> "TAnyArrowItem": """Removes the columns (discard value) or rows (discard rows) as indicated by contract filters.""" # convert arrow schema names into normalized names @@ -274,7 +291,11 @@ def _apply_contract_filters( name for name, mode in filtered_columns.items() if mode == "discard_row" ]: is_null = pyarrow.pyarrow.compute.is_null(item[rev_mapping[column]]) - mask = is_null if mask is None else pyarrow.pyarrow.compute.and_(mask, is_null) + mask = ( + is_null + if mask is None + else pyarrow.pyarrow.compute.and_(mask, is_null) + ) # filter the table using the mask if mask is not None: item = item.filter(mask) @@ -324,9 +345,10 @@ def _compute_table( if (src_hint := src_column.get(hint_name)) is not None: if src_hint != hint: logger.warning( - f"In resource: {resource.name}, when merging arrow schema on column" - f" {col_name}. The hint {hint_name} value {src_hint} defined in" - f" resource is overwritten from arrow with value {hint}." + f"In resource: {resource.name}, when merging arrow" + f" schema on column {col_name}. The hint" + f" {hint_name} value {src_hint} defined in resource is" + f" overwritten from arrow with value {hint}." ) # we must override the columns to preserve the order in arrow table @@ -341,5 +363,7 @@ def _compute_and_update_table( ) -> TDataItems: items = super()._compute_and_update_table(resource, table_name, items, meta) # filter data item as filters could be updated in compute table - items = [self._apply_contract_filters(item, resource, table_name) for item in items] + items = [ + self._apply_contract_filters(item, resource, table_name) for item in items + ] return items diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index 01a99a23fe..59b75d3305 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -13,7 +13,12 @@ TSchemaContract, ) from dlt.common import logger -from dlt.common.schema.utils import DEFAULT_WRITE_DISPOSITION, merge_column, new_column, new_table +from dlt.common.schema.utils import ( + DEFAULT_WRITE_DISPOSITION, + merge_column, + new_column, + new_table, +) from dlt.common.typing import TDataItem, DictStrAny, DictStrStr from dlt.common.utils import update_dict_nested from dlt.common.validation import validate_dict_ignoring_xkeys @@ -22,8 +27,16 @@ InconsistentTableTemplate, ) from dlt.extract.incremental import Incremental -from dlt.extract.items import TFunHintTemplate, TTableHintTemplate, TableNameMeta, ValidateItem -from dlt.extract.utils import ensure_table_schema_columns, ensure_table_schema_columns_hint +from dlt.extract.items import ( + TFunHintTemplate, + TTableHintTemplate, + TableNameMeta, + ValidateItem, +) +from dlt.extract.utils import ( + ensure_table_schema_columns, + ensure_table_schema_columns_hint, +) from dlt.extract.validation import create_item_validator @@ -149,7 +162,9 @@ def columns(self) -> TTableHintTemplate[TTableSchemaColumns]: def schema_contract(self) -> TTableHintTemplate[TSchemaContract]: return self._hints.get("schema_contract") - def compute_table_schema(self, item: TDataItem = None, meta: Any = None) -> TTableSchema: + def compute_table_schema( + self, item: TDataItem = None, meta: Any = None + ) -> TTableSchema: """Computes the table schema based on hints and column definitions passed during resource creation. `item` parameter is used to resolve table hints based on data. `meta` parameter is taken from Pipe and may further specify table name if variant is to be used @@ -218,8 +233,8 @@ def apply_hints( if create_table_variant: if not isinstance(table_name, str): raise ValueError( - "Please provide string table name if you want to create a table variant of" - " hints" + "Please provide string table name if you want to create a table" + " variant of hints" ) # select hints variant t = self._hints_variants.get(table_name, None) @@ -324,20 +339,21 @@ def _set_hints( # incremental cannot be specified in variant if hints_template.get("incremental"): raise InconsistentTableTemplate( - f"You can specify incremental only for the resource `{self.name}` hints, not in" - f" table `{table_name}` variant-" + f"You can specify incremental only for the resource `{self.name}`" + f" hints, not in table `{table_name}` variant-" ) if hints_template.get("validator"): logger.warning( - f"A data item validator was created from column schema in {self.name} for a" - f" table `{table_name}` variant. Currently such validator is ignored." + "A data item validator was created from column schema in" + f" {self.name} for a table `{table_name}` variant. Currently such" + " validator is ignored." ) # dynamic hints will be ignored for name, hint in hints_template.items(): if callable(hint) and name not in NATURAL_CALLABLES: raise InconsistentTableTemplate( - f"Table `{table_name}` variant hint is resource {self.name} cannot have" - f" dynamic hint but {name} does." + f"Table `{table_name}` variant hint is resource" + f" {self.name} cannot have dynamic hint but {name} does." ) self._hints_variants[table_name] = hints_template else: @@ -380,7 +396,9 @@ def _resolve_hint(item: TDataItem, hint: TTableHintTemplate[Any]) -> Any: return hint(item) if callable(hint) else hint @staticmethod - def _merge_key(hint: TColumnProp, keys: TColumnNames, partial: TPartialTableSchema) -> None: + def _merge_key( + hint: TColumnProp, keys: TColumnNames, partial: TPartialTableSchema + ) -> None: if isinstance(keys, str): keys = [keys] for key in keys: @@ -408,8 +426,11 @@ def validate_dynamic_hints(template: TResourceHints) -> None: table_name = template.get("name") # if any of the hints is a function, then name must be as well. if any( - callable(v) for k, v in template.items() if k not in ["name", *NATURAL_CALLABLES] + callable(v) + for k, v in template.items() + if k not in ["name", *NATURAL_CALLABLES] ) and not callable(table_name): raise InconsistentTableTemplate( - f"Table name {table_name} must be a function if any other table hint is a function" + f"Table name {table_name} must be a function if any other table hint is" + " a function" ) diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index e74e87d094..1f8cc078be 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -35,7 +35,11 @@ IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing, ) -from dlt.extract.incremental.typing import IncrementalColumnState, TCursorValue, LastValueFunc +from dlt.extract.incremental.typing import ( + IncrementalColumnState, + TCursorValue, + LastValueFunc, +) from dlt.extract.pipe import Pipe from dlt.extract.items import SupportsPipe, TTableHintTemplate, ItemTransform from dlt.extract.incremental.transform import ( @@ -193,7 +197,9 @@ def merge(self, other: "Incremental[TCursorValue]") -> "Incremental[TCursorValue >>> my_resource(updated=incremental(initial_value='2023-01-01', end_value='2023-02-01')) """ # func, resource name and primary key are not part of the dict - kwargs = dict(self, last_value_func=self.last_value_func, primary_key=self._primary_key) + kwargs = dict( + self, last_value_func=self.last_value_func, primary_key=self._primary_key + ) for key, value in dict( other, last_value_func=other.last_value_func, primary_key=other.primary_key ).items(): @@ -204,7 +210,9 @@ def merge(self, other: "Incremental[TCursorValue]") -> "Incremental[TCursorValue constructor = self.__orig_class__ else: constructor = ( - other.__orig_class__ if hasattr(other, "__orig_class__") else other.__class__ + other.__orig_class__ + if hasattr(other, "__orig_class__") + else other.__class__ ) constructor = extract_inner_type(constructor) merged = constructor(**kwargs) @@ -221,30 +229,31 @@ def on_resolved(self) -> None: compile_path(self.cursor_path) if self.end_value is not None and self.initial_value is None: raise ConfigurationValueError( - "Incremental 'end_value' was specified without 'initial_value'. 'initial_value' is" - " required when using 'end_value'." + "Incremental 'end_value' was specified without 'initial_value'." + " 'initial_value' is required when using 'end_value'." ) self._cursor_datetime_check(self.initial_value, "initial_value") self._cursor_datetime_check(self.initial_value, "end_value") # Ensure end value is "higher" than initial value if ( self.end_value is not None - and self.last_value_func([self.end_value, self.initial_value]) != self.end_value + and self.last_value_func([self.end_value, self.initial_value]) + != self.end_value ): if self.last_value_func in (min, max): adject = "higher" if self.last_value_func is max else "lower" msg = ( - f"Incremental 'initial_value' ({self.initial_value}) is {adject} than" - f" 'end_value` ({self.end_value}). 'end_value' must be {adject} than" - " 'initial_value'" + f"Incremental 'initial_value' ({self.initial_value}) is" + f" {adject} than 'end_value` ({self.end_value}). 'end_value' must" + f" be {adject} than 'initial_value'" ) else: msg = ( - f"Incremental 'initial_value' ({self.initial_value}) is greater than" - f" 'end_value' ({self.end_value}) as determined by the custom" + f"Incremental 'initial_value' ({self.initial_value}) is greater" + f" than 'end_value' ({self.end_value}) as determined by the custom" " 'last_value_func'. The result of" - f" '{self.last_value_func.__name__}([end_value, initial_value])' must equal" - " 'end_value'" + f" '{self.last_value_func.__name__}([end_value, initial_value])'" + " must equal 'end_value'" ) raise ConfigurationValueError(msg) @@ -277,7 +286,9 @@ def get_state(self) -> IncrementalColumnState: if not self.resource_name: raise IncrementalUnboundError(self.cursor_path) - self._cached_state = Incremental._get_state(self.resource_name, self.cursor_path) + self._cached_state = Incremental._get_state( + self.resource_name, self.cursor_path + ) if len(self._cached_state) == 0: # set the default like this, setdefault evaluates the default no matter if it is needed or not. and our default is heavy self._cached_state.update( @@ -292,7 +303,9 @@ def get_state(self) -> IncrementalColumnState: @staticmethod def _get_state(resource_name: str, cursor_path: str) -> IncrementalColumnState: state: IncrementalColumnState = ( - resource_state(resource_name).setdefault("incremental", {}).setdefault(cursor_path, {}) + resource_state(resource_name) + .setdefault("incremental", {}) + .setdefault(cursor_path, {}) ) # if state params is empty return state @@ -301,10 +314,10 @@ def _get_state(resource_name: str, cursor_path: str) -> IncrementalColumnState: def _cursor_datetime_check(value: Any, arg_name: str) -> None: if value and isinstance(value, datetime) and value.tzinfo is None: logger.warning( - f"The {arg_name} argument {value} is a datetime without timezone. This may result" - " in an error when such values are compared by Incremental class. Note that `dlt`" - " stores datetimes in timezone-aware types so the UTC timezone will be added by" - " the destination" + f"The {arg_name} argument {value} is a datetime without timezone. This" + " may result in an error when such values are compared by Incremental" + " class. Note that `dlt` stores datetimes in timezone-aware types so" + " the UTC timezone will be added by the destination" ) @property @@ -343,15 +356,17 @@ def _join_external_scheduler(self) -> None: data_type = py_type_to_sc_type(param_type) except Exception as ex: logger.warning( - f"Specified Incremental last value type {param_type} is not supported. Please use" - f" DateTime, Date, float, int or str to join external schedulers.({ex})" + f"Specified Incremental last value type {param_type} is not supported." + " Please use DateTime, Date, float, int or str to join external" + f" schedulers.({ex})" ) if param_type is Any: logger.warning( - "Could not find the last value type of Incremental class participating in external" - " schedule. Please add typing when declaring incremental argument in your resource" - " or pass initial_value from which the type can be inferred." + "Could not find the last value type of Incremental class participating" + " in external schedule. Please add typing when declaring incremental" + " argument in your resource or pass initial_value from which the type" + " can be inferred." ) return @@ -370,7 +385,9 @@ def _ensure_airflow_end_date( context = get_current_context() start_date = context["data_interval_start"] - end_date = _ensure_airflow_end_date(start_date, context["data_interval_end"]) + end_date = _ensure_airflow_end_date( + start_date, context["data_interval_end"] + ) self.initial_value = coerce_from_date_types(data_type, start_date) if end_date is not None: self.end_value = coerce_from_date_types(data_type, end_date) @@ -379,13 +396,14 @@ def _ensure_airflow_end_date( logger.info( f"Found Airflow scheduler: initial value: {self.initial_value} from" f" data_interval_start {context['data_interval_start']}, end value:" - f" {self.end_value} from data_interval_end {context['data_interval_end']}" + f" {self.end_value} from data_interval_end" + f" {context['data_interval_end']}" ) return except TypeError as te: logger.warning( - f"Could not coerce Airflow execution dates into the last value type {param_type}." - f" ({te})" + "Could not coerce Airflow execution dates into the last value type" + f" {param_type}. ({te})" ) except Exception: pass @@ -411,8 +429,9 @@ def bind(self, pipe: SupportsPipe) -> "Incremental[TCursorValue]": # set initial value from last value, in case of a new state those are equal self.start_value = self.last_value logger.info( - f"Bind incremental on {self.resource_name} with initial_value: {self.initial_value}," - f" start_value: {self.start_value}, end_value: {self.end_value}" + f"Bind incremental on {self.resource_name} with initial_value:" + f" {self.initial_value}, start_value: {self.start_value}, end_value:" + f" {self.end_value}" ) # cache state self._cached_state = self.get_state() @@ -436,8 +455,9 @@ def can_close(self) -> bool: def __str__(self) -> str: return ( - f"Incremental at {id(self)} for resource {self.resource_name} with cursor path:" - f" {self.cursor_path} initial {self.initial_value} lv_func {self.last_value_func}" + f"Incremental at {id(self)} for resource {self.resource_name} with cursor" + f" path: {self.cursor_path} initial {self.initial_value} lv_func" + f" {self.last_value_func}" ) def _get_transformer(self, items: TDataItems) -> IncrementalTransform: @@ -487,7 +507,9 @@ class IncrementalResourceWrapper(ItemTransform[TDataItem]): """Keeps the injectable incremental""" _resource_name: str = None - def __init__(self, primary_key: Optional[TTableHintTemplate[TColumnNames]] = None) -> None: + def __init__( + self, primary_key: Optional[TTableHintTemplate[TColumnNames]] = None + ) -> None: """Creates a wrapper over a resource function that accepts Incremental instance in its argument to perform incremental loading. The wrapper delays instantiation of the Incremental to the moment of actual execution and is currently used by `dlt.resource` decorator. @@ -512,9 +534,9 @@ def get_incremental_arg(sig: inspect.Signature) -> Optional[inspect.Parameter]: for p in sig.parameters.values(): annotation = extract_inner_type(p.annotation) annotation = get_origin(annotation) or annotation - if (inspect.isclass(annotation) and issubclass(annotation, Incremental)) or isinstance( - p.default, Incremental - ): + if ( + inspect.isclass(annotation) and issubclass(annotation, Incremental) + ) or isinstance(p.default, Incremental): incremental_param = p break return incremental_param @@ -522,7 +544,9 @@ def get_incremental_arg(sig: inspect.Signature) -> Optional[inspect.Parameter]: def wrap(self, sig: inspect.Signature, func: TFun) -> TFun: """Wrap the callable to inject an `Incremental` object configured for the resource.""" incremental_param = self.get_incremental_arg(sig) - assert incremental_param, "Please use `should_wrap` to decide if to call this function" + assert ( + incremental_param + ), "Please use `should_wrap` to decide if to call this function" @wraps(func) def _wrap(*args: Any, **kwargs: Any) -> Any: @@ -533,7 +557,10 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: if p.name in bound_args.arguments: explicit_value = bound_args.arguments[p.name] - if explicit_value is Incremental.EMPTY or p.default is Incremental.EMPTY: + if ( + explicit_value is Incremental.EMPTY + or p.default is Incremental.EMPTY + ): # drop incremental pass elif isinstance(explicit_value, Incremental): @@ -550,13 +577,15 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: elif isinstance(p.default, Incremental): new_incremental = p.default.copy() - if (not new_incremental or new_incremental.is_partial()) and not self._incremental: + if ( + not new_incremental or new_incremental.is_partial() + ) and not self._incremental: if is_optional_type(p.annotation): bound_args.arguments[p.name] = None # Remove partial spec return func(*bound_args.args, **bound_args.kwargs) raise ValueError( - f"{p.name} Incremental argument has no default. Please wrap its typing in" - " Optional[] to allow no incremental" + f"{p.name} Incremental argument has no default. Please wrap its" + " typing in Optional[] to allow no incremental" ) # pass Generic information from annotation to new_incremental if ( @@ -569,7 +598,9 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: # set the incremental only if not yet set or if it was passed explicitly # NOTE: the _incremental may be also set by applying hints to the resource see `set_template` in `DltResource` - if (new_incremental and p.name in bound_args.arguments) or not self._incremental: + if ( + new_incremental and p.name in bound_args.arguments + ) or not self._incremental: self._incremental = new_incremental self._incremental.resolve() # in case of transformers the bind will be called before this wrapper is set: because transformer is called for a first time late in the pipe @@ -602,7 +633,9 @@ def bind(self, pipe: SupportsPipe) -> "IncrementalResourceWrapper": self._resource_name = pipe.name if self._incremental: if self._allow_external_schedulers is not None: - self._incremental.allow_external_schedulers = self._allow_external_schedulers + self._incremental.allow_external_schedulers = ( + self._allow_external_schedulers + ) self._incremental.bind(pipe) return self diff --git a/dlt/extract/incremental/exceptions.py b/dlt/extract/incremental/exceptions.py index e318a028dc..8cedb18382 100644 --- a/dlt/extract/incremental/exceptions.py +++ b/dlt/extract/incremental/exceptions.py @@ -3,7 +3,9 @@ class IncrementalCursorPathMissing(PipeException): - def __init__(self, pipe_name: str, json_path: str, item: TDataItem, msg: str = None) -> None: + def __init__( + self, pipe_name: str, json_path: str, item: TDataItem, msg: str = None + ) -> None: self.json_path = json_path self.item = item msg = ( @@ -14,12 +16,14 @@ def __init__(self, pipe_name: str, json_path: str, item: TDataItem, msg: str = N class IncrementalPrimaryKeyMissing(PipeException): - def __init__(self, pipe_name: str, primary_key_column: str, item: TDataItem) -> None: + def __init__( + self, pipe_name: str, primary_key_column: str, item: TDataItem + ) -> None: self.primary_key_column = primary_key_column self.item = item msg = ( - f"Primary key column {primary_key_column} was not found in extracted data item. All" - " data items must contain this column. Use the same names of fields as in your JSON" - " document." + f"Primary key column {primary_key_column} was not found in extracted data" + " item. All data items must contain this column. Use the same names of" + " fields as in your JSON document." ) super().__init__(pipe_name, msg) diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index 29b20de7b8..d06c87c8d8 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -75,12 +75,14 @@ def compute_unique_value( ) -> str: try: assert not self.deduplication_disabled, ( - f"{self.resource_name}: Attempt to compute unique values when deduplication is" - " disabled" + f"{self.resource_name}: Attempt to compute unique values when" + " deduplication is disabled" ) if primary_key: - return digest128(json.dumps(resolve_column_value(primary_key, row), sort_keys=True)) + return digest128( + json.dumps(resolve_column_value(primary_key, row), sort_keys=True) + ) elif primary_key is None: return digest128(json.dumps(row, sort_keys=True)) else: @@ -96,7 +98,9 @@ def __call__( @property def deduplication_disabled(self) -> bool: """Skip deduplication when length of the key is 0""" - return isinstance(self.primary_key, (list, tuple)) and len(self.primary_key) == 0 + return ( + isinstance(self.primary_key, (list, tuple)) and len(self.primary_key) == 0 + ) class JsonIncremental(IncrementalTransform): @@ -116,7 +120,9 @@ def find_cursor_value(self, row: TDataItem) -> Any: except Exception: pass if row_value is None: - raise IncrementalCursorPathMissing(self.resource_name, self.cursor_path, row) + raise IncrementalCursorPathMissing( + self.resource_name, self.cursor_path, row + ) return row_value def __call__( @@ -193,7 +199,9 @@ def __call__( class ArrowIncremental(IncrementalTransform): _dlt_index = "_dlt_index" - def compute_unique_values(self, item: "TAnyArrowItem", unique_columns: List[str]) -> List[str]: + def compute_unique_values( + self, item: "TAnyArrowItem", unique_columns: List[str] + ) -> List[str]: if not unique_columns: return [] rows = item.select(unique_columns).to_pylist() @@ -215,7 +223,9 @@ def _add_unique_index(self, tbl: "pa.Table") -> "pa.Table": """Creates unique index if necessary.""" # create unique index if necessary if self._dlt_index not in tbl.schema.names: - tbl = pyarrow.append_column(tbl, self._dlt_index, pa.array(numpy.arange(tbl.num_rows))) + tbl = pyarrow.append_column( + tbl, self._dlt_index, pa.array(numpy.arange(tbl.num_rows)) + ) return tbl def __call__( @@ -226,7 +236,9 @@ def __call__( if is_pandas: tbl = pandas_to_arrow(tbl) - primary_key = self.primary_key(tbl) if callable(self.primary_key) else self.primary_key + primary_key = ( + self.primary_key(tbl) if callable(self.primary_key) else self.primary_key + ) if primary_key: # create a list of unique columns if isinstance(primary_key, str): @@ -275,9 +287,9 @@ def __call__( self.resource_name, cursor_path, tbl, - f"Column name {cursor_path} was not found in the arrow table. Not nested JSON paths" - " are not supported for arrow tables and dataframes, the incremental cursor_path" - " must be a column name.", + f"Column name {cursor_path} was not found in the arrow table. Not" + " nested JSON paths are not supported for arrow tables and dataframes," + " the incremental cursor_path must be a column name.", ) from e # If end_value is provided, filter to include table rows that are "less" than end_value @@ -286,21 +298,29 @@ def __call__( tbl = tbl.filter(end_compare(tbl[cursor_path], end_value_scalar)) # Is max row value higher than end value? # NOTE: pyarrow bool *always* evaluates to python True. `as_py()` is necessary - end_out_of_range = not end_compare(row_value_scalar, end_value_scalar).as_py() + end_out_of_range = not end_compare( + row_value_scalar, end_value_scalar + ).as_py() if self.start_value is not None: start_value_scalar = to_arrow_scalar(self.start_value, cursor_data_type) # Remove rows lower or equal than the last start value keep_filter = last_value_compare(tbl[cursor_path], start_value_scalar) - start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py()) + start_out_of_range = bool( + pa.compute.any(pa.compute.invert(keep_filter)).as_py() + ) tbl = tbl.filter(keep_filter) if not self.deduplication_disabled: # Deduplicate after filtering old values tbl = self._add_unique_index(tbl) # Remove already processed rows where the cursor is equal to the start value - eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], start_value_scalar)) + eq_rows = tbl.filter( + pa.compute.equal(tbl[cursor_path], start_value_scalar) + ) # compute index, unique hash mapping - unique_values_index = self.compute_unique_values_with_index(eq_rows, unique_columns) + unique_values_index = self.compute_unique_values_with_index( + eq_rows, unique_columns + ) unique_values_index = [ (i, uq_val) for i, uq_val in unique_values_index @@ -310,7 +330,9 @@ def __call__( remove_idx = pa.array(i for i, _ in unique_values_index) # Filter the table tbl = tbl.filter( - pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx)) + pa.compute.invert( + pa.compute.is_in(tbl[self._dlt_index], remove_idx) + ) ) if ( @@ -324,7 +346,9 @@ def __call__( # Compute unique hashes for all rows equal to row value self.unique_hashes = set( self.compute_unique_values( - tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)), + tbl.filter( + pa.compute.equal(tbl[cursor_path], row_value_scalar) + ), unique_columns, ) ) @@ -333,7 +357,9 @@ def __call__( self.unique_hashes.update( set( self.compute_unique_values( - tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)), + tbl.filter( + pa.compute.equal(tbl[cursor_path], row_value_scalar) + ), unique_columns, ) ) diff --git a/dlt/extract/items.py b/dlt/extract/items.py index fec31e2846..9bd0ecc1b9 100644 --- a/dlt/extract/items.py +++ b/dlt/extract/items.py @@ -128,7 +128,9 @@ def close(self) -> None: ItemTransformFunctionWithMeta = Callable[[TDataItem, str], TAny] ItemTransformFunctionNoMeta = Callable[[TDataItem], TAny] -ItemTransformFunc = Union[ItemTransformFunctionWithMeta[TAny], ItemTransformFunctionNoMeta[TAny]] +ItemTransformFunc = Union[ + ItemTransformFunctionWithMeta[TAny], ItemTransformFunctionNoMeta[TAny] +] class ItemTransform(ABC, Generic[TAny]): diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index 6517273db5..2203852128 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -1,7 +1,17 @@ import inspect import makefun from copy import copy -from typing import Any, AsyncIterator, Optional, Union, Callable, Iterable, Iterator, List, Tuple +from typing import ( + Any, + AsyncIterator, + Optional, + Union, + Callable, + Iterable, + Iterator, + List, + Tuple, +) from dlt.common.typing import AnyFun, AnyType, TDataItems from dlt.common.utils import get_callable_name @@ -32,7 +42,9 @@ class ForkPipe: - def __init__(self, pipe: "Pipe", step: int = -1, copy_on_fork: bool = False) -> None: + def __init__( + self, pipe: "Pipe", step: int = -1, copy_on_fork: bool = False + ) -> None: """A transformer that forks the `pipe` and sends the data items to forks added via `add_pipe` method.""" self._pipes: List[Tuple["Pipe", int]] = [] self.copy_on_fork = copy_on_fork @@ -58,7 +70,9 @@ def __call__(self, item: TDataItems, meta: Any) -> Iterator[ResolvablePipeItem]: class Pipe(SupportsPipe): - def __init__(self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = None) -> None: + def __init__( + self, name: str, steps: List[TPipeStep] = None, parent: "Pipe" = None + ) -> None: self.name = name self._gen_idx = 0 self._steps: List[TPipeStep] = [] @@ -109,7 +123,9 @@ def steps(self) -> List[TPipeStep]: def find(self, *step_type: AnyType) -> int: """Finds a step with object of type `step_type`""" - return next((i for i, v in enumerate(self._steps) if isinstance(v, step_type)), -1) + return next( + (i for i, v in enumerate(self._steps) if isinstance(v, step_type)), -1 + ) def __getitem__(self, i: int) -> TPipeStep: return self._steps[i] @@ -117,9 +133,13 @@ def __getitem__(self, i: int) -> TPipeStep: def __len__(self) -> int: return len(self._steps) - def fork(self, child_pipe: "Pipe", child_step: int = -1, copy_on_fork: bool = False) -> "Pipe": + def fork( + self, child_pipe: "Pipe", child_step: int = -1, copy_on_fork: bool = False + ) -> "Pipe": if len(self._steps) == 0: - raise CreatePipeException(self.name, f"Cannot fork to empty pipe {child_pipe}") + raise CreatePipeException( + self.name, f"Cannot fork to empty pipe {child_pipe}" + ) fork_step = self.tail if not isinstance(fork_step, ForkPipe): fork_step = ForkPipe(child_pipe, child_step, copy_on_fork) @@ -165,7 +185,8 @@ def remove_step(self, index: int) -> None: if index == self._gen_idx: raise CreatePipeException( self.name, - f"Step at index {index} holds a data generator for this pipe and cannot be removed", + f"Step at index {index} holds a data generator for this pipe and cannot" + " be removed", ) self._steps.pop(index) if index < self._gen_idx: @@ -299,10 +320,13 @@ def _wrap_gen(self, *args: Any, **kwargs: Any) -> Any: def _verify_head_step(self, step: TPipeStep) -> None: # first element must be Iterable, Iterator or Callable in resource pipe - if not isinstance(step, (Iterable, Iterator, AsyncIterator)) and not callable(step): + if not isinstance(step, (Iterable, Iterator, AsyncIterator)) and not callable( + step + ): raise CreatePipeException( self.name, - "A head of a resource pipe must be Iterable, Iterator, AsyncIterator or a Callable", + "A head of a resource pipe must be Iterable, Iterator, AsyncIterator or" + " a Callable", ) def _wrap_transform_step_meta(self, step_no: int, step: TPipeStep) -> TPipeStep: @@ -310,18 +334,20 @@ def _wrap_transform_step_meta(self, step_no: int, step: TPipeStep) -> TPipeStep: if isinstance(step, (Iterable, Iterator)) and not callable(step): if self.has_parent: raise CreatePipeException( - self.name, "Iterable or Iterator cannot be a step in transformer pipe" + self.name, + "Iterable or Iterator cannot be a step in transformer pipe", ) else: raise CreatePipeException( - self.name, "Iterable or Iterator can only be a first step in resource pipe" + self.name, + "Iterable or Iterator can only be a first step in resource pipe", ) if not callable(step): raise CreatePipeException( self.name, - "Pipe step must be a callable taking one data item as argument and optional second" - " meta argument", + "Pipe step must be a callable taking one data item as argument and" + " optional second meta argument", ) else: # check the signature @@ -350,7 +376,11 @@ def _partial(*args: Any, **kwargs: Any) -> Any: "meta", inspect._ParameterKind.KEYWORD_ONLY, default=None ) kwargs_arg = next( - (p for p in sig.parameters.values() if p.kind == inspect.Parameter.VAR_KEYWORD), + ( + p + for p in sig.parameters.values() + if p.kind == inspect.Parameter.VAR_KEYWORD + ), None, ) if kwargs_arg: @@ -378,7 +408,9 @@ def _ensure_transform_step(self, step_no: int, step: TPipeStep) -> None: if step_no == self._gen_idx: # error for gen step if len(sig.parameters) == 0: - raise InvalidTransformerGeneratorFunction(self.name, callable_name, sig, code=1) + raise InvalidTransformerGeneratorFunction( + self.name, callable_name, sig, code=1 + ) else: # show the sig without first argument raise ParametrizedResourceUnbound( @@ -389,7 +421,9 @@ def _ensure_transform_step(self, step_no: int, step: TPipeStep) -> None: str(ty_ex), ) else: - raise InvalidStepFunctionArguments(self.name, callable_name, sig, str(ty_ex)) + raise InvalidStepFunctionArguments( + self.name, callable_name, sig, str(ty_ex) + ) def _clone(self, new_name: str = None, with_parent: bool = False) -> "Pipe": """Clones the pipe steps, optionally renaming the pipe. Used internally to clone a list of connected pipes.""" diff --git a/dlt/extract/pipe_iterator.py b/dlt/extract/pipe_iterator.py index 1edd9bd039..7ae49e28bb 100644 --- a/dlt/extract/pipe_iterator.py +++ b/dlt/extract/pipe_iterator.py @@ -37,7 +37,12 @@ ResourceExtractionError, ) from dlt.extract.pipe import Pipe -from dlt.extract.items import DataItemWithMeta, PipeItem, ResolvablePipeItem, SourcePipeItem +from dlt.extract.items import ( + DataItemWithMeta, + PipeItem, + ResolvablePipeItem, + SourcePipeItem, +) from dlt.extract.utils import wrap_async_iterator from dlt.extract.concurrency import FuturesPool @@ -96,7 +101,9 @@ def from_pipe( # create extractor sources = [SourcePipeItem(pipe.gen, 0, pipe, None)] - return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode) + return cls( + max_parallel_items, workers, futures_poll_interval, sources, next_item_mode + ) @classmethod @with_config(spec=PipeIteratorConfiguration) @@ -125,7 +132,9 @@ def _fork_pipeline(pipe: Pipe) -> None: # make the parent yield by sending a clone of item to itself with position at the end if yield_parents and pipe.parent in pipes: # fork is last step of the pipe so it will yield - pipe.parent.fork(pipe.parent, len(pipe.parent) - 1, copy_on_fork=copy_on_fork) + pipe.parent.fork( + pipe.parent, len(pipe.parent) - 1, copy_on_fork=copy_on_fork + ) _fork_pipeline(pipe.parent) else: # head of independent pipe must be iterator @@ -142,7 +151,9 @@ def _fork_pipeline(pipe: Pipe) -> None: _fork_pipeline(pipe) # create extractor - return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode) + return cls( + max_parallel_items, workers, futures_poll_interval, sources, next_item_mode + ) def __next__(self) -> PipeItem: pipe_item: Union[ResolvablePipeItem, SourcePipeItem] = None @@ -192,7 +203,10 @@ def __next__(self) -> PipeItem: if isinstance(item, AsyncIterator): self._sources.append( SourcePipeItem( - wrap_async_iterator(item), pipe_item.step, pipe_item.pipe, pipe_item.meta + wrap_async_iterator(item), + pipe_item.step, + pipe_item.pipe, + pipe_item.meta, ), ) pipe_item = None @@ -208,12 +222,15 @@ def __next__(self) -> PipeItem: # if we are at the end of the pipe then yield element if pipe_item.step == len(pipe_item.pipe) - 1: # must be resolved - if isinstance(item, (Iterator, Awaitable, AsyncIterator)) or callable(item): + if isinstance(item, (Iterator, Awaitable, AsyncIterator)) or callable( + item + ): raise PipeItemProcessingError( pipe_item.pipe.name, - f"Pipe item at step {pipe_item.step} was not fully evaluated and is of type" - f" {type(pipe_item.item).__name__}. This is internal error or you are" - " yielding something weird from resources ie. functions or awaitables.", + f"Pipe item at step {pipe_item.step} was not fully evaluated" + f" and is of type {type(pipe_item.item).__name__}. This is" + " internal error or you are yielding something weird from" + " resources ie. functions or awaitables.", ) # mypy not able to figure out that item was resolved return pipe_item # type: ignore @@ -235,7 +252,12 @@ def __next__(self) -> PipeItem: inspect.signature(step), str(ty_ex), ) - except (PipelineException, ExtractorException, DltSourceException, PipeException): + except ( + PipelineException, + ExtractorException, + DltSourceException, + PipeException, + ): raise except Exception as ex: raise ResourceExtractionError( @@ -259,11 +281,14 @@ def _get_source_item(self) -> ResolvablePipeItem: # always reset to end of list for fifo mode, also take into account that new sources can be added # if too many new sources is added we switch to fifo not to exhaust them if self._next_item_mode == "fifo" or ( - sources_count - self._initial_sources_count >= self._futures_pool.max_parallel_items + sources_count - self._initial_sources_count + >= self._futures_pool.max_parallel_items ): self._current_source_index = sources_count - 1 else: - self._current_source_index = (self._current_source_index - 1) % sources_count + self._current_source_index = ( + self._current_source_index - 1 + ) % sources_count while True: # if we have checked all sources once and all returned None, return and poll/resolve some futures if self._current_source_index == first_evaluated_index: @@ -280,7 +305,9 @@ def _get_source_item(self) -> ResolvablePipeItem: if not isinstance(pipe_item, ResolvablePipeItem): # keep the item assigned step and pipe when creating resolvable item if isinstance(pipe_item, DataItemWithMeta): - return ResolvablePipeItem(pipe_item.data, step, pipe, pipe_item.meta) + return ResolvablePipeItem( + pipe_item.data, step, pipe, pipe_item.meta + ) else: return ResolvablePipeItem(pipe_item, step, pipe, meta) @@ -291,7 +318,9 @@ def _get_source_item(self) -> ResolvablePipeItem: if first_evaluated_index is None: first_evaluated_index = self._current_source_index # always go round robin if None was returned or item is to be run as future - self._current_source_index = (self._current_source_index - 1) % sources_count + self._current_source_index = ( + self._current_source_index - 1 + ) % sources_count except StopIteration: # remove empty iterator and try another source @@ -300,7 +329,12 @@ def _get_source_item(self) -> ResolvablePipeItem: if self._current_source_index < self._initial_sources_count: self._initial_sources_count -= 1 return self._get_source_item() - except (PipelineException, ExtractorException, DltSourceException, PipeException): + except ( + PipelineException, + ExtractorException, + DltSourceException, + PipeException, + ): raise except Exception as ex: raise ResourceExtractionError(pipe.name, gen, str(ex), "generator") from ex @@ -324,7 +358,10 @@ def __enter__(self) -> "PipeIterator": return self def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: types.TracebackType, ) -> None: self.close() @@ -333,7 +370,9 @@ def clone_pipes( pipes: Sequence[Pipe], existing_cloned_pairs: Dict[int, Pipe] = None ) -> Tuple[List[Pipe], Dict[int, Pipe]]: """This will clone pipes and fix the parent/dependent references""" - cloned_pipes = [p._clone() for p in pipes if id(p) not in (existing_cloned_pairs or {})] + cloned_pipes = [ + p._clone() for p in pipes if id(p) not in (existing_cloned_pairs or {}) + ] cloned_pairs = {id(p): c for p, c in zip(pipes, cloned_pipes)} if existing_cloned_pairs: cloned_pairs.update(existing_cloned_pairs) diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index 4776158bbb..5db23836b8 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -16,7 +16,14 @@ from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.specs import known_sections from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.typing import AnyFun, DictStrAny, StrAny, TDataItem, TDataItems, NoneType +from dlt.common.typing import ( + AnyFun, + DictStrAny, + StrAny, + TDataItem, + TDataItems, + NoneType, +) from dlt.common.configuration.container import Container from dlt.common.pipeline import ( PipelineContext, @@ -162,7 +169,10 @@ def from_data( else: # some other data type that is not supported raise InvalidResourceDataType( - name, data, type(data), f"The data type of supplied type is {type(data).__name__}" + name, + data, + type(data), + f"The data type of supplied type is {type(data).__name__}", ) @property @@ -238,8 +248,13 @@ def select_tables(self, *table_names: Iterable[str]) -> "DltResource": """ def _filter(item: TDataItem, meta: Any = None) -> bool: - is_in_meta = isinstance(meta, TableNameMeta) and meta.table_name in table_names - is_in_dyn = self._table_name_hint_fun and self._table_name_hint_fun(item) in table_names + is_in_meta = ( + isinstance(meta, TableNameMeta) and meta.table_name in table_names + ) + is_in_dyn = ( + self._table_name_hint_fun + and self._table_name_hint_fun(item) in table_names + ) return is_in_meta or is_in_dyn # add filtering function at the end of pipe @@ -378,13 +393,17 @@ def parallelize(self) -> "DltResource": ) and not (callable(self._pipe.gen) and self.is_transformer) ): - raise InvalidParallelResourceDataType(self.name, self._pipe.gen, type(self._pipe.gen)) + raise InvalidParallelResourceDataType( + self.name, self._pipe.gen, type(self._pipe.gen) + ) self._pipe.replace_gen(wrap_parallel_iterator(self._pipe.gen)) # type: ignore # TODO return self def add_step( - self, item_transform: ItemTransformFunctionWithMeta[TDataItems], insert_at: int = None + self, + item_transform: ItemTransformFunctionWithMeta[TDataItems], + insert_at: int = None, ) -> "DltResource": # noqa: A003 if insert_at is None: self._pipe.append_step(item_transform) @@ -409,7 +428,9 @@ def _set_hints( self.add_step(incremental) if incremental: - primary_key = table_schema_template.get("primary_key", incremental.primary_key) + primary_key = table_schema_template.get( + "primary_key", incremental.primary_key + ) if primary_key is not None: incremental.primary_key = primary_key @@ -498,7 +519,9 @@ def __iter__(self) -> Iterator[TDataItem]: section_context = self._get_config_section_context() # managed pipe iterator will set the context on each call to __next__ - with inject_section(section_context), Container().injectable_context(state_context): + with inject_section(section_context), Container().injectable_context( + state_context + ): pipe_iterator: ManagedPipeIterator = ManagedPipeIterator.from_pipes([self._pipe]) # type: ignore pipe_iterator.set_context([state_context, section_context]) @@ -550,7 +573,10 @@ def _get_config_section_context(self) -> ConfigSectionContext: "", self.source_name or default_schema_name or self.name, ), - source_state_key=self.source_name or default_schema_name or self.section or uniq_id(), + source_state_key=self.source_name + or default_schema_name + or self.section + or uniq_id(), ) def __str__(self) -> str: @@ -572,14 +598,15 @@ def __str__(self) -> str: if self.requires_args: head_sig = inspect.signature(self._pipe.gen) # type: ignore info += ( - "\nThis resource is parametrized and takes the following arguments" - f" {head_sig}. You must call this resource before loading." + "\nThis resource is parametrized and takes the following" + f" arguments {head_sig}. You must call this resource before" + " loading." ) else: info += ( - "\nIf you want to see the data items in the resource you must iterate it or" - " convert to list ie. list(resource). Note that, like any iterator, you can" - " iterate the resource only once." + "\nIf you want to see the data items in the resource you must" + " iterate it or convert to list ie. list(resource). Note that," + " like any iterator, you can iterate the resource only once." ) else: info += "\nThis resource is not bound to the data" @@ -596,7 +623,9 @@ def _ensure_valid_transformer_resource(name: str, data: Any) -> None: name, get_callable_name(data), inspect.signature(data), valid_code ) else: - raise InvalidTransformerDataTypeGeneratorFunctionRequired(name, data, type(data)) + raise InvalidTransformerDataTypeGeneratorFunctionRequired( + name, data, type(data) + ) @staticmethod def _get_parent_pipe(name: str, data_from: Union["DltResource", Pipe]) -> Pipe: @@ -608,7 +637,9 @@ def _get_parent_pipe(name: str, data_from: Union["DltResource", Pipe]) -> Pipe: else: # if this is generator function provide nicer exception if callable(data_from): - raise InvalidParentResourceIsAFunction(name, get_callable_name(data_from)) + raise InvalidParentResourceIsAFunction( + name, get_callable_name(data_from) + ) else: raise InvalidParentResourceDataType(name, data_from, type(data_from)) @@ -618,7 +649,9 @@ def validate_transformer_generator_function(f: AnyFun) -> int: if len(sig.parameters) == 0: return 1 # transformer may take only one positional only argument - pos_only_len = sum(1 for p in sig.parameters.values() if p.kind == p.POSITIONAL_ONLY) + pos_only_len = sum( + 1 for p in sig.parameters.values() if p.kind == p.POSITIONAL_ONLY + ) if pos_only_len > 1: return 2 first_ar = next(iter(sig.parameters.values())) @@ -626,7 +659,10 @@ def validate_transformer_generator_function(f: AnyFun) -> int: if pos_only_len == 1 and first_ar.kind != first_ar.POSITIONAL_ONLY: return 2 # first arg must be positional or kw_pos - if first_ar.kind not in (first_ar.POSITIONAL_ONLY, first_ar.POSITIONAL_OR_KEYWORD): + if first_ar.kind not in ( + first_ar.POSITIONAL_ONLY, + first_ar.POSITIONAL_OR_KEYWORD, + ): return 3 return 0 diff --git a/dlt/extract/source.py b/dlt/extract/source.py index 5d9799e29c..b9e0c762b4 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -9,7 +9,9 @@ from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.specs import known_sections from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer +from dlt.common.normalizers.json.relational import ( + DataItemNormalizer as RelationalNormalizer, +) from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnName, TSchemaContract from dlt.common.typing import StrAny, TDataItem @@ -21,7 +23,11 @@ source_state, pipeline_state, ) -from dlt.common.utils import graph_find_scc_nodes, flatten_list_or_items, graph_edges_to_nodes +from dlt.common.utils import ( + graph_find_scc_nodes, + flatten_list_or_items, + graph_edges_to_nodes, +) from dlt.extract.items import TDecompositionStrategy from dlt.extract.pipe_iterator import ManagedPipeIterator @@ -69,7 +75,9 @@ def extracted(self) -> Dict[str, DltResource]: mock_template = make_hints( pipe.name, write_disposition=resource.write_disposition ) - resource = DltResource(pipe, mock_template, False, section=resource.section) + resource = DltResource( + pipe, mock_template, False, section=resource.section + ) resource.source_name = resource.source_name extracted[resource.name] = resource else: @@ -132,7 +140,9 @@ def add(self, *resources: DltResource) -> None: def _clone_new_pipes(self, resource_names: Sequence[str]) -> None: # clone all new pipes and keep - _, self._cloned_pairs = ManagedPipeIterator.clone_pipes(self._new_pipes, self._cloned_pairs) + _, self._cloned_pairs = ManagedPipeIterator.clone_pipes( + self._new_pipes, self._cloned_pairs + ) # self._cloned_pairs.update(cloned_pairs) # replace pipes in resources, the cloned_pipes preserve parent connections for name in resource_names: @@ -218,11 +228,15 @@ def name(self) -> str: @property def max_table_nesting(self) -> int: """A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON.""" - return RelationalNormalizer.get_normalizer_config(self._schema).get("max_nesting") + return RelationalNormalizer.get_normalizer_config(self._schema).get( + "max_nesting" + ) @max_table_nesting.setter def max_table_nesting(self, value: int) -> None: - RelationalNormalizer.update_normalizer_config(self._schema, {"max_nesting": value}) + RelationalNormalizer.update_normalizer_config( + self._schema, {"max_nesting": value} + ) @property def schema_contract(self) -> TSchemaContract: @@ -245,7 +259,9 @@ def exhausted(self) -> bool: @property def root_key(self) -> bool: """Enables merging on all resources by propagating root foreign key to child tables. This option is most useful if you plan to change write disposition of a resource to disable/enable merge""" - config = RelationalNormalizer.get_normalizer_config(self._schema).get("propagation") + config = RelationalNormalizer.get_normalizer_config(self._schema).get( + "propagation" + ) return ( config is not None and "root" in config @@ -257,13 +273,14 @@ def root_key(self) -> bool: def root_key(self, value: bool) -> None: if value is True: RelationalNormalizer.update_normalizer_config( - self._schema, {"propagation": {"root": {"_dlt_id": TColumnName("_dlt_root_id")}}} + self._schema, + {"propagation": {"root": {"_dlt_id": TColumnName("_dlt_root_id")}}}, ) else: if self.root_key: - propagation_config = RelationalNormalizer.get_normalizer_config(self._schema)[ - "propagation" - ] + propagation_config = RelationalNormalizer.get_normalizer_config( + self._schema + )["propagation"] propagation_config["root"].pop("_dlt_id") # type: ignore @property @@ -316,7 +333,10 @@ def decompose(self, strategy: TDecompositionStrategy) -> List["DltSource"]: scc = graph_find_scc_nodes(graph_edges_to_nodes(dag, directed=False)) # components contain elements that are not currently selected selected_set = set(self.resources.selected.keys()) - return [self.with_resources(*component.intersection(selected_set)) for component in scc] + return [ + self.with_resources(*component.intersection(selected_set)) + for component in scc + ] else: raise ValueError(strategy) @@ -368,7 +388,9 @@ def clone(self, with_name: str = None) -> "DltSource": """ # mind that resources and pipes are cloned when added to the DltResourcesDict in the source constructor return DltSource( - self.schema.clone(with_name=with_name), self.section, list(self._resources.values()) + self.schema.clone(with_name=with_name), + self.section, + list(self._resources.values()), ) def __iter__(self) -> Iterator[TDataItem]: @@ -384,7 +406,9 @@ def __iter__(self) -> Iterator[TDataItem]: section_context = self._get_config_section_context() # managed pipe iterator will set the context on each call to __next__ - with inject_section(section_context), Container().injectable_context(state_context): + with inject_section(section_context), Container().injectable_context( + state_context + ): pipe_iterator: ManagedPipeIterator = ManagedPipeIterator.from_pipes(self._resources.selected_pipes) # type: ignore pipe_iterator.set_context([section_context, state_context]) _iter = map(lambda item: item.item, pipe_iterator) @@ -392,7 +416,9 @@ def __iter__(self) -> Iterator[TDataItem]: def _get_config_section_context(self) -> ConfigSectionContext: proxy = Container()[PipelineContext] - pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name + pipeline_name = ( + None if not proxy.is_active() else proxy.pipeline().pipeline_name + ) return ConfigSectionContext( pipeline_name=pipeline_name, sections=(known_sections.SOURCES, self.section, self.name), @@ -416,8 +442,8 @@ def __setattr__(self, name: str, value: Any) -> None: def __str__(self) -> str: info = ( f"DltSource {self.name} section {self.section} contains" - f" {len(self.resources)} resource(s) of which {len(self.selected_resources)} are" - " selected" + f" {len(self.resources)} resource(s) of which" + f" {len(self.selected_resources)} are selected" ) for r in self.resources.values(): selected_info = "selected" if r.selected else "not selected" @@ -430,12 +456,13 @@ def __str__(self) -> str: info += f"\nresource {r.name} is {selected_info}" if self.exhausted: info += ( - "\nSource is already iterated and cannot be used again ie. to display or load data." + "\nSource is already iterated and cannot be used again ie. to display" + " or load data." ) else: info += ( - "\nIf you want to see the data items in this source you must iterate it or convert" - " to list ie. list(source)." + "\nIf you want to see the data items in this source you must iterate it" + " or convert to list ie. list(source)." ) info += " Note that, like any iterator, you can iterate the source only once." info += f"\ninstance id: {id(self)}" diff --git a/dlt/extract/storage.py b/dlt/extract/storage.py index 251d7a5ce9..5c6a87c8f9 100644 --- a/dlt/extract/storage.py +++ b/dlt/extract/storage.py @@ -27,7 +27,9 @@ def __init__(self, package_storage: PackageStorage) -> None: super().__init__(self.load_file_type) self.package_storage = package_storage - def _get_data_item_path_template(self, load_id: str, _: str, table_name: str) -> str: + def _get_data_item_path_template( + self, load_id: str, _: str, table_name: str + ) -> str: file_name = PackageStorage.build_job_file_name(table_name, "%s") file_path = self.package_storage.get_job_file_path( load_id, PackageStorage.NEW_JOBS_FOLDER, file_name @@ -53,14 +55,19 @@ def __init__(self, config: NormalizeStorageConfiguration) -> None: self.new_packages_folder = uniq_id(8) self.storage.create_folder(self.new_packages_folder, exists_ok=True) self.new_packages = PackageStorage( - FileStorage(os.path.join(self.storage.storage_path, self.new_packages_folder)), "new" + FileStorage( + os.path.join(self.storage.storage_path, self.new_packages_folder) + ), + "new", ) self._item_storages: Dict[TLoaderFileFormat, ExtractorItemStorage] = { "puae-jsonl": JsonLExtractorStorage(self.new_packages), "arrow": ArrowExtractorStorage(self.new_packages), } - def create_load_package(self, schema: Schema, reuse_exiting_package: bool = True) -> str: + def create_load_package( + self, schema: Schema, reuse_exiting_package: bool = True + ) -> str: """Creates a new load package for given `schema` or returns if such package already exists. You can prevent reuse of the existing package by setting `reuse_exiting_package` to False @@ -81,7 +88,9 @@ def create_load_package(self, schema: Schema, reuse_exiting_package: bool = True self.new_packages.save_schema(load_id, schema) return load_id - def get_storage(self, loader_file_format: TLoaderFileFormat) -> ExtractorItemStorage: + def get_storage( + self, loader_file_format: TLoaderFileFormat + ) -> ExtractorItemStorage: return self._item_storages[loader_file_format] def close_writers(self, load_id: str) -> None: @@ -101,9 +110,12 @@ def remove_closed_files(self, load_id: str) -> None: def commit_new_load_package(self, load_id: str, schema: Schema) -> None: self.new_packages.save_schema(load_id, schema) self.storage.rename_tree( - os.path.join(self.new_packages_folder, self.new_packages.get_package_path(load_id)), os.path.join( - NormalizeStorage.EXTRACTED_FOLDER, self.new_packages.get_package_path(load_id) + self.new_packages_folder, self.new_packages.get_package_path(load_id) + ), + os.path.join( + NormalizeStorage.EXTRACTED_FOLDER, + self.new_packages.get_package_path(load_id), ), ) diff --git a/dlt/extract/utils.py b/dlt/extract/utils.py index 69edcab93d..2473ec69f6 100644 --- a/dlt/extract/utils.py +++ b/dlt/extract/utils.py @@ -21,8 +21,18 @@ from dlt.common.exceptions import MissingDependencyException from dlt.common.pipeline import reset_resource_state -from dlt.common.schema.typing import TColumnNames, TAnySchemaColumns, TTableSchemaColumns -from dlt.common.typing import AnyFun, DictStrAny, TDataItem, TDataItems, TAnyFunOrGenerator +from dlt.common.schema.typing import ( + TColumnNames, + TAnySchemaColumns, + TTableSchemaColumns, +) +from dlt.common.typing import ( + AnyFun, + DictStrAny, + TDataItem, + TDataItems, + TAnyFunOrGenerator, +) from dlt.common.utils import get_callable_name from dlt.extract.exceptions import ( @@ -70,7 +80,8 @@ def ensure_table_schema_columns(columns: TAnySchemaColumns) -> TTableSchemaColum # Assume list of columns return {col["name"]: col for col in columns} elif pydantic is not None and ( - isinstance(columns, pydantic.BaseModel) or issubclass(columns, pydantic.BaseModel) + isinstance(columns, pydantic.BaseModel) + or issubclass(columns, pydantic.BaseModel) ): return pydantic.pydantic_to_table_schema_columns(columns) @@ -95,7 +106,9 @@ def wrapper(item: TDataItem) -> TTableSchemaColumns: return ensure_table_schema_columns(columns) -def reset_pipe_state(pipe: SupportsPipe, source_state_: Optional[DictStrAny] = None) -> None: +def reset_pipe_state( + pipe: SupportsPipe, source_state_: Optional[DictStrAny] = None +) -> None: """Resets the resource state for a `pipe` and all its parent pipes""" if pipe.has_parent: reset_pipe_state(pipe.parent, source_state_) @@ -124,11 +137,15 @@ def simulate_func_call( return sig, no_item_sig, bound_args -def check_compat_transformer(name: str, f: AnyFun, sig: inspect.Signature) -> inspect.Parameter: +def check_compat_transformer( + name: str, f: AnyFun, sig: inspect.Signature +) -> inspect.Parameter: sig_arg_count = len(sig.parameters) callable_name = get_callable_name(f) if sig_arg_count == 0: - raise InvalidStepFunctionArguments(name, callable_name, sig, "Function takes no arguments") + raise InvalidStepFunctionArguments( + name, callable_name, sig, "Function takes no arguments" + ) # see if meta is present in kwargs meta_arg = next((p for p in sig.parameters.values() if p.name == "meta"), None) diff --git a/dlt/extract/validation.py b/dlt/extract/validation.py index 504eee1bfc..5a8b9923af 100644 --- a/dlt/extract/validation.py +++ b/dlt/extract/validation.py @@ -7,7 +7,11 @@ PydanticBaseModel = Any # type: ignore[misc, assignment] from dlt.common.typing import TDataItems -from dlt.common.schema.typing import TAnySchemaColumns, TSchemaContract, TSchemaEvolutionMode +from dlt.common.schema.typing import ( + TAnySchemaColumns, + TSchemaContract, + TSchemaEvolutionMode, +) from dlt.extract.items import TTableHintTemplate, ValidateItem @@ -23,7 +27,10 @@ def __init__( column_mode: TSchemaEvolutionMode, data_mode: TSchemaEvolutionMode, ) -> None: - from dlt.common.libs.pydantic import apply_schema_contract_to_model, create_list_model + from dlt.common.libs.pydantic import ( + apply_schema_contract_to_model, + create_list_model, + ) self.column_mode: TSchemaEvolutionMode = column_mode self.data_mode: TSchemaEvolutionMode = data_mode @@ -43,7 +50,9 @@ def __call__( return validate_items( self.table_name, self.list_model, item, self.column_mode, self.data_mode ) - return validate_item(self.table_name, self.model, item, self.column_mode, self.data_mode) + return validate_item( + self.table_name, self.model, item, self.column_mode, self.data_mode + ) def __str__(self, *args: Any, **kwargs: Any) -> str: return f"PydanticValidator(model={self.model.__qualname__})" @@ -81,7 +90,9 @@ def create_item_validator( ) return ( PydanticValidator( - columns, expanded_schema_contract["columns"], expanded_schema_contract["data_type"] + columns, + expanded_schema_contract["columns"], + expanded_schema_contract["data_type"], ), schema_contract or expanded_schema_contract, ) diff --git a/dlt/helpers/airflow_helper.py b/dlt/helpers/airflow_helper.py index e01cf790d2..cd8242741b 100644 --- a/dlt/helpers/airflow_helper.py +++ b/dlt/helpers/airflow_helper.py @@ -30,9 +30,13 @@ from dlt.common.data_writers import TLoaderFileFormat from dlt.common.schema.typing import TWriteDisposition, TSchemaContract from dlt.common.utils import uniq_id -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention +from dlt.common.normalizers.naming.snake_case import ( + NamingConvention as SnakeCaseNamingConvention, +) from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.runtime.collector import NULL_COLLECTOR from dlt.extract import DltSource @@ -44,7 +48,9 @@ DEFAULT_RETRY_NO_RETRY = Retrying(stop=stop_after_attempt(1), reraise=True) DEFAULT_RETRY_BACKOFF = Retrying( - stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1.5, min=4, max=10), reraise=True + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1.5, min=4, max=10), + reraise=True, ) @@ -126,7 +132,9 @@ def __init__( data_dir = os.path.join("/home/airflow/gcs/data", f"dlt_{uniq_id(8)}") else: # create random path - data_dir = os.path.join(local_data_folder or gettempdir(), f"dlt_{uniq_id(8)}") + data_dir = os.path.join( + local_data_folder or gettempdir(), f"dlt_{uniq_id(8)}" + ) os.environ["DLT_DATA_DIR"] = data_dir # delete existing config providers in container, they will get reloaded on next use @@ -203,7 +211,9 @@ def run( schema_contract=schema_contract, pipeline_name=pipeline_name, ) - return PythonOperator(task_id=self._task_name(pipeline, data), python_callable=f, **kwargs) + return PythonOperator( + task_id=self._task_name(pipeline, data), python_callable=f, **kwargs + ) def _run( self, @@ -246,7 +256,10 @@ def _run( logger.LOGGER = ti.log # set global number of buffered items - if dlt.config.get("data_writer.buffer_max_items") is None and self.buffer_max_items > 0: + if ( + dlt.config.get("data_writer.buffer_max_items") is None + and self.buffer_max_items > 0 + ): dlt.config["data_writer.buffer_max_items"] = self.buffer_max_items logger.info(f"Set data_writer.buffer_max_items to {self.buffer_max_items}") @@ -256,7 +269,9 @@ def _run( logger.info("Set load.abort_task_if_any_job_failed to True") if self.log_progress_period > 0 and task_pipeline.collector == NULL_COLLECTOR: - task_pipeline.collector = log(log_period=self.log_progress_period, logger=logger.LOGGER) + task_pipeline.collector = log( + log_period=self.log_progress_period, logger=logger.LOGGER + ) logger.info(f"Enabled log progress with period {self.log_progress_period}") logger.info(f"Pipeline data in {task_pipeline.working_dir}") @@ -278,7 +293,8 @@ def log_after_attempt(retry_state: RetryCallState) -> None: ): with attempt: logger.info( - "Running the pipeline, attempt=%s" % attempt.retry_state.attempt_number + "Running the pipeline, attempt=%s" + % attempt.retry_state.attempt_number ) load_info = task_pipeline.run( data, @@ -318,7 +334,9 @@ def add_run( pipeline: Pipeline, data: Any, *, - decompose: Literal["none", "serialize", "parallel", "parallel-isolated"] = "none", + decompose: Literal[ + "none", "serialize", "parallel", "parallel-isolated" + ] = "none", table_name: str = None, write_disposition: TWriteDisposition = None, loader_file_format: TLoaderFileFormat = None, @@ -371,15 +389,17 @@ def add_run( # make sure that pipeline was created after dag was initialized if not pipeline.pipelines_dir.startswith(os.environ["DLT_DATA_DIR"]): raise ValueError( - "Please create your Pipeline instance after AirflowTasks are created. The dlt" - " pipelines directory is not set correctly." + "Please create your Pipeline instance after AirflowTasks are created." + " The dlt pipelines directory is not set correctly." ) with self: # use factory function to make a task, in order to parametrize it # passing arguments to task function (_run) is serializing # them and running template engine on them - def make_task(pipeline: Pipeline, data: Any, name: str = None) -> PythonOperator: + def make_task( + pipeline: Pipeline, data: Any, name: str = None + ) -> PythonOperator: f = functools.partial( self._run, pipeline, @@ -489,7 +509,9 @@ def make_task(pipeline: Pipeline, data: Any, name: str = None) -> PythonOperator ) -def airflow_get_execution_dates() -> Tuple[pendulum.DateTime, Optional[pendulum.DateTime]]: +def airflow_get_execution_dates() -> ( + Tuple[pendulum.DateTime, Optional[pendulum.DateTime]] +): # prefer logging to task logger try: from airflow.operators.python import get_current_context # noqa diff --git a/dlt/helpers/dbt/configuration.py b/dlt/helpers/dbt/configuration.py index 70fa4d1ac5..b303827b22 100644 --- a/dlt/helpers/dbt/configuration.py +++ b/dlt/helpers/dbt/configuration.py @@ -25,6 +25,11 @@ def on_resolved(self) -> None: if not self.package_profiles_dir: # use "profile.yml" located in the same folder as current module self.package_profiles_dir = os.path.dirname(__file__) - if self.package_repository_ssh_key and self.package_repository_ssh_key[-1] != "\n": + if ( + self.package_repository_ssh_key + and self.package_repository_ssh_key[-1] != "\n" + ): # must end with new line, otherwise won't be parsed by Crypto - self.package_repository_ssh_key = TSecretValue(self.package_repository_ssh_key + "\n") + self.package_repository_ssh_key = TSecretValue( + self.package_repository_ssh_key + "\n" + ) diff --git a/dlt/helpers/dbt/dbt_utils.py b/dlt/helpers/dbt/dbt_utils.py index b4097e4434..1e065f2a31 100644 --- a/dlt/helpers/dbt/dbt_utils.py +++ b/dlt/helpers/dbt/dbt_utils.py @@ -169,7 +169,9 @@ def run_dbt_command( # oftentimes dbt tries to exit on error raise DBTProcessingError(command, None, sys_ex) except FailFastException as ff: - dbt_exc = DBTProcessingError(command, parse_dbt_execution_results(ff.result), ff.result) + dbt_exc = DBTProcessingError( + command, parse_dbt_execution_results(ff.result), ff.result + ) # detect incremental model out of sync if is_incremental_schema_out_of_sync_error(ff.result): raise IncrementalSchemaOutOfSyncError(dbt_exc) from ff diff --git a/dlt/helpers/dbt/runner.py b/dlt/helpers/dbt/runner.py index 7b1f79dc77..d981dbf446 100644 --- a/dlt/helpers/dbt/runner.py +++ b/dlt/helpers/dbt/runner.py @@ -91,14 +91,17 @@ def _log_dbt_run_results(self, results: Sequence[DBTNodeResult]) -> None: logger.error(f"Model {res.model_name} error! Error: {res.message}") else: logger.info( - f"Model {res.model_name} {res.status} in {res.time} seconds with {res.message}" + f"Model {res.model_name} {res.status} in {res.time} seconds with" + f" {res.message}" ) def ensure_newest_package(self) -> None: """Clones or brings the dbt package at `package_location` up to date.""" from git import GitError - with git_custom_key_command(self.config.package_repository_ssh_key) as ssh_command: + with git_custom_key_command( + self.config.package_repository_ssh_key + ) as ssh_command: try: ensure_remote_head( self.package_path, @@ -107,7 +110,9 @@ def ensure_newest_package(self) -> None: ) except GitError as err: # cleanup package folder - logger.info(f"Package will be cloned due to {type(err).__name__}:{str(err)}") + logger.info( + f"Package will be cloned due to {type(err).__name__}:{str(err)}" + ) logger.info( f"Will clone {self.config.package_location} head" f" {self.config.package_repository_branch} into {self.package_path}" @@ -122,7 +127,10 @@ def ensure_newest_package(self) -> None: @with_custom_environ def _run_dbt_command( - self, command: str, command_args: Sequence[str] = None, package_vars: StrAny = None + self, + command: str, + command_args: Sequence[str] = None, + package_vars: StrAny = None, ) -> Sequence[DBTNodeResult]: logger.info( f"Exec dbt command: {command} {command_args} {package_vars} on profile" @@ -184,7 +192,9 @@ def run( DBTProcessingError: `run` command failed. Contains a list of models with their execution statuses and error messages """ return self._run_dbt_command( - "run", cmd_params, self._get_package_vars(additional_vars, destination_dataset_name) + "run", + cmd_params, + self._get_package_vars(additional_vars, destination_dataset_name), ) def test( @@ -209,11 +219,16 @@ def test( DBTProcessingError: `test` command failed. Contains a list of models with their execution statuses and error messages """ return self._run_dbt_command( - "test", cmd_params, self._get_package_vars(additional_vars, destination_dataset_name) + "test", + cmd_params, + self._get_package_vars(additional_vars, destination_dataset_name), ) def _run_db_steps( - self, run_params: Sequence[str], package_vars: StrAny, source_tests_selector: str + self, + run_params: Sequence[str], + package_vars: StrAny, + source_tests_selector: str, ) -> Sequence[DBTNodeResult]: if self.repo_storage: # make sure we use package from the remote head @@ -240,7 +255,9 @@ def _run_db_steps( return self.run(run_params, package_vars) except IncrementalSchemaOutOfSyncError: if self.config.auto_full_refresh_when_out_of_sync: - logger.warning("Attempting full refresh due to incremental model out of sync") + logger.warning( + "Attempting full refresh due to incremental model out of sync" + ) return self.run(run_params + ["--full-refresh"], package_vars) else: raise diff --git a/dlt/helpers/dbt_cloud/client.py b/dlt/helpers/dbt_cloud/client.py index 67d315f0d1..0c0a51f5ea 100644 --- a/dlt/helpers/dbt_cloud/client.py +++ b/dlt/helpers/dbt_cloud/client.py @@ -41,11 +41,15 @@ def __init__( self.accounts_url = f"accounts/{self.account_id}" def get_endpoint(self, endpoint: str) -> Any: - response = requests.get(f"{self.base_api_url}/{endpoint}", headers=self._headers) + response = requests.get( + f"{self.base_api_url}/{endpoint}", headers=self._headers + ) results = response.json() return results - def post_endpoint(self, endpoint: str, json_body: Optional[Dict[Any, Any]] = None) -> Any: + def post_endpoint( + self, endpoint: str, json_body: Optional[Dict[Any, Any]] = None + ) -> Any: response = requests.post( f"{self.base_api_url}/{endpoint}", headers=self._headers, @@ -101,15 +105,17 @@ def trigger_job_run( """ if not (self.account_id and job_id): raise InvalidCredentialsException( - f"account_id and job_id are required, got account_id: {self.account_id} and job_id:" - f" {job_id}" + "account_id and job_id are required, got account_id:" + f" {self.account_id} and job_id: {job_id}" ) json_body = {} if data: json_body.update(data) - response = self.post_endpoint(f"{self.accounts_url}/jobs/{job_id}/run", json_body=json_body) + response = self.post_endpoint( + f"{self.accounts_url}/jobs/{job_id}/run", json_body=json_body + ) return int(response["data"]["id"]) def get_run_status(self, run_id: Union[int, str]) -> Dict[Any, Any]: @@ -133,8 +139,8 @@ def get_run_status(self, run_id: Union[int, str]) -> Dict[Any, Any]: """ if not (self.account_id and run_id): raise InvalidCredentialsException( - f"account_id and run_id are required, got account_id: {self.account_id} and run_id:" - f" {run_id}." + "account_id and run_id are required, got account_id:" + f" {self.account_id} and run_id: {run_id}." ) response = self.get_endpoint(f"{self.accounts_url}/runs/{run_id}") diff --git a/dlt/helpers/streamlit_app/blocks/load_info.py b/dlt/helpers/streamlit_app/blocks/load_info.py index 134b5ad5a4..71d5cbb00e 100644 --- a/dlt/helpers/streamlit_app/blocks/load_info.py +++ b/dlt/helpers/streamlit_app/blocks/load_info.py @@ -10,8 +10,9 @@ def last_load_info(pipeline: dlt.Pipeline) -> None: loads_df = query_data_live( pipeline, - f"SELECT load_id, inserted_at FROM {pipeline.default_schema.loads_table_name} WHERE" - " status = 0 ORDER BY inserted_at DESC LIMIT 101 ", + "SELECT load_id, inserted_at FROM" + f" {pipeline.default_schema.loads_table_name} WHERE status = 0 ORDER BY" + " inserted_at DESC LIMIT 101 ", ) if loads_df is None: @@ -24,7 +25,8 @@ def last_load_info(pipeline: dlt.Pipeline) -> None: if loads_df.shape[0] > 0: rel_time = ( humanize.naturaldelta( - pendulum.now() - pendulum.from_timestamp(loads_df.iloc[0, 1].timestamp()) + pendulum.now() + - pendulum.from_timestamp(loads_df.iloc[0, 1].timestamp()) ) + " ago" ) diff --git a/dlt/helpers/streamlit_app/blocks/query.py b/dlt/helpers/streamlit_app/blocks/query.py index a03e9a0cd9..859b1f1fd8 100644 --- a/dlt/helpers/streamlit_app/blocks/query.py +++ b/dlt/helpers/streamlit_app/blocks/query.py @@ -37,8 +37,8 @@ def maybe_run_query( raise MissingDependencyException( "DLT Streamlit Helpers", ["altair"], - "DLT Helpers for Streamlit should be run within a streamlit" - " app.", + "DLT Helpers for Streamlit should be run within a" + " streamlit app.", ) # try altair @@ -46,7 +46,8 @@ def maybe_run_query( alt.Chart(df) .mark_bar() .encode( - x=f"{df.columns[1]}:Q", y=alt.Y(f"{df.columns[0]}:N", sort="-x") + x=f"{df.columns[1]}:Q", + y=alt.Y(f"{df.columns[0]}:N", sort="-x"), ) ) st.altair_chart(bar_chart, use_container_width=True) diff --git a/dlt/helpers/streamlit_app/pages/load_info.py b/dlt/helpers/streamlit_app/pages/load_info.py index ee13cf2531..026c38aacb 100644 --- a/dlt/helpers/streamlit_app/pages/load_info.py +++ b/dlt/helpers/streamlit_app/pages/load_info.py @@ -22,8 +22,9 @@ def write_load_status_page(pipeline: Pipeline) -> None: try: loads_df = query_data_live( pipeline, - f"SELECT load_id, inserted_at FROM {pipeline.default_schema.loads_table_name} WHERE" - " status = 0 ORDER BY inserted_at DESC LIMIT 101 ", + "SELECT load_id, inserted_at FROM" + f" {pipeline.default_schema.loads_table_name} WHERE status = 0 ORDER BY" + " inserted_at DESC LIMIT 101 ", ) if loads_df is not None: @@ -57,24 +58,29 @@ def write_load_status_page(pipeline: Pipeline) -> None: schemas_df = query_data_live( pipeline, "SELECT schema_name, inserted_at, version, version_hash FROM" - f" {pipeline.default_schema.version_table_name} ORDER BY inserted_at DESC LIMIT" - " 101 ", + f" {pipeline.default_schema.version_table_name} ORDER BY inserted_at" + " DESC LIMIT 101 ", ) st.markdown("**100 recent schema updates**") st.dataframe(schemas_df) except CannotRestorePipelineException as restore_ex: - st.error("Seems like the pipeline does not exist. Did you run it at least once?") + st.error( + "Seems like the pipeline does not exist. Did you run it at least once?" + ) st.exception(restore_ex) except ConfigFieldMissingException as cf_ex: st.error( - "Pipeline credentials/configuration is missing. This most often happen when you run the" - " streamlit app from different folder than the `.dlt` with `toml` files resides." + "Pipeline credentials/configuration is missing. This most often happen when" + " you run the streamlit app from different folder than the `.dlt` with" + " `toml` files resides." ) st.text(str(cf_ex)) except Exception as ex: - st.error("Pipeline info could not be prepared. Did you load the data at least once?") + st.error( + "Pipeline info could not be prepared. Did you load the data at least once?" + ) st.exception(ex) @@ -83,7 +89,9 @@ def show_state_versions(pipeline: dlt.Pipeline) -> None: remote_state = None with pipeline.destination_client() as client: if isinstance(client, WithStateSync): - remote_state = load_pipeline_state_from_destination(pipeline.pipeline_name, client) + remote_state = load_pipeline_state_from_destination( + pipeline.pipeline_name, client + ) local_state = pipeline.state @@ -111,7 +119,8 @@ def show_state_versions(pipeline: dlt.Pipeline) -> None: if remote_state_version != str(local_state["_state_version"]): st.text("") st.warning( - "Looks like that local state is not yet synchronized or synchronization is disabled", + "Looks like that local state is not yet synchronized or synchronization is" + " disabled", icon="⚠️", ) diff --git a/dlt/helpers/streamlit_app/utils.py b/dlt/helpers/streamlit_app/utils.py index cf1728c33b..6b2dab495c 100644 --- a/dlt/helpers/streamlit_app/utils.py +++ b/dlt/helpers/streamlit_app/utils.py @@ -38,7 +38,9 @@ def render_with_pipeline(render_func: Callable[..., None]) -> None: render_func(pipeline) -def query_using_cache(pipeline: dlt.Pipeline, ttl: int) -> Callable[..., Optional[pd.DataFrame]]: +def query_using_cache( + pipeline: dlt.Pipeline, ttl: int +) -> Callable[..., Optional[pd.DataFrame]]: @st.cache_data(ttl=ttl) def do_query( # type: ignore[return] query: str, diff --git a/dlt/load/exceptions.py b/dlt/load/exceptions.py index e85dffd2e9..0046f39784 100644 --- a/dlt/load/exceptions.py +++ b/dlt/load/exceptions.py @@ -11,20 +11,23 @@ def __init__(self, load_id: str, job_id: str, failed_message: str) -> None: self.job_id = job_id self.failed_message = failed_message super().__init__( - f"Job for {job_id} failed terminally in load {load_id} with message {failed_message}." - " The package is aborted and cannot be retried." + f"Job for {job_id} failed terminally in load {load_id} with message" + f" {failed_message}. The package is aborted and cannot be retried." ) class LoadClientJobRetry(DestinationTransientException): - def __init__(self, load_id: str, job_id: str, retry_count: int, max_retry_count: int) -> None: + def __init__( + self, load_id: str, job_id: str, retry_count: int, max_retry_count: int + ) -> None: self.load_id = load_id self.job_id = job_id self.retry_count = retry_count self.max_retry_count = max_retry_count super().__init__( - f"Job for {job_id} had {retry_count} retries which a multiple of {max_retry_count}." - " Exiting retry loop. You can still rerun the load package to retry this job." + f"Job for {job_id} had {retry_count} retries which a multiple of" + f" {max_retry_count}. Exiting retry loop. You can still rerun the load" + " package to retry this job." ) @@ -36,8 +39,8 @@ def __init__( self.supported_types = supported_file_format self.file_path = file_path super().__init__( - f"Loader does not support writer {file_format} in file {file_path}. Supported writers:" - f" {supported_file_format}" + f"Loader does not support writer {file_format} in file {file_path}." + f" Supported writers: {supported_file_format}" ) @@ -47,6 +50,6 @@ def __init__(self, table_name: str, write_disposition: str, file_name: str) -> N self.write_disposition = write_disposition self.file_name = file_name super().__init__( - f"Loader does not support {write_disposition} in table {table_name} when loading file" - f" {file_name}" + f"Loader does not support {write_disposition} in table {table_name} when" + f" loading file {file_name}" ) diff --git a/dlt/load/load.py b/dlt/load/load.py index f02a21f98e..9b8b6acbab 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -16,7 +16,11 @@ WithStepInfo, ) from dlt.common.schema.utils import get_top_level_table -from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState +from dlt.common.storages.load_storage import ( + LoadPackageInfo, + ParsedLoadJobFileName, + TJobState, +) from dlt.common.storages.load_package import LoadPackageStateInjectableContext from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor from dlt.common.runtime.collector import Collector, NULL_COLLECTOR @@ -102,7 +106,9 @@ def get_destination_client(self, schema: Schema) -> JobClientBase: return self.destination.client(schema, self.initial_client_config) def get_staging_destination_client(self, schema: Schema) -> JobClientBase: - return self.staging_destination.client(schema, self.initial_staging_client_config) + return self.staging_destination.client( + schema, self.initial_staging_client_config + ) def is_staging_destination_job(self, file_path: str) -> bool: return ( @@ -145,7 +151,9 @@ def w_spool_job( self.capabilities.supported_loader_file_formats, file_path, ) - logger.info(f"Will load file {file_path} with table name {job_info.table_name}") + logger.info( + f"Will load file {file_path} with table name {job_info.table_name}" + ) table = client.prepare_load_table(job_info.table_name) if table["write_disposition"] not in ["append", "replace", "merge"]: raise LoadClientUnsupportedWriteDisposition( @@ -166,17 +174,23 @@ def w_spool_job( with self.maybe_with_staging_dataset(client, use_staging_dataset): job = client.start_file_load( table, - self.load_storage.normalized_packages.storage.make_full_path(file_path), + self.load_storage.normalized_packages.storage.make_full_path( + file_path + ), load_id, ) except (DestinationTerminalException, TerminalValueError): # if job irreversibly cannot be started, mark it as failed logger.exception(f"Terminal problem when adding job {file_path}") - job = EmptyLoadJob.from_file_path(file_path, "failed", pretty_format_exception()) + job = EmptyLoadJob.from_file_path( + file_path, "failed", pretty_format_exception() + ) except (DestinationTransientException, Exception): # return no job so file stays in new jobs (root) folder logger.exception(f"Temporary problem when adding job {file_path}") - job = EmptyLoadJob.from_file_path(file_path, "retry", pretty_format_exception()) + job = EmptyLoadJob.from_file_path( + file_path, "retry", pretty_format_exception() + ) self.load_storage.normalized_packages.start_job(load_id, job.file_name()) return job @@ -203,18 +217,29 @@ def retrieve_jobs( # list all files that were started but not yet completed started_jobs = self.load_storage.normalized_packages.list_started_jobs(load_id) - logger.info(f"Found {len(started_jobs)} that are already started and should be continued") + logger.info( + f"Found {len(started_jobs)} that are already started and should be" + " continued" + ) if len(started_jobs) == 0: return 0, jobs for file_path in started_jobs: try: logger.info(f"Will retrieve {file_path}") - client = staging_client if self.is_staging_destination_job(file_path) else client + client = ( + staging_client + if self.is_staging_destination_job(file_path) + else client + ) job = client.restore_file_load(file_path) except DestinationTerminalException: - logger.exception(f"Job retrieval for {file_path} failed, job will be terminated") - job = EmptyLoadJob.from_file_path(file_path, "failed", pretty_format_exception()) + logger.exception( + f"Job retrieval for {file_path} failed, job will be terminated" + ) + job = EmptyLoadJob.from_file_path( + file_path, "failed", pretty_format_exception() + ) # proceed to appending job, do not reraise except (DestinationTransientException, Exception): # raise on all temporary exceptions, typically network / server problems @@ -237,7 +262,9 @@ def create_followup_jobs( # check for merge jobs only for jobs executing on the destination, the staging destination jobs must be excluded # NOTE: we may move that logic to the interface starting_job_file_name = starting_job.file_name() - if state == "completed" and not self.is_staging_destination_job(starting_job_file_name): + if state == "completed" and not self.is_staging_destination_job( + starting_job_file_name + ): client = self.destination.client(schema, self.initial_client_config) top_job_table = get_top_level_table( schema.tables, starting_job.job_file_info().table_name @@ -245,7 +272,10 @@ def create_followup_jobs( # if all tables of chain completed, create follow up jobs all_jobs = self.load_storage.normalized_packages.list_all_jobs(load_id) if table_chain := get_completed_table_chain( - schema, all_jobs, top_job_table, starting_job.job_file_info().job_id() + schema, + all_jobs, + top_job_table, + starting_job.job_file_info().job_id(), ): if follow_up_jobs := client.create_table_chain_completed_followup_jobs( table_chain @@ -254,7 +284,9 @@ def create_followup_jobs( jobs = jobs + starting_job.create_followup_jobs(state) return jobs - def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> List[LoadJob]: + def complete_jobs( + self, load_id: str, jobs: List[LoadJob], schema: Schema + ) -> List[LoadJob]: """Run periodically in the main thread to collect job execution statuses. After detecting change of status, it commits the job state by moving it to the right folder @@ -292,7 +324,9 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: remaining_jobs.append(job) elif state == "failed": # create followup jobs - _schedule_followup_jobs(self.create_followup_jobs(load_id, state, job, schema)) + _schedule_followup_jobs( + self.create_followup_jobs(load_id, state, job, schema) + ) # try to get exception message from job failed_message = job.exception() @@ -300,35 +334,47 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: load_id, job.file_name(), failed_message ) logger.error( - f"Job for {job.job_id()} failed terminally in load {load_id} with message" - f" {failed_message}" + f"Job for {job.job_id()} failed terminally in load {load_id} with" + f" message {failed_message}" ) elif state == "retry": # try to get exception message from job retry_message = job.exception() # move back to new folder to try again - self.load_storage.normalized_packages.retry_job(load_id, job.file_name()) + self.load_storage.normalized_packages.retry_job( + load_id, job.file_name() + ) logger.warning( - f"Job for {job.job_id()} retried in load {load_id} with message {retry_message}" + f"Job for {job.job_id()} retried in load {load_id} with message" + f" {retry_message}" ) elif state == "completed": # create followup jobs - _schedule_followup_jobs(self.create_followup_jobs(load_id, state, job, schema)) + _schedule_followup_jobs( + self.create_followup_jobs(load_id, state, job, schema) + ) # move to completed folder after followup jobs are created # in case of exception when creating followup job, the loader will retry operation and try to complete again - self.load_storage.normalized_packages.complete_job(load_id, job.file_name()) + self.load_storage.normalized_packages.complete_job( + load_id, job.file_name() + ) logger.info(f"Job for {job.job_id()} completed in load {load_id}") if state in ["failed", "completed"]: self.collector.update("Jobs") if state == "failed": self.collector.update( - "Jobs", 1, message="WARNING: Some of the jobs failed!", label="Failed" + "Jobs", + 1, + message="WARNING: Some of the jobs failed!", + label="Failed", ) return remaining_jobs - def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) -> None: + def complete_package( + self, load_id: str, schema: Schema, aborted: bool = False + ) -> None: # do not commit load id for aborted packages if not aborted: with self.get_destination_client(schema) as job_client: @@ -336,18 +382,23 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) self.load_storage.complete_load_package(load_id, aborted) # collect package info self._loaded_packages.append(self.load_storage.get_load_package_info(load_id)) - self._step_info_complete_load_id(load_id, metrics={"started_at": None, "finished_at": None}) + self._step_info_complete_load_id( + load_id, metrics={"started_at": None, "finished_at": None} + ) # delete jobs only now self.load_storage.maybe_remove_completed_jobs(load_id) logger.info( - f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" + f"All jobs completed, archiving package {load_id} with aborted set to" + f" {aborted}" ) def load_single_package(self, load_id: str, schema: Schema) -> None: new_jobs = self.get_new_jobs_info(load_id) # initialize analytical storage ie. create dataset required by passed schema with self.get_destination_client(schema) as job_client: - if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None: + if ( + expected_update := self.load_storage.begin_schema_update(load_id) + ) is not None: # init job client applied_update = init_client( job_client, @@ -365,8 +416,9 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: # init staging client if self.staging_destination: assert isinstance(job_client, SupportsStagingDestination), ( - f"Job client for destination {self.destination.destination_type} does not" - " implement SupportsStagingDestination" + "Job client for destination" + f" {self.destination.destination_type} does not implement" + " SupportsStagingDestination" ) with self.get_staging_destination_client(schema) as staging_client: init_client( @@ -383,7 +435,9 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: # initialize staging destination and spool or retrieve unfinished jobs if self.staging_destination: with self.get_staging_destination_client(schema) as staging_client: - jobs_count, jobs = self.retrieve_jobs(job_client, load_id, staging_client) + jobs_count, jobs = self.retrieve_jobs( + job_client, load_id, staging_client + ) else: jobs_count, jobs = self.retrieve_jobs(job_client, load_id) @@ -395,14 +449,19 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: self.complete_package(load_id, schema, False) return # update counter we only care about the jobs that are scheduled to be loaded - package_info = self.load_storage.normalized_packages.get_load_package_info(load_id) + package_info = self.load_storage.normalized_packages.get_load_package_info( + load_id + ) total_jobs = reduce(lambda p, c: p + len(c), package_info.jobs.values(), 0) no_failed_jobs = len(package_info.jobs["failed_jobs"]) no_completed_jobs = len(package_info.jobs["completed_jobs"]) + no_failed_jobs self.collector.update("Jobs", no_completed_jobs, total_jobs) if no_failed_jobs > 0: self.collector.update( - "Jobs", no_failed_jobs, message="WARNING: Some of the jobs failed!", label="Failed" + "Jobs", + no_failed_jobs, + message="WARNING: Some of the jobs failed!", + label="Failed", ) # loop until all jobs are processed while True: @@ -410,8 +469,10 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: remaining_jobs = self.complete_jobs(load_id, jobs, schema) if len(remaining_jobs) == 0: # get package status - package_info = self.load_storage.normalized_packages.get_load_package_info( - load_id + package_info = ( + self.load_storage.normalized_packages.get_load_package_info( + load_id + ) ) # possibly raise on failed jobs if self.config.raise_on_failed_jobs: @@ -458,7 +519,9 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: load_id = loads[0] logger.info(f"Loading schema from load package in {load_id}") schema = self.load_storage.normalized_packages.load_schema(load_id) - logger.info(f"Loaded schema name {schema.name} and version {schema.stored_version}") + logger.info( + f"Loaded schema name {schema.name} and version {schema.stored_version}" + ) container = Container() # get top load id and mark as being processed @@ -487,11 +550,15 @@ def get_step_info( _dataset_name: str = None for load_package in self._loaded_packages: # TODO: each load id may have a separate dataset so construct a list of datasets here - if isinstance(self.initial_client_config, DestinationClientDwhConfiguration): + if isinstance( + self.initial_client_config, DestinationClientDwhConfiguration + ): _dataset_name = self.initial_client_config.normalize_dataset_name( load_package.schema ) - metrics[load_package.load_id] = self._step_info_metrics(load_package.load_id) + metrics[load_package.load_id] = self._step_info_metrics( + load_package.load_id + ) return LoadInfo( pipeline, @@ -501,7 +568,9 @@ def get_step_info( self.initial_client_config.destination_name, self.initial_client_config.environment, ( - Destination.normalize_type(self.initial_staging_client_config.destination_type) + Destination.normalize_type( + self.initial_staging_client_config.destination_type + ) if self.initial_staging_client_config else None ), @@ -510,7 +579,11 @@ def get_step_info( if self.initial_staging_client_config else None ), - str(self.initial_staging_client_config) if self.initial_staging_client_config else None, + ( + str(self.initial_staging_client_config) + if self.initial_staging_client_config + else None + ), self.initial_client_config.fingerprint(), _dataset_name, list(load_ids), diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 067ae33613..ea7eebd592 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -30,7 +30,10 @@ def get_completed_table_chain( # returns ordered list of tables from parent to child leaf tables table_chain: List[TTableSchema] = [] # allow for jobless tables for those write disposition - skip_jobless_table = top_merged_table["write_disposition"] not in ("replace", "merge") + skip_jobless_table = top_merged_table["write_disposition"] not in ( + "replace", + "merge", + ) # make sure all the jobs for the table chain is completed for table in map( @@ -40,7 +43,9 @@ def get_completed_table_chain( table_jobs = PackageStorage.filter_jobs_for_table(all_jobs, table["name"]) # skip tables that never seen data if not has_table_seen_data(table): - assert len(table_jobs) == 0, f"Tables that never seen data cannot have jobs {table}" + assert ( + len(table_jobs) == 0 + ), f"Tables that never seen data cannot have jobs {table}" continue # skip jobless tables if len(table_jobs) == 0 and skip_jobless_table: @@ -86,14 +91,18 @@ def init_client( dlt_tables = set(schema.dlt_table_names()) # tables without data (TODO: normalizer removes such jobs, write tests and remove the line below) tables_no_data = set( - table["name"] for table in schema.data_tables() if not has_table_seen_data(table) + table["name"] + for table in schema.data_tables() + if not has_table_seen_data(table) ) # get all tables that actually have load jobs with data tables_with_jobs = set(job.table_name for job in new_jobs) - tables_no_data # get tables to truncate by extending tables with jobs with all their child tables truncate_tables = set( - _extend_tables_with_table_chain(schema, tables_with_jobs, tables_with_jobs, truncate_filter) + _extend_tables_with_table_chain( + schema, tables_with_jobs, tables_with_jobs, truncate_filter + ) ) applied_update = _init_dataset_and_update_schema( @@ -114,7 +123,8 @@ def init_client( _init_dataset_and_update_schema( job_client, expected_update, - staging_tables | {schema.version_table_name}, # keep only schema version + staging_tables + | {schema.version_table_name}, # keep only schema version staging_tables, # all eligible tables must be also truncated staging_info=True, ) @@ -136,14 +146,15 @@ def _init_dataset_and_update_schema( ) job_client.initialize_storage() logger.info( - f"Client for {job_client.config.destination_type} will update schema to package schema" - f" {staging_text}" + f"Client for {job_client.config.destination_type} will update schema to package" + f" schema {staging_text}" ) applied_update = job_client.update_stored_schema( only_tables=update_tables, expected_update=expected_update ) logger.info( - f"Client for {job_client.config.destination_type} will truncate tables {staging_text}" + f"Client for {job_client.config.destination_type} will truncate tables" + f" {staging_text}" ) job_client.initialize_storage(truncate_tables=truncate_tables) return applied_update @@ -167,7 +178,10 @@ def _extend_tables_with_table_chain( # for replace and merge write dispositions we should include tables # without jobs in the table chain, because child tables may need # processing due to changes in the root table - skip_jobless_table = top_job_table["write_disposition"] not in ("replace", "merge") + skip_jobless_table = top_job_table["write_disposition"] not in ( + "replace", + "merge", + ) for table in map( lambda t: fill_hints_from_parent_and_clone_table(schema.tables, t), get_child_tables(schema.tables, top_job_table["name"]), diff --git a/dlt/normalize/exceptions.py b/dlt/normalize/exceptions.py index a172196899..57d94cb714 100644 --- a/dlt/normalize/exceptions.py +++ b/dlt/normalize/exceptions.py @@ -12,5 +12,6 @@ def __init__(self, load_id: str, job_id: str, failed_message: str) -> None: self.job_id = job_id self.failed_message = failed_message super().__init__( - f"Job for {job_id} failed terminally in load {load_id} with message {failed_message}." + f"Job for {job_id} failed terminally in load {load_id} with message" + f" {failed_message}." ) diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index fc1e152ff2..0ea849d7d7 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -5,7 +5,11 @@ from dlt.common.data_writers import DataWriterMetrics from dlt.common.json import custom_pua_decode, may_have_pua from dlt.common.runtime import signals -from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns, TSchemaContractDict +from dlt.common.schema.typing import ( + TSchemaEvolutionMode, + TTableSchemaColumns, + TSchemaContractDict, +) from dlt.common.schema.utils import has_table_seen_data from dlt.common.storages import ( NormalizeStorage, @@ -43,7 +47,9 @@ def __init__( self.config = config @abstractmethod - def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSchemaUpdate]: ... + def __call__( + self, extracted_items_file: str, root_table_name: str + ) -> List[TSchemaUpdate]: ... class JsonLItemsNormalizer(ItemsNormalizer): @@ -74,7 +80,11 @@ def _filter_columns( return row def _normalize_chunk( - self, root_table_name: str, items: List[TDataItem], may_have_pua: bool, skip_write: bool + self, + root_table_name: str, + items: List[TDataItem], + may_have_pua: bool, + skip_write: bool, ) -> TSchemaUpdate: column_schemas = self._column_schemas schema_update: TSchemaUpdate = {} @@ -106,7 +116,9 @@ def _normalize_chunk( # filter columns or full rows if schema contract said so # do it before schema inference in `coerce_row` to not trigger costly migration code - filtered_columns = self._filtered_tables_columns.get(table_name, None) + filtered_columns = self._filtered_tables_columns.get( + table_name, None + ) if filtered_columns: row = self._filter_columns(filtered_columns, row) # type: ignore[arg-type] # if whole row got dropped @@ -120,7 +132,9 @@ def _normalize_chunk( row[k] = custom_pua_decode(v) # type: ignore # coerce row of values into schema table, generating partial table with new columns if any - row, partial_table = schema.coerce_row(table_name, parent_table, row) + row, partial_table = schema.coerce_row( + table_name, parent_table, row + ) # if we detect a migration, check schema contract if partial_table: @@ -138,8 +152,10 @@ def _normalize_chunk( if entity == "tables": self._filtered_tables.add(name) elif entity == "columns": - filtered_columns = self._filtered_tables_columns.setdefault( - table_name, {} + filtered_columns = ( + self._filtered_tables_columns.setdefault( + table_name, {} + ) ) filtered_columns[name] = mode @@ -154,7 +170,9 @@ def _normalize_chunk( table_updates.append(partial_table) # update our columns - column_schemas[table_name] = schema.get_table_columns(table_name) + column_schemas[table_name] = schema.get_table_columns( + table_name + ) # apply new filters if filtered_columns and filters: @@ -199,7 +217,9 @@ def __call__( root_table_name, items, may_have_pua(line), skip_write=False ) schema_updates.append(partial_update) - logger.debug(f"Processed {line_no+1} lines from file {extracted_items_file}") + logger.debug( + f"Processed {line_no+1} lines from file {extracted_items_file}" + ) if line is None and root_table_name in self.schema.tables: # TODO: we should push the truncate jobs via package state # not as empty jobs. empty jobs should be reserved for @@ -218,7 +238,8 @@ def __call__( self.schema.get_table_columns(root_table_name), ) logger.debug( - f"No lines in file {extracted_items_file}, written empty load job file" + f"No lines in file {extracted_items_file}, written empty load job" + " file" ) return schema_updates @@ -228,7 +249,11 @@ class ParquetItemsNormalizer(ItemsNormalizer): REWRITE_ROW_GROUPS = 1 def _write_with_dlt_columns( - self, extracted_items_file: str, root_table_name: str, add_load_id: bool, add_dlt_id: bool + self, + extracted_items_file: str, + root_table_name: str, + add_load_id: bool, + add_dlt_id: bool, ) -> List[TSchemaUpdate]: new_columns: List[Any] = [] schema = self.schema @@ -255,7 +280,9 @@ def _write_with_dlt_columns( ( -1, pa.field("_dlt_load_id", load_id_type, nullable=False), - lambda batch: pa.array([load_id] * batch.num_rows, type=load_id_type), + lambda batch: pa.array( + [load_id] * batch.num_rows, type=load_id_type + ), ) ) @@ -264,7 +291,11 @@ def _write_with_dlt_columns( { "name": root_table_name, "columns": { - "_dlt_id": {"name": "_dlt_id", "data_type": "text", "nullable": False} + "_dlt_id": { + "name": "_dlt_id", + "data_type": "text", + "nullable": False, + } }, } ) @@ -322,16 +353,26 @@ def _fix_schema_precisions(self, root_table_name: str) -> List[TSchemaUpdate]: if not new_cols: return [] return [ - {root_table_name: [schema.update_table({"name": root_table_name, "columns": new_cols})]} + { + root_table_name: [ + schema.update_table({"name": root_table_name, "columns": new_cols}) + ] + } ] - def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSchemaUpdate]: + def __call__( + self, extracted_items_file: str, root_table_name: str + ) -> List[TSchemaUpdate]: base_schema_update = self._fix_schema_precisions(root_table_name) add_dlt_id = self.config.parquet_normalizer.add_dlt_id add_dlt_load_id = self.config.parquet_normalizer.add_dlt_load_id - if add_dlt_id or add_dlt_load_id or self.load_storage.loader_file_format != "arrow": + if ( + add_dlt_id + or add_dlt_load_id + or self.load_storage.loader_file_format != "arrow" + ): schema_update = self._write_with_dlt_columns( extracted_items_file, root_table_name, add_dlt_load_id, add_dlt_id ) @@ -342,14 +383,18 @@ def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSch with self.normalize_storage.extracted_packages.storage.open_file( extracted_items_file, "rb" ) as f: - file_metrics = DataWriterMetrics(extracted_items_file, get_row_count(f), f.tell(), 0, 0) + file_metrics = DataWriterMetrics( + extracted_items_file, get_row_count(f), f.tell(), 0, 0 + ) parts = ParsedLoadJobFileName.parse(extracted_items_file) self.load_storage.import_items_file( self.load_id, self.schema.name, parts.table_name, - self.normalize_storage.extracted_packages.storage.make_full_path(extracted_items_file), + self.normalize_storage.extracted_packages.storage.make_full_path( + extracted_items_file + ), file_metrics, ) diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 4a17b9eef8..ce329516dc 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -128,19 +128,31 @@ def _create_load_storage(file_format: TLoaderFileFormat) -> LoadStorage: destination_caps.preferred_loader_file_format or destination_caps.preferred_staging_file_format ) - return LoadStorage(False, file_format, supported_formats, loader_storage_config) + return LoadStorage( + False, file_format, supported_formats, loader_storage_config + ) # process all files with data items and write to buffered item storage with Container().injectable_context(destination_caps): schema = Schema.from_stored_schema(stored_schema) normalize_storage = NormalizeStorage(False, normalize_storage_config) - def _get_items_normalizer(file_format: TLoaderFileFormat) -> ItemsNormalizer: + def _get_items_normalizer( + file_format: TLoaderFileFormat, + ) -> ItemsNormalizer: if file_format in item_normalizers: return item_normalizers[file_format] - klass = ParquetItemsNormalizer if file_format == "parquet" else JsonLItemsNormalizer + klass = ( + ParquetItemsNormalizer + if file_format == "parquet" + else JsonLItemsNormalizer + ) norm = item_normalizers[file_format] = klass( - _create_load_storage(file_format), normalize_storage, schema, load_id, config + _create_load_storage(file_format), + normalize_storage, + schema, + load_id, + config, ) return norm @@ -157,8 +169,9 @@ def _get_items_normalizer(file_format: TLoaderFileFormat) -> ItemsNormalizer: root_tables.add(root_table_name) normalizer = _get_items_normalizer(parsed_file_name.file_format) logger.debug( - f"Processing extracted items in {extracted_items_file} in load_id" - f" {load_id} with table name {root_table_name} and schema {schema.name}" + f"Processing extracted items in {extracted_items_file} in" + f" load_id {load_id} with table name {root_table_name} and" + f" schema {schema.name}" ) partial_updates = normalizer(extracted_items_file, root_table_name) schema_updates.extend(partial_updates) @@ -181,7 +194,8 @@ def update_table(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> N for schema_update in schema_updates: for table_name, table_updates in schema_update.items(): logger.info( - f"Updating schema for table {table_name} with {len(table_updates)} deltas" + f"Updating schema for table {table_name} with" + f" {len(table_updates)} deltas" ) for partial_table in table_updates: # merge columns @@ -204,7 +218,9 @@ def group_worker_files(files: Sequence[str], no_groups: int) -> List[Sequence[st l_idx = idx + 1 return chunk_files - def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV: + def map_parallel( + self, schema: Schema, load_id: str, files: Sequence[str] + ) -> TWorkerRV: workers: int = getattr(self.pool, "_max_workers", 1) chunk_files = self.group_worker_files(files, workers) schema_dict: TStoredSchema = schema.to_dict() @@ -244,12 +260,16 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW # update metrics self.collector.update("Files", len(result.file_metrics)) self.collector.update( - "Items", sum(result.file_metrics, EMPTY_DATA_WRITER_METRICS).items_count + "Items", + sum( + result.file_metrics, EMPTY_DATA_WRITER_METRICS + ).items_count, ) except CannotCoerceColumnException as exc: # schema conflicts resulting from parallel executing logger.warning( - f"Parallel schema update conflict, retrying task ({str(exc)}" + "Parallel schema update conflict, retrying task" + f" ({str(exc)}" ) # delete all files produced by the task for metrics in result.file_metrics: @@ -268,7 +288,9 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW return summary - def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV: + def map_single( + self, schema: Schema, load_id: str, files: Sequence[str] + ) -> TWorkerRV: result = Normalize.w_normalize_files( self.config, self.normalize_storage.config, @@ -290,9 +312,13 @@ def spool_files( # process files in parallel or in single thread, depending on map_f schema_updates, writer_metrics = map_f(schema, load_id, files) # compute metrics - job_metrics = {ParsedLoadJobFileName.parse(m.file_path): m for m in writer_metrics} + job_metrics = { + ParsedLoadJobFileName.parse(m.file_path): m for m in writer_metrics + } table_metrics: Dict[str, DataWriterMetrics] = { - table_name: sum(map(lambda pair: pair[1], metrics), EMPTY_DATA_WRITER_METRICS) + table_name: sum( + map(lambda pair: pair[1], metrics), EMPTY_DATA_WRITER_METRICS + ) for table_name, metrics in itertools.groupby( job_metrics.items(), lambda pair: pair[0].table_name ) @@ -306,18 +332,21 @@ def spool_files( # mark that table have seen data only if there was data if "seen-data" not in x_normalizer: logger.info( - f"Table {table_name} has seen data for a first time with load id {load_id}" + f"Table {table_name} has seen data for a first time with load id" + f" {load_id}" ) x_normalizer["seen-data"] = True # schema is updated, save it to schema volume if schema.is_modified: logger.info( - f"Saving schema {schema.name} with version {schema.stored_version}:{schema.version}" + f"Saving schema {schema.name} with version" + f" {schema.stored_version}:{schema.version}" ) self.schema_storage.save_schema(schema) else: logger.info( - f"Schema {schema.name} with version {schema.version} was not modified. Save skipped" + f"Schema {schema.name} with version {schema.version} was not modified." + " Save skipped" ) # save schema new package self.load_storage.new_packages.save_schema(load_id, schema) @@ -339,12 +368,16 @@ def spool_files( { "started_at": None, "finished_at": None, - "job_metrics": {job.job_id(): metrics for job, metrics in job_metrics.items()}, + "job_metrics": { + job.job_id(): metrics for job, metrics in job_metrics.items() + }, "table_metrics": table_metrics, }, ) - def spool_schema_files(self, load_id: str, schema: Schema, files: Sequence[str]) -> str: + def spool_schema_files( + self, load_id: str, schema: Schema, files: Sequence[str] + ) -> str: # delete existing folder for the case that this is a retry self.load_storage.new_packages.delete_package(load_id, not_exists_ok=True) # normalized files will go here before being atomically renamed @@ -358,12 +391,15 @@ def spool_schema_files(self, load_id: str, schema: Schema, files: Sequence[str]) except CannotCoerceColumnException as exc: # schema conflicts resulting from parallel executing logger.warning( - f"Parallel schema update conflict, switching to single thread ({str(exc)}" + "Parallel schema update conflict, switching to single thread" + f" ({str(exc)}" ) # start from scratch self.load_storage.new_packages.delete_package(load_id) self.load_storage.new_packages.create_package(load_id) - self.spool_files(load_id, schema.clone(update_normalizers=True), self.map_single, files) + self.spool_files( + load_id, schema.clone(update_normalizers=True), self.map_single, files + ) return load_id @@ -386,18 +422,22 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: storage_schema = self.schema_storage[schema.name] if schema.stored_version_hash != storage_schema.stored_version_hash: logger.warning( - f"When normalizing package {load_id} with schema {schema.name}: the storage" - f" schema hash {storage_schema.stored_version_hash} is different from" - f" extract package schema hash {schema.stored_version_hash}. Storage schema" - " was used." + f"When normalizing package {load_id} with schema {schema.name}:" + " the storage schema hash" + f" {storage_schema.stored_version_hash} is different from" + f" extract package schema hash {schema.stored_version_hash}." + " Storage schema was used." ) schema = storage_schema except FileNotFoundError: pass # read all files to normalize placed as new jobs - schema_files = self.normalize_storage.extracted_packages.list_new_jobs(load_id) + schema_files = self.normalize_storage.extracted_packages.list_new_jobs( + load_id + ) logger.info( - f"Found {len(schema_files)} files in schema {schema.name} load_id {load_id}" + f"Found {len(schema_files)} files in schema {schema.name} load_id" + f" {load_id}" ) if len(schema_files) == 0: # delete empty package @@ -411,14 +451,18 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: self.spool_schema_files(load_id, schema, schema_files) # return info on still pending packages (if extractor saved something in the meantime) - return TRunMetrics(False, len(self.normalize_storage.extracted_packages.list_packages())) + return TRunMetrics( + False, len(self.normalize_storage.extracted_packages.list_packages()) + ) def get_load_package_info(self, load_id: str) -> LoadPackageInfo: """Returns information on extracted/normalized/completed package with given load_id, all jobs and their statuses.""" try: return self.load_storage.get_load_package_info(load_id) except LoadPackageNotFound: - return self.normalize_storage.extracted_packages.get_load_package_info(load_id) + return self.normalize_storage.extracted_packages.get_load_package_info( + load_id + ) def get_step_info( self, @@ -431,4 +475,6 @@ def get_step_info( load_package = self.get_load_package_info(load_id) load_packages.append(load_package) metrics[load_id] = self._step_info_metrics(load_id) - return NormalizeInfo(pipeline, metrics, load_ids, load_packages, pipeline.first_run) + return NormalizeInfo( + pipeline, metrics, load_ids, load_packages, pipeline.first_run + ) diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 4101e58320..2a99f5fcd8 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -7,12 +7,23 @@ from dlt.common.configuration import with_config from dlt.common.configuration.container import Container from dlt.common.configuration.inject import get_orig_args, last_config -from dlt.common.destination import TLoaderFileFormat, Destination, TDestinationReferenceArg +from dlt.common.destination import ( + TLoaderFileFormat, + Destination, + TDestinationReferenceArg, +) from dlt.common.pipeline import LoadInfo, PipelineContext, get_dlt_pipelines_dir -from dlt.pipeline.configuration import PipelineConfiguration, ensure_correct_pipeline_kwargs +from dlt.pipeline.configuration import ( + PipelineConfiguration, + ensure_correct_pipeline_kwargs, +) from dlt.pipeline.pipeline import Pipeline -from dlt.pipeline.progress import _from_name as collector_from_name, TCollectorArg, _NULL_COLLECTOR +from dlt.pipeline.progress import ( + _from_name as collector_from_name, + TCollectorArg, + _NULL_COLLECTOR, +) from dlt.pipeline.warnings import credentials_argument_deprecated @@ -81,7 +92,8 @@ def pipeline( @overload def pipeline() -> Pipeline: # type: ignore """When called without any arguments, returns the recently created `Pipeline` instance. - If not found, it creates a new instance with all the pipeline options set to defaults.""" + If not found, it creates a new instance with all the pipeline options set to defaults. + """ @with_config(spec=PipelineConfiguration, auto_pipeline_section=True) @@ -120,7 +132,8 @@ def pipeline( pipelines_dir = get_dlt_pipelines_dir() destination = Destination.from_reference( - destination or kwargs["destination_type"], destination_name=kwargs["destination_name"] + destination or kwargs["destination_type"], + destination_name=kwargs["destination_name"], ) staging = Destination.from_reference( staging or kwargs.get("staging_type", None), diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index d7ffca6e89..e443436901 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -40,5 +40,9 @@ def on_resolved(self) -> None: def ensure_correct_pipeline_kwargs(f: AnyFun, **kwargs: Any) -> None: for arg_name in kwargs: - if not hasattr(PipelineConfiguration, arg_name) and not arg_name.startswith("_dlt"): - raise TypeError(f"{f.__name__} got an unexpected keyword argument '{arg_name}'") + if not hasattr(PipelineConfiguration, arg_name) and not arg_name.startswith( + "_dlt" + ): + raise TypeError( + f"{f.__name__} got an unexpected keyword argument '{arg_name}'" + ) diff --git a/dlt/pipeline/current.py b/dlt/pipeline/current.py index 25fd398623..293de78389 100644 --- a/dlt/pipeline/current.py +++ b/dlt/pipeline/current.py @@ -1,6 +1,10 @@ """Easy access to active pipelines, state, sources and schemas""" -from dlt.common.pipeline import source_state as _state, resource_state, get_current_pipe_name +from dlt.common.pipeline import ( + source_state as _state, + resource_state, + get_current_pipe_name, +) from dlt.pipeline import pipeline as _pipeline from dlt.extract.decorators import get_source_schema from dlt.common.storages.load_package import ( diff --git a/dlt/pipeline/dbt.py b/dlt/pipeline/dbt.py index e647e475ed..30bfaaab90 100644 --- a/dlt/pipeline/dbt.py +++ b/dlt/pipeline/dbt.py @@ -44,9 +44,13 @@ def get_venv( # try to restore existing venv with contextlib.suppress(VenvNotFound): # TODO: check dlt version in venv and update it if local version updated - return _restore_venv(venv_dir, [pipeline.destination.spec().destination_type], dbt_version) + return _restore_venv( + venv_dir, [pipeline.destination.spec().destination_type], dbt_version + ) - return _create_venv(venv_dir, [pipeline.destination.spec().destination_type], dbt_version) + return _create_venv( + venv_dir, [pipeline.destination.spec().destination_type], dbt_version + ) def package( diff --git a/dlt/pipeline/exceptions.py b/dlt/pipeline/exceptions.py index d3538a8377..9742beed70 100644 --- a/dlt/pipeline/exceptions.py +++ b/dlt/pipeline/exceptions.py @@ -8,21 +8,25 @@ class InvalidPipelineName(PipelineException, ValueError): def __init__(self, pipeline_name: str, details: str) -> None: super().__init__( pipeline_name, - f"The pipeline name {pipeline_name} contains invalid characters. The pipeline name is" - " used to create a pipeline working directory and must be a valid directory name. The" - f" actual error is: {details}", + f"The pipeline name {pipeline_name} contains invalid characters. The" + " pipeline name is used to create a pipeline working directory and must be" + f" a valid directory name. The actual error is: {details}", ) class PipelineConfigMissing(PipelineException): def __init__( - self, pipeline_name: str, config_elem: str, step: TPipelineStep, _help: str = None + self, + pipeline_name: str, + config_elem: str, + step: TPipelineStep, + _help: str = None, ) -> None: self.config_elem = config_elem self.step = step msg = ( - f"Configuration element {config_elem} was not provided and {step} step cannot be" - " executed" + f"Configuration element {config_elem} was not provided and {step} step" + " cannot be executed" ) if _help: msg += f"\n{_help}\n" @@ -32,8 +36,8 @@ def __init__( class CannotRestorePipelineException(PipelineException): def __init__(self, pipeline_name: str, pipelines_dir: str, reason: str) -> None: msg = ( - f"Pipeline with name {pipeline_name} in working directory {pipelines_dir} could not be" - f" restored: {reason}" + f"Pipeline with name {pipeline_name} in working directory" + f" {pipelines_dir} could not be restored: {reason}" ) super().__init__(pipeline_name, msg) @@ -89,19 +93,20 @@ def __init__( self.to_engine = to_engine super().__init__( pipeline_name, - f"No engine upgrade path for state in pipeline {pipeline_name} from {init_engine} to" - f" {to_engine}, stopped at {from_engine}. You possibly tried to run an older dlt" - " version against a destination you have previously loaded data to with a newer dlt" - " version.", + f"No engine upgrade path for state in pipeline {pipeline_name} from" + f" {init_engine} to {to_engine}, stopped at {from_engine}. You possibly" + " tried to run an older dlt version against a destination you have" + " previously loaded data to with a newer dlt version.", ) class PipelineHasPendingDataException(PipelineException): def __init__(self, pipeline_name: str, pipelines_dir: str) -> None: msg = ( - f" Operation failed because pipeline with name {pipeline_name} in working directory" - f" {pipelines_dir} contains pending extracted files or load packages. Use `dlt pipeline" - " sync` to reset the local state then run this operation again." + f" Operation failed because pipeline with name {pipeline_name} in working" + f" directory {pipelines_dir} contains pending extracted files or load" + " packages. Use `dlt pipeline sync` to reset the local state then run this" + " operation again." ) super().__init__(pipeline_name, msg) @@ -109,9 +114,9 @@ def __init__(self, pipeline_name: str, pipelines_dir: str) -> None: class PipelineNeverRan(PipelineException): def __init__(self, pipeline_name: str, pipelines_dir: str) -> None: msg = ( - f" Operation failed because pipeline with name {pipeline_name} in working directory" - f" {pipelines_dir} was never run or never synced with destination. Use `dlt pipeline" - " sync` to synchronize." + f" Operation failed because pipeline with name {pipeline_name} in working" + f" directory {pipelines_dir} was never run or never synced with" + " destination. Use `dlt pipeline sync` to synchronize." ) super().__init__(pipeline_name, msg) @@ -119,5 +124,6 @@ def __init__(self, pipeline_name: str, pipelines_dir: str) -> None: class PipelineNotActive(PipelineException): def __init__(self, pipeline_name: str) -> None: super().__init__( - pipeline_name, f"Pipeline {pipeline_name} is not active so it cannot be deactivated" + pipeline_name, + f"Pipeline {pipeline_name} is not active so it cannot be deactivated", ) diff --git a/dlt/pipeline/helpers.py b/dlt/pipeline/helpers.py index c242a26eaa..9229ede977 100644 --- a/dlt/pipeline/helpers.py +++ b/dlt/pipeline/helpers.py @@ -1,5 +1,16 @@ import contextlib -from typing import Callable, Sequence, Iterable, Optional, Any, List, Dict, Tuple, Union, TypedDict +from typing import ( + Callable, + Sequence, + Iterable, + Optional, + Any, + List, + Dict, + Tuple, + Union, + TypedDict, +) from itertools import chain from dlt.common.jsonpath import resolve_paths, TAnyJsonPath, compile_paths @@ -50,7 +61,10 @@ def retry_load( def _retry_load(ex: BaseException) -> bool: # do not retry in normalize or extract stages - if isinstance(ex, PipelineStepFailed) and ex.step not in retry_on_pipeline_steps: + if ( + isinstance(ex, PipelineStepFailed) + and ex.step not in retry_on_pipeline_steps + ): return False # do not retry on terminal exceptions if isinstance(ex, TerminalException) or ( @@ -78,7 +92,9 @@ class DropCommand: def __init__( self, pipeline: Pipeline, - resources: Union[Iterable[Union[str, TSimpleRegex]], Union[str, TSimpleRegex]] = (), + resources: Union[ + Iterable[Union[str, TSimpleRegex]], Union[str, TSimpleRegex] + ] = (), schema_name: Optional[str] = None, state_paths: TAnyJsonPath = (), drop_all: bool = False, @@ -93,7 +109,9 @@ def __init__( if not pipeline.default_schema_name: raise PipelineNeverRan(pipeline.pipeline_name, pipeline.pipelines_dir) - self.schema = pipeline.schemas[schema_name or pipeline.default_schema_name].clone() + self.schema = pipeline.schemas[ + schema_name or pipeline.default_schema_name + ].clone() self.schema_tables = self.schema.tables self.drop_tables = not state_only self.drop_state = True @@ -102,9 +120,13 @@ def __init__( resources = set(resources) resource_names = [] if drop_all: - self.resource_pattern = compile_simple_regex(TSimpleRegex("re:.*")) # Match everything + self.resource_pattern = compile_simple_regex( + TSimpleRegex("re:.*") + ) # Match everything elif resources: - self.resource_pattern = compile_simple_regexes(TSimpleRegex(r) for r in resources) + self.resource_pattern = compile_simple_regexes( + TSimpleRegex(r) for r in resources + ) else: self.resource_pattern = None @@ -112,9 +134,13 @@ def __init__( data_tables = { t["name"]: t for t in self.schema.data_tables() } # Don't remove _dlt tables - resource_tables = group_tables_by_resource(data_tables, pattern=self.resource_pattern) + resource_tables = group_tables_by_resource( + data_tables, pattern=self.resource_pattern + ) if self.drop_tables: - self.tables_to_drop = list(chain.from_iterable(resource_tables.values())) + self.tables_to_drop = list( + chain.from_iterable(resource_tables.values()) + ) self.tables_to_drop.reverse() else: self.tables_to_drop = [] @@ -138,8 +164,8 @@ def __init__( ) if self.resource_pattern and not resource_tables: self.info["warnings"].append( - f"Specified resource(s) {str(resources)} did not select any table(s) in schema" - f" {self.schema.name}. Possible resources are:" + f"Specified resource(s) {str(resources)} did not select any table(s) in" + f" schema {self.schema.name}. Possible resources are:" f" {list(group_tables_by_resource(data_tables).keys())}" ) self._new_state = self._create_modified_state() @@ -156,8 +182,8 @@ def _drop_destination_tables(self) -> None: table_names = [tbl["name"] for tbl in self.tables_to_drop] for table_name in table_names: assert table_name not in self.schema._schema_tables, ( - f"You are dropping table {table_name} in {self.schema.name} but it is still present" - " in the schema" + f"You are dropping table {table_name} in {self.schema.name} but it is" + " still present in the schema" ) with self.pipeline._sql_job_client(self.schema) as client: client.drop_tables(*table_names, replace_schema=True) @@ -191,11 +217,13 @@ def _create_modified_state(self) -> Dict[str, Any]: resolved_paths = resolve_paths(self.state_paths_to_drop, source_state) if self.state_paths_to_drop and not resolved_paths: self.info["warnings"].append( - f"State paths {self.state_paths_to_drop} did not select any paths in source" - f" {source_name}" + f"State paths {self.state_paths_to_drop} did not select any paths" + f" in source {source_name}" ) _delete_source_state_keys(resolved_paths, source_state) - self.info["state_paths"].extend(f"{source_name}.{p}" for p in resolved_paths) + self.info["state_paths"].extend( + f"{source_name}.{p}" for p in resolved_paths + ) return state # type: ignore[return-value] def _extract_state(self) -> None: @@ -243,4 +271,6 @@ def drop( drop_all: bool = False, state_only: bool = False, ) -> None: - return DropCommand(pipeline, resources, schema_name, state_paths, drop_all, state_only)() + return DropCommand( + pipeline, resources, schema_name, state_paths, drop_all, state_only + )() diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index de1f7afced..e2fa1f0d73 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -149,10 +149,14 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: self.activate() # backup and restore state - should_extract_state = may_extract_state and self.config.restore_from_destination + should_extract_state = ( + may_extract_state and self.config.restore_from_destination + ) with self.managed_state(extract_state=should_extract_state) as state: # add the state to container as a context - with self._container.injectable_context(StateInjectableContext(state=state)): + with self._container.injectable_context( + StateInjectableContext(state=state) + ): return f(self, *args, **kwargs) return _wrap # type: ignore @@ -211,7 +215,9 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: try: # start a trace step for wrapped function if trace: - trace_step = start_trace_step(trace, cast(TPipelineStep, f.__name__), self) + trace_step = start_trace_step( + trace, cast(TPipelineStep, f.__name__), self + ) step_info = f(self, *args, **kwargs) return step_info @@ -256,7 +262,9 @@ def decorator(f: TFun) -> TFun: def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: # add section context to the container to be used by all configuration without explicit sections resolution with inject_section( - ConfigSectionContext(pipeline_name=self.pipeline_name, sections=sections) + ConfigSectionContext( + pipeline_name=self.pipeline_name, sections=sections + ) ): return f(self, *args, **kwargs) @@ -278,7 +286,9 @@ class Pipeline(SupportsPipeline): "destinations", } ) - LOCAL_STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineLocalState).keys()) + LOCAL_STATE_PROPS: ClassVar[List[str]] = list( + get_type_hints(TPipelineLocalState).keys() + ) DEFAULT_DATASET_SUFFIX: ClassVar[str] = "_dataset" pipeline_name: str @@ -352,7 +362,9 @@ def __init__( # we overwrite the state with the values from init self._set_dataset_name(dataset_name) self.credentials = credentials - self._configure(import_schema_path, export_schema_path, must_attach_to_local_pipeline) + self._configure( + import_schema_path, export_schema_path, must_attach_to_local_pipeline + ) def drop(self, pipeline_name: str = None) -> "Pipeline": """Deletes local pipeline state, schemas and any working files. @@ -422,12 +434,16 @@ def extract( ): if source.exhausted: raise SourceExhausted(source.name) - self._extract_source(extract_step, source, max_parallel_items, workers) + self._extract_source( + extract_step, source, max_parallel_items, workers + ) # extract state if self.config.restore_from_destination: # this will update state version hash so it will not be extracted again by with_state_sync self._bump_version_and_extract_state( - self._container[StateInjectableContext].state, True, extract_step + self._container[StateInjectableContext].state, + True, + extract_step, ) # commit load packages extract_step.commit_packages() @@ -453,7 +469,9 @@ def normalize( workers = 1 if loader_file_format and loader_file_format in INTERNAL_LOADER_FILE_FORMATS: - raise ValueError(f"{loader_file_format} is one of internal dlt file formats.") + raise ValueError( + f"{loader_file_format} is one of internal dlt file formats." + ) # check if any schema is present, if not then no data was extracted if not self.default_schema_name: return None @@ -469,7 +487,9 @@ def normalize( _load_storage_config=self._load_storage_config(), ) # run with destination context - with self._maybe_destination_capabilities(loader_file_format=loader_file_format): + with self._maybe_destination_capabilities( + loader_file_format=loader_file_format + ): # shares schema storage with the pipeline so we do not need to install normalize_step: Normalize = Normalize( collector=self.collector, @@ -531,7 +551,9 @@ def load( is_storage_owner=False, config=load_config, initial_client_config=client.config, - initial_staging_client_config=staging_client.config if staging_client else None, + initial_staging_client_config=( + staging_client.config if staging_client else None + ), ) try: with signals.delayed_signals(): @@ -643,9 +665,9 @@ def run( # if there were any pending loads, load them and **exit** if data is not None: logger.warn( - "The pipeline `run` method will now load the pending load packages. The data" - " you passed to the run function will not be loaded. In order to do that you" - " must run the pipeline again" + "The pipeline `run` method will now load the pending load packages." + " The data you passed to the run function will not be loaded. In" + " order to do that you must run the pipeline again" ) return self.load(destination, dataset_name, credentials=credentials) @@ -696,8 +718,13 @@ def sync_destination( # if remote state is newer or same # print(f'REMOTE STATE: {(remote_state or {}).get("_state_version")} >= {state["_state_version"]}') # TODO: check if remote_state["_state_version"] is not in 10 recent version. then we know remote is newer. - if remote_state and remote_state["_state_version"] >= state["_state_version"]: - state_changed = remote_state["_version_hash"] != state.get("_version_hash") + if ( + remote_state + and remote_state["_state_version"] >= state["_state_version"] + ): + state_changed = remote_state["_version_hash"] != state.get( + "_version_hash" + ) # print(f"MERGED STATE: {bool(merged_state)}") if state_changed: # see if state didn't change the pipeline name @@ -705,8 +732,8 @@ def sync_destination( raise CannotRestorePipelineException( state["pipeline_name"], self.pipelines_dir, - "destination state contains state for pipeline with name" - f" {remote_state['pipeline_name']}", + "destination state contains state for pipeline with" + f" name {remote_state['pipeline_name']}", ) # if state was modified force get all schemas restored_schemas = self._get_schemas_from_destination( @@ -743,7 +770,9 @@ def sync_destination( if self.default_schema_name is None: should_wipe = True else: - with self._get_destination_clients(self.default_schema)[0] as job_client: + with self._get_destination_clients(self.default_schema)[ + 0 + ] as job_client: # and storage is not initialized should_wipe = not job_client.is_storage_initialized() if should_wipe: @@ -831,8 +860,8 @@ def last_trace(self) -> PipelineTrace: return load_trace(self.working_dir) @deprecated( - "Please use list_extracted_load_packages instead. Flat extracted storage format got dropped" - " in dlt 0.4.0", + "Please use list_extracted_load_packages instead. Flat extracted storage format" + " got dropped in dlt 0.4.0", category=Dlt04DeprecationWarning, ) def list_extracted_resources(self) -> Sequence[str]: @@ -856,7 +885,11 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: try: return self._get_load_storage().get_load_package_info(load_id) except LoadPackageNotFound: - return self._get_normalize_storage().extracted_packages.get_load_package_info(load_id) + return ( + self._get_normalize_storage().extracted_packages.get_load_package_info( + load_id + ) + ) def get_load_package_state(self, load_id: str) -> TLoadPackageState: """Returns information on extracted/normalized/completed package with given load_id, all jobs and their statuses.""" @@ -864,15 +897,24 @@ def get_load_package_state(self, load_id: str) -> TLoadPackageState: def list_failed_jobs_in_package(self, load_id: str) -> Sequence[LoadJobInfo]: """List all failed jobs and associated error messages for a specified `load_id`""" - return self._get_load_storage().get_load_package_info(load_id).jobs.get("failed_jobs", []) + return ( + self._get_load_storage() + .get_load_package_info(load_id) + .jobs.get("failed_jobs", []) + ) def drop_pending_packages(self, with_partial_loads: bool = True) -> None: """Deletes all extracted and normalized packages, including those that are partially loaded by default""" # delete normalized packages load_storage = self._get_load_storage() for load_id in load_storage.normalized_packages.list_packages(): - package_info = load_storage.normalized_packages.get_load_package_info(load_id) - if PackageStorage.is_package_partially_loaded(package_info) and not with_partial_loads: + package_info = load_storage.normalized_packages.get_load_package_info( + load_id + ) + if ( + PackageStorage.is_package_partially_loaded(package_info) + and not with_partial_loads + ): continue load_storage.normalized_packages.delete_package(load_id) # delete extracted files @@ -881,15 +923,17 @@ def drop_pending_packages(self, with_partial_loads: bool = True) -> None: normalize_storage.extracted_packages.delete_package(load_id) @with_schemas_sync - def sync_schema(self, schema_name: str = None, credentials: Any = None) -> TSchemaTables: + def sync_schema( + self, schema_name: str = None, credentials: Any = None + ) -> TSchemaTables: """Synchronizes the schema `schema_name` with the destination. If no name is provided, the default schema will be synchronized.""" if not schema_name and not self.default_schema_name: raise PipelineConfigMissing( self.pipeline_name, "default_schema_name", "load", - "Pipeline contains no schemas. Please extract any data with `extract` or `run`" - " methods.", + "Pipeline contains no schemas. Please extract any data with `extract`" + " or `run` methods.", ) schema = self.schemas[schema_name] if schema_name else self.default_schema @@ -918,7 +962,9 @@ def get_local_state_val(self, key: str) -> Any: state = self._get_state() return state["_local"][key] # type: ignore - def sql_client(self, schema_name: str = None, credentials: Any = None) -> SqlClientBase[Any]: + def sql_client( + self, schema_name: str = None, credentials: Any = None + ) -> SqlClientBase[Any]: """Returns a sql client configured to query/change the destination and dataset that were used to load the data. Use the client with `with` statement to manage opening and closing connection to the destination: >>> with pipeline.sql_client() as client: @@ -940,7 +986,9 @@ def sql_client(self, schema_name: str = None, credentials: Any = None) -> SqlCli schema = self._get_schema_or_create(schema_name) return self._sql_job_client(schema, credentials).sql_client - def destination_client(self, schema_name: str = None, credentials: Any = None) -> JobClientBase: + def destination_client( + self, schema_name: str = None, credentials: Any = None + ) -> JobClientBase: """Get the destination job client for the configured destination Use the client with `with` statement to manage opening and closing connection to the destination: >>> with pipeline.destination_client() as client: @@ -961,13 +1009,17 @@ def _get_schema_or_create(self, schema_name: str = None) -> Schema: with self._maybe_destination_capabilities(): return Schema(self.pipeline_name) - def _sql_job_client(self, schema: Schema, credentials: Any = None) -> SqlJobClientBase: + def _sql_job_client( + self, schema: Schema, credentials: Any = None + ) -> SqlJobClientBase: client_config = self._get_destination_client_initial_config(credentials) client = self._get_destination_clients(schema, client_config)[0] if isinstance(client, SqlJobClientBase): return client else: - raise SqlClientNotAvailable(self.pipeline_name, self.destination.destination_name) + raise SqlClientNotAvailable( + self.pipeline_name, self.destination.destination_name + ) def _get_normalize_storage(self) -> NormalizeStorage: return NormalizeStorage(True, self._normalize_storage_config()) @@ -987,7 +1039,9 @@ def _normalize_storage_config(self) -> NormalizeStorageConfiguration: ) def _load_storage_config(self) -> LoadStorageConfiguration: - return LoadStorageConfiguration(load_volume_path=os.path.join(self.working_dir, "load")) + return LoadStorageConfiguration( + load_volume_path=os.path.join(self.working_dir, "load") + ) def _init_working_dir(self, pipeline_name: str, pipelines_dir: str) -> None: self.pipeline_name = pipeline_name @@ -1002,7 +1056,10 @@ def _init_working_dir(self, pipeline_name: str, pipelines_dir: str) -> None: self._wipe_working_folder() def _configure( - self, import_schema_path: str, export_schema_path: str, must_attach_to_local_pipeline: bool + self, + import_schema_path: str, + export_schema_path: str, + must_attach_to_local_pipeline: bool, ) -> None: # create schema storage and folders self._schema_storage_config = SchemaStorageConfiguration( @@ -1032,7 +1089,9 @@ def _configure( self._create_pipeline() # create schema storage - self._schema_storage = LiveSchemaStorage(self._schema_storage_config, makedirs=True) + self._schema_storage = LiveSchemaStorage( + self._schema_storage_config, makedirs=True + ) def _create_pipeline(self) -> None: self._wipe_working_folder() @@ -1060,7 +1119,9 @@ def _extract_source( # (2) load import schema an overwrite pipeline schema if import schema modified # (3) load pipeline schema if no import schema is present pipeline_schema = self.schemas[source.schema.name] - pipeline_schema = pipeline_schema.clone() # use clone until extraction complete + pipeline_schema = ( + pipeline_schema.clone() + ) # use clone until extraction complete # apply all changes in the source schema to pipeline schema # NOTE: we do not apply contracts to changes done programmatically pipeline_schema.update_schema(source.schema) @@ -1087,7 +1148,10 @@ def _extract_source( return load_id def _get_destination_client_initial_config( - self, destination: TDestination = None, credentials: Any = None, as_staging: bool = False + self, + destination: TDestination = None, + credentials: Any = None, + as_staging: bool = False, ) -> DestinationClientConfiguration: destination = destination or self.destination if not destination: @@ -1095,8 +1159,9 @@ def _get_destination_client_initial_config( self.pipeline_name, "destination", "load", - "Please provide `destination` argument to `pipeline`, `run` or `load` method" - " directly or via .dlt config.toml file or environment variable.", + "Please provide `destination` argument to `pipeline`, `run` or `load`" + " method directly or via .dlt config.toml file or environment" + " variable.", ) # create initial destination client config client_spec = destination.spec @@ -1104,7 +1169,9 @@ def _get_destination_client_initial_config( if not as_staging: # explicit credentials passed to dlt.pipeline should not be applied to staging credentials = credentials or self.credentials - if credentials is not None and not isinstance(credentials, CredentialsConfiguration): + if credentials is not None and not isinstance( + credentials, CredentialsConfiguration + ): # use passed credentials as initial value. initial value may resolve credentials credentials = initialize_credentials( client_spec.get_resolvable_fields()["credentials"], credentials @@ -1114,8 +1181,8 @@ def _get_destination_client_initial_config( if issubclass(client_spec, DestinationClientDwhConfiguration): if not self.dataset_name and self.full_refresh: logger.warning( - "Full refresh may not work if dataset name is not set. Please set the" - " dataset_name argument in dlt.pipeline or run method" + "Full refresh may not work if dataset name is not set. Please set" + " the dataset_name argument in dlt.pipeline or run method" ) # set default schema name to load all incoming data to a single dataset, no matter what is the current schema name default_schema_name = ( @@ -1148,19 +1215,27 @@ def _get_destination_clients( if self.staging: if not initial_staging_config: # this is just initial config - without user configuration injected - initial_staging_config = self._get_destination_client_initial_config( - self.staging, as_staging=True + initial_staging_config = ( + self._get_destination_client_initial_config( + self.staging, as_staging=True + ) ) # create the client - that will also resolve the config staging_client = self.staging.client(schema, initial_staging_config) if not initial_config: # config is not provided then get it with injected credentials - initial_config = self._get_destination_client_initial_config(self.destination) + initial_config = self._get_destination_client_initial_config( + self.destination + ) # attach the staging client config to destination client config - if its type supports it if ( self.staging - and isinstance(initial_config, DestinationClientDwhWithStagingConfiguration) - and isinstance(staging_client.config, DestinationClientStagingConfiguration) + and isinstance( + initial_config, DestinationClientDwhWithStagingConfiguration + ) + and isinstance( + staging_client.config, DestinationClientStagingConfiguration + ) ): initial_config.staging_config = staging_client.config # create instance with initial_config properly set @@ -1180,8 +1255,9 @@ def _get_destination_capabilities(self) -> DestinationCapabilitiesContext: self.pipeline_name, "destination", "normalize", - "Please provide `destination` argument to `pipeline`, `run` or `load` method" - " directly or via .dlt config.toml file or environment variable.", + "Please provide `destination` argument to `pipeline`, `run` or `load`" + " method directly or via .dlt config.toml file or environment" + " variable.", ) return self.destination.capabilities() @@ -1236,15 +1312,17 @@ def _set_destinations( and not self.staging ): logger.warning( - f"The destination {self.destination.destination_name} requires the filesystem" - " staging destination to be set, but it was not provided. Setting it to" - " 'filesystem'." + f"The destination {self.destination.destination_name} requires the" + " filesystem staging destination to be set, but it was not provided." + " Setting it to 'filesystem'." ) staging = "filesystem" staging_name = "filesystem" if staging: - staging_module = Destination.from_reference(staging, destination_name=staging_name) + staging_module = Destination.from_reference( + staging, destination_name=staging_name + ) if staging_module and not issubclass( staging_module.spec, DestinationClientStagingConfiguration ): @@ -1282,7 +1360,9 @@ def _maybe_destination_capabilities( loader_file_format, ) caps.supported_loader_file_formats = ( - destination_caps.supported_staging_file_formats if stage_caps else None + destination_caps.supported_staging_file_formats + if stage_caps + else None ) or destination_caps.supported_loader_file_formats yield caps finally: @@ -1311,10 +1391,15 @@ def _resolve_loader_file_format( if not dest_caps.preferred_loader_file_format: raise DestinationLoadingWithoutStagingNotSupported(destination) file_format = dest_caps.preferred_loader_file_format - elif stage_caps and dest_caps.preferred_staging_file_format in possible_file_formats: + elif ( + stage_caps + and dest_caps.preferred_staging_file_format in possible_file_formats + ): file_format = dest_caps.preferred_staging_file_format else: - file_format = possible_file_formats[0] if len(possible_file_formats) > 0 else None + file_format = ( + possible_file_formats[0] if len(possible_file_formats) > 0 else None + ) if file_format not in possible_file_formats: raise DestinationIncompatibleLoaderFileFormatException( destination, @@ -1335,8 +1420,9 @@ def _set_dataset_name(self, new_dataset_name: str) -> None: fields = self.destination.spec().get_resolvable_fields() dataset_name_type = fields.get("dataset_name") # if dataset is required (default!) we create a default dataset name - destination_needs_dataset = dataset_name_type is not None and not is_optional_type( - dataset_name_type + destination_needs_dataset = ( + dataset_name_type is not None + and not is_optional_type(dataset_name_type) ) # if destination is not specified - generate dataset if not self.destination or destination_needs_dataset: @@ -1397,7 +1483,9 @@ def _optional_sql_job_client(self, schema_name: str) -> Optional[SqlJobClientBas logger.info(f"Sql Client not available: {pip_ex}") except SqlClientNotAvailable: # fallback is sql client not available for destination - logger.info("Client not available because destination does not support sql client") + logger.info( + "Client not available because destination does not support sql client" + ) except ConfigFieldMissingException: # probably credentials are missing logger.info("Client not available due to missing credentials") @@ -1415,7 +1503,9 @@ def _restore_state_from_destination(self) -> Optional[TPipelineState]: schema = Schema(schema_name) with self._get_destination_clients(schema)[0] as job_client: if isinstance(job_client, WithStateSync): - state = load_pipeline_state_from_destination(self.pipeline_name, job_client) + state = load_pipeline_state_from_destination( + self.pipeline_name, job_client + ) if state is None: logger.info( "The state was not found in the destination" @@ -1466,7 +1556,8 @@ def _get_schemas_from_destination( schema = Schema.from_dict(json.loads(schema_info.schema)) logger.info( f"The schema {schema.name} version {schema.version} hash" - f" {schema.stored_version_hash} was restored from the destination" + f" {schema.stored_version_hash} was restored from the" + " destination" f" {self.destination.destination_name}:{self.dataset_name}" ) restored_schemas.append(schema) @@ -1513,8 +1604,9 @@ def _state_to_props(self, state: TPipelineState) -> None: if state_destination: if self.destination.destination_type != state_destination: logger.warning( - f"The destination {state_destination}:{state.get('destination_name')} in" - " state differs from destination" + "The destination" + f" {state_destination}:{state.get('destination_name')} in state" + " differs from destination" f" {self.destination.destination_type}:{self.destination.destination_name} in" " pipeline and will be ignored" ) @@ -1549,12 +1641,16 @@ def _bump_version_and_extract_state( Storage will be created on demand. In that case the extracted package will be immediately committed. """ - _, hash_, _ = bump_pipeline_state_version_if_modified(self._props_to_state(state)) + _, hash_, _ = bump_pipeline_state_version_if_modified( + self._props_to_state(state) + ) should_extract = hash_ != state["_local"].get("_last_extracted_hash") if should_extract and extract_state: data = state_resource(state) extract_ = extract or Extract( - self._schema_storage, self._normalize_storage_config(), original_data=data + self._schema_storage, + self._normalize_storage_config(), + original_data=data, ) self._extract_source( extract_, data_to_sources(data, self, self.default_schema)[0], 1, 1 diff --git a/dlt/pipeline/platform.py b/dlt/pipeline/platform.py index 0955e91b51..b34ead45f2 100644 --- a/dlt/pipeline/platform.py +++ b/dlt/pipeline/platform.py @@ -4,7 +4,12 @@ from dlt.common.managed_thread_pool import ManagedThreadPool from urllib.parse import urljoin -from dlt.pipeline.trace import PipelineTrace, PipelineStepTrace, TPipelineStep, SupportsPipeline +from dlt.pipeline.trace import ( + PipelineTrace, + PipelineStepTrace, + TPipelineStep, + SupportsPipeline, +) from dlt.common import json from dlt.common import logger from dlt.common.pipeline import LoadInfo @@ -39,7 +44,8 @@ def _future_send() -> None: response = requests.put(url, data=trace_dump) if response.status_code != 200: logger.debug( - f"Failed to send trace to platform, response code: {response.status_code}" + "Failed to send trace to platform, response code:" + f" {response.status_code}" ) except Exception as e: logger.debug(f"Exception while sending trace to platform: {e}") @@ -84,7 +90,8 @@ def _future_send() -> None: response = requests.put(url, data=json.dumps(payload)) if response.status_code != 200: logger.debug( - f"Failed to send state to platform, response code: {response.status_code}" + "Failed to send state to platform, response code:" + f" {response.status_code}" ) except Exception as e: logger.debug(f"Exception while sending state to platform: {e}") @@ -92,7 +99,9 @@ def _future_send() -> None: _THREAD_POOL.thread_pool.submit(_future_send) -def on_start_trace(trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline) -> None: +def on_start_trace( + trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline +) -> None: pass @@ -114,5 +123,7 @@ def on_end_trace_step( _sync_schemas_to_platform(trace, pipeline) -def on_end_trace(trace: PipelineTrace, pipeline: SupportsPipeline, send_state: bool) -> None: +def on_end_trace( + trace: PipelineTrace, pipeline: SupportsPipeline, send_state: bool +) -> None: _send_trace_to_platform(trace, pipeline) diff --git a/dlt/pipeline/progress.py b/dlt/pipeline/progress.py index 89eda4cac5..ec32eff4b4 100644 --- a/dlt/pipeline/progress.py +++ b/dlt/pipeline/progress.py @@ -7,7 +7,10 @@ EnlightenCollector as enlighten, AliveCollector as alive_progress, ) -from dlt.common.runtime.collector import Collector as _Collector, NULL_COLLECTOR as _NULL_COLLECTOR +from dlt.common.runtime.collector import ( + Collector as _Collector, + NULL_COLLECTOR as _NULL_COLLECTOR, +) TSupportedCollectors = Literal["tqdm", "enlighten", "log", "alive_progress"] TCollectorArg = Union[_Collector, TSupportedCollectors] diff --git a/dlt/pipeline/state_sync.py b/dlt/pipeline/state_sync.py index 5366b9c46d..25144c6ced 100644 --- a/dlt/pipeline/state_sync.py +++ b/dlt/pipeline/state_sync.py @@ -27,7 +27,11 @@ # state table columns STATE_TABLE_COLUMNS: TTableSchemaColumns = { "version": {"name": "version", "data_type": "bigint", "nullable": False}, - "engine_version": {"name": "engine_version", "data_type": "bigint", "nullable": False}, + "engine_version": { + "name": "engine_version", + "data_type": "bigint", + "nullable": False, + }, "pipeline_name": {"name": "pipeline_name", "data_type": "text", "nullable": False}, "state": {"name": "state", "data_type": "text", "nullable": False}, "created_at": {"name": "created_at", "data_type": "timestamp", "nullable": False}, @@ -64,7 +68,9 @@ def generate_pipeline_state_version_hash(state: TPipelineState) -> str: return generate_state_version_hash(state, exclude_attrs=["_local"]) -def bump_pipeline_state_version_if_modified(state: TPipelineState) -> Tuple[int, str, str]: +def bump_pipeline_state_version_if_modified( + state: TPipelineState, +) -> Tuple[int, str, str]: return bump_state_version_if_modified(state, exclude_attrs=["_local"]) @@ -128,7 +134,10 @@ def state_resource(state: TPipelineState) -> DltResource: "version_hash": state["_version_hash"], } return dlt.resource( - [state_doc], name=STATE_TABLE_NAME, write_disposition="append", columns=STATE_TABLE_COLUMNS + [state_doc], + name=STATE_TABLE_NAME, + write_disposition="append", + columns=STATE_TABLE_COLUMNS, ) diff --git a/dlt/pipeline/trace.py b/dlt/pipeline/trace.py index b610d1751f..e556835e8a 100644 --- a/dlt/pipeline/trace.py +++ b/dlt/pipeline/trace.py @@ -51,10 +51,17 @@ class SerializableResolvedValueTrace(NamedTuple): def asdict(self) -> StrAny: """A dictionary representation that is safe to load.""" - return {k: v for k, v in self._asdict().items() if k not in ("value", "default_value")} + return { + k: v + for k, v in self._asdict().items() + if k not in ("value", "default_value") + } def asstr(self, verbosity: int = 0) -> str: - return f"{self.key}->{self.value} in {'.'.join(self.sections)} by {self.provider_name}" + return ( + f"{self.key}->{self.value} in {'.'.join(self.sections)} by" + f" {self.provider_name}" + ) def __str__(self) -> str: return self.asstr(verbosity=0) @@ -146,8 +153,8 @@ def asstr(self, verbosity: int = 0) -> str: else: elapsed_str = "---" msg = ( - f"Run started at {self.started_at} and {completed_str} in {elapsed_str} with" - f" {len(self.steps)} steps." + f"Run started at {self.started_at} and {completed_str} in" + f" {elapsed_str} with {len(self.steps)} steps." ) if verbosity > 0 and len(self.resolved_config_values) > 0: msg += "\nFollowing config and secret values were resolved:\n" @@ -337,7 +344,9 @@ def load_trace(trace_path: str) -> PipelineTrace: return None -def get_exception_traces(exc: BaseException, container: Container = None) -> List[ExceptionTrace]: +def get_exception_traces( + exc: BaseException, container: Container = None +) -> List[ExceptionTrace]: """Gets exception trace chain and extend it with data available in Container context""" traces = get_exception_trace_chain(exc) container = container or Container() diff --git a/dlt/pipeline/track.py b/dlt/pipeline/track.py index 990c59050e..294e4dc771 100644 --- a/dlt/pipeline/track.py +++ b/dlt/pipeline/track.py @@ -29,7 +29,9 @@ def _add_sentry_tags(span: Span, pipeline: SupportsPipeline) -> None: pass -def slack_notify_load_success(incoming_hook: str, load_info: LoadInfo, trace: PipelineTrace) -> int: +def slack_notify_load_success( + incoming_hook: str, load_info: LoadInfo, trace: PipelineTrace +) -> int: """Sends a markdown formatted success message and returns http status code from the Slack incoming hook""" try: author = github_info().get("github_user", "") @@ -45,8 +47,12 @@ def _get_step_elapsed(step: PipelineStepTrace) -> str: return f"`{step.step.upper()}`: _{humanize.precisedelta(elapsed)}_ " load_step = trace.steps[-1] - normalize_step = next((step for step in trace.steps if step.step == "normalize"), None) - extract_step = next((step for step in trace.steps if step.step == "extract"), None) + normalize_step = next( + (step for step in trace.steps if step.step == "normalize"), None + ) + extract_step = next( + (step for step in trace.steps if step.step == "extract"), None + ) message = f"""The {author}pipeline *{load_info.pipeline.pipeline_name}* just loaded *{len(load_info.loads_ids)}* load package(s) to destination *{load_info.destination_type}* and into dataset *{load_info.dataset_name}*. 🚀 *{humanize.precisedelta(total_elapsed)}* of which {_get_step_elapsed(load_step)}{_get_step_elapsed(normalize_step)}{_get_step_elapsed(extract_step)}""" @@ -58,7 +64,9 @@ def _get_step_elapsed(step: PipelineStepTrace) -> str: return -1 -def on_start_trace(trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline) -> None: +def on_start_trace( + trace: PipelineTrace, step: TPipelineStep, pipeline: SupportsPipeline +) -> None: if pipeline.runtime_config.sentry_dsn: # https://getsentry.github.io/sentry-python/api.html#sentry_sdk.Hub.capture_event # print(f"START SENTRY TX: {trace.transaction_id} SCOPE: {Hub.current.scope}") @@ -95,12 +103,20 @@ def on_end_trace_step( props = { "elapsed": (step.finished_at - trace.started_at).total_seconds(), "success": step.step_exception is None, - "destination_name": pipeline.destination.destination_name if pipeline.destination else None, - "destination_type": pipeline.destination.destination_type if pipeline.destination else None, + "destination_name": ( + pipeline.destination.destination_name if pipeline.destination else None + ), + "destination_type": ( + pipeline.destination.destination_type if pipeline.destination else None + ), "pipeline_name_hash": digest128(pipeline.pipeline_name), - "dataset_name_hash": digest128(pipeline.dataset_name) if pipeline.dataset_name else None, + "dataset_name_hash": ( + digest128(pipeline.dataset_name) if pipeline.dataset_name else None + ), "default_schema_name_hash": ( - digest128(pipeline.default_schema_name) if pipeline.default_schema_name else None + digest128(pipeline.default_schema_name) + if pipeline.default_schema_name + else None ), "transaction_id": trace.transaction_id, } @@ -113,7 +129,9 @@ def on_end_trace_step( dlthub_telemetry_track("pipeline", step.step, props) -def on_end_trace(trace: PipelineTrace, pipeline: SupportsPipeline, send_state: bool) -> None: +def on_end_trace( + trace: PipelineTrace, pipeline: SupportsPipeline, send_state: bool +) -> None: if pipeline.runtime_config.sentry_dsn: # print(f"---END SENTRY TX: {trace.transaction_id} SCOPE: {Hub.current.scope}") with contextlib.suppress(Exception): diff --git a/dlt/pipeline/warnings.py b/dlt/pipeline/warnings.py index 87fcbc1f0c..a128fa5ccf 100644 --- a/dlt/pipeline/warnings.py +++ b/dlt/pipeline/warnings.py @@ -6,7 +6,9 @@ def credentials_argument_deprecated( - caller_name: str, credentials: t.Optional[t.Any], destination: TDestinationReferenceArg = None + caller_name: str, + credentials: t.Optional[t.Any], + destination: TDestinationReferenceArg = None, ) -> None: if credentials is None: return @@ -14,8 +16,9 @@ def credentials_argument_deprecated( dest_name = Destination.to_name(destination) if destination else "postgres" warnings.warn( - f"The `credentials argument` to {caller_name} is deprecated and will be removed in a future" - " version. Pass the same credentials to the `destination` instance instead, e.g." + f"The `credentials argument` to {caller_name} is deprecated and will be removed" + " in a future version. Pass the same credentials to the `destination` instance" + " instead, e.g." f" {caller_name}(destination=dlt.destinations.{dest_name}(credentials=...))", Dlt04DeprecationWarning, stacklevel=2, diff --git a/dlt/reflection/script_inspector.py b/dlt/reflection/script_inspector.py index f9068d31e4..613945f934 100644 --- a/dlt/reflection/script_inspector.py +++ b/dlt/reflection/script_inspector.py @@ -109,7 +109,9 @@ def load_script_module( # path must be first so we always load our module of sys.path.insert(0, sys_path) try: - logger.info(f"Importing pipeline script from path {module_path} and module: {module}") + logger.info( + f"Importing pipeline script from path {module_path} and module: {module}" + ) if ignore_missing_imports: return _import_module(f"{module}") else: @@ -129,15 +131,17 @@ def inspect_pipeline_script( DltSource, "__init__", patch__init__ ), patch.object(ManagedPipeIterator, "__init__", patch__init__): return load_script_module( - module_path, script_relative_path, ignore_missing_imports=ignore_missing_imports + module_path, + script_relative_path, + ignore_missing_imports=ignore_missing_imports, ) class PipelineIsRunning(DltException): def __init__(self, obj: object, args: Tuple[str, ...], kwargs: DictStrAny) -> None: super().__init__( - "The pipeline script instantiates the pipeline on import. Did you forget to use if" - f" __name__ == 'main':? in {obj.__class__.__name__}", + "The pipeline script instantiates the pipeline on import. Did you forget" + f" to use if __name__ == 'main':? in {obj.__class__.__name__}", obj, args, kwargs, diff --git a/dlt/reflection/script_visitor.py b/dlt/reflection/script_visitor.py index 52b19fe031..352dd37ab0 100644 --- a/dlt/reflection/script_visitor.py +++ b/dlt/reflection/script_visitor.py @@ -109,13 +109,20 @@ def visit_Call(self, node: ast.Call) -> Any: pass else: # check if this is a call to any known source - if alias_name in self.known_sources or alias_name in self.known_resources: + if ( + alias_name in self.known_sources + or alias_name in self.known_resources + ): # set parent to the outer function node.parent = find_outer_func_def(node) # type: ignore if alias_name in self.known_sources: - decorated_calls = self.known_source_calls.setdefault(alias_name, []) + decorated_calls = self.known_source_calls.setdefault( + alias_name, [] + ) else: - decorated_calls = self.known_resource_calls.setdefault(alias_name, []) + decorated_calls = self.known_resource_calls.setdefault( + alias_name, [] + ) decorated_calls.append(node) # visit the children super().generic_visit(node) diff --git a/dlt/sources/credentials.py b/dlt/sources/credentials.py index 2883c0c688..e653c785de 100644 --- a/dlt/sources/credentials.py +++ b/dlt/sources/credentials.py @@ -8,7 +8,10 @@ from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.configuration.specs import OAuth2Credentials from dlt.common.configuration.specs import CredentialsConfiguration, configspec -from dlt.common.storages.configuration import FileSystemCredentials, FilesystemConfiguration +from dlt.common.storages.configuration import ( + FileSystemCredentials, + FilesystemConfiguration, +) __all__ = [ diff --git a/dlt/sources/helpers/requests/retry.py b/dlt/sources/helpers/requests/retry.py index c9a813598f..6d0184116a 100644 --- a/dlt/sources/helpers/requests/retry.py +++ b/dlt/sources/helpers/requests/retry.py @@ -116,13 +116,18 @@ def _make_retry( respect_retry_after_header: bool, max_delay: TimedeltaSeconds, ) -> Retrying: - retry_conds = [retry_if_status(status_codes), retry_if_exception_type(tuple(exceptions))] + retry_conds = [ + retry_if_status(status_codes), + retry_if_exception_type(tuple(exceptions)), + ] if condition is not None: if callable(condition): retry_condition = [condition] retry_conds.extend([retry_if_predicate(c) for c in retry_condition]) - wait_cls = wait_exponential_retry_after if respect_retry_after_header else wait_exponential + wait_cls = ( + wait_exponential_retry_after if respect_retry_after_header else wait_exponential + ) return Retrying( wait=wait_cls(multiplier=backoff_factor, max=max_delay), retry=(retry_any(*retry_conds)), @@ -189,7 +194,9 @@ def __init__( ) -> None: self._adapter = HTTPAdapter(pool_maxsize=max_connections) self._local = local() - self._session_kwargs = dict(timeout=request_timeout, raise_for_status=raise_for_status) + self._session_kwargs = dict( + timeout=request_timeout, raise_for_status=raise_for_status + ) self._retry_kwargs: Dict[str, Any] = dict( status_codes=status_codes, exceptions=exceptions, @@ -220,9 +227,7 @@ def __init__( self.options = lambda *a, **kw: self.session.options(*a, **kw) self.request = lambda *a, **kw: self.session.request(*a, **kw) - self._config_version: int = ( - 0 # Incrementing marker to ensure per-thread sessions are recreated on config changes - ) + self._config_version: int = 0 # Incrementing marker to ensure per-thread sessions are recreated on config changes def update_from_config(self, config: RunConfiguration) -> None: """Update session/retry settings from RunConfiguration""" diff --git a/dlt/sources/helpers/requests/session.py b/dlt/sources/helpers/requests/session.py index 0a4d277848..90b0f01f5b 100644 --- a/dlt/sources/helpers/requests/session.py +++ b/dlt/sources/helpers/requests/session.py @@ -14,7 +14,9 @@ DEFAULT_TIMEOUT = 60 -def _timeout_to_seconds(timeout: TRequestTimeout) -> Optional[Union[Tuple[float, float], float]]: +def _timeout_to_seconds( + timeout: TRequestTimeout, +) -> Optional[Union[Tuple[float, float], float]]: return ( (to_seconds(timeout[0]), to_seconds(timeout[1])) if isinstance(timeout, tuple) diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 5d7a2f7eb2..57afd3ed52 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -102,7 +102,8 @@ def parse_native_representation(self, value: Any) -> None: raise NativeValueError( type(self), value, - f"HttpBasicAuth username and password must be a tuple of two strings, got {type(value)}", + "HttpBasicAuth username and password must be a tuple of two strings, got" + f" {type(value)}", ) def __call__(self, request: PreparedRequest) -> PreparedRequest: @@ -208,8 +209,10 @@ def load_private_key(self) -> "PrivateKeyTypes": private_key_bytes = self.private_key.encode("utf-8") return serialization.load_pem_private_key( private_key_bytes, - password=self.private_key_passphrase.encode("utf-8") - if self.private_key_passphrase - else None, + password=( + self.private_key_passphrase.encode("utf-8") + if self.private_key_passphrase + else None + ), backend=default_backend(), ) diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py index ce414322a0..3061b26e5b 100644 --- a/dlt/sources/helpers/rest_client/paginators.py +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -92,8 +92,8 @@ def update_state(self, response: Response) -> None: total = int(total) except ValueError: raise ValueError( - f"Total count is not an integer in response for {self.__class__.__name__}. " - f"Expected an integer, got {total}" + "Total count is not an integer in response for" + f" {self.__class__.__name__}. Expected an integer, got {total}" ) self.offset += self.limit diff --git a/docs/examples/_template/_template.py b/docs/examples/_template/_template.py index cdd38f8204..5237ec4453 100644 --- a/docs/examples/_template/_template.py +++ b/docs/examples/_template/_template.py @@ -20,7 +20,9 @@ if __name__ == "__main__": # run a pipeline pipeline = dlt.pipeline( - pipeline_name="example_pipeline", destination="duckdb", dataset_name="example_data" + pipeline_name="example_pipeline", + destination="duckdb", + dataset_name="example_data", ) # Extract, normalize, and load the data load_info = pipeline.run([1, 2, 3], table_name="player") diff --git a/docs/examples/archive/_helpers.py b/docs/examples/archive/_helpers.py index 0f490ff85f..c1f917e116 100644 --- a/docs/examples/archive/_helpers.py +++ b/docs/examples/archive/_helpers.py @@ -6,7 +6,9 @@ "type": "service_account", "project_id": "zinc-mantra-353207", "private_key": "XFhETkYxMSY7Og0jJDgjKDcuUz8kK1kAXltcfyQqIjYCBjs2bDc3PzcOCBobHwg1TVpDNDAkLCUqMiciMD9KBBEWJgIiDDY1IB09bzInMkAdMDtCFwYBJ18QGyR/LBVEFQNQOjhIB0UXHhFSOD4hDiRMYCYkIxgkMTgqJTZBOWceIkgHPCU6EiQtHyRcH0MmWh4xDjowBkcMGSY8I38cLgk6NVYAGEU3ExcvPVQvBUYyIS5BClkyHB4MPkATM0BCeFwcFS9dNg8AJA40B0pYJUUxAjkbCzhZQj9mODk6f0Y6JRUBJyQhZysEWkU8MwU1LCsELF4gBStNWzsHAh4PXTVAOxA3PSgJUksFFgAwVxkZGiMwJT4UEgwFEn8/FRd/O1UmKzYRH19kCjBaLCAGIB0VUVk+Bh0zJzQtElJKOBIFAGULRQY7BVInOSAoGBdaMCYgIhMnCBhfNQsDFABFIH8+MD0JBjM0PEQxBwRGXwAiIBkoExgcFCYQQzE6AUAHCCQzSjpdKwcYFAIkHg1CG0o3NSBMEztEBQRYCgB9NwQofw8FOAohDzgCbBQ7MzQoJigUEyQzJlsWNRk7CxYDJS43Jj5BIj5IQQ8UPUtELURCRjBHFRcZMzs+MVAgAmQfGyJ/JjcTHgVWBzBJXEQ6TRgHXD0YCUI7fDQVAiUCMCALM1MbBxw8LCkCJQEySwIZNTJDSyBBJCE0OgsBIkBGSwkfEH8DUjlKM1E+H30nGxwAMxYpG0IpMARoA08dDQFWExs/Lh06VT0hHicQNlsiQQIHDE4UAV4ABAAjMkMFPTB9ISU3fws2GysuBBo1GR84OCJQWgdLBCg3R0Y8FwIYDUwACyAmOR1GIUYgBw86DDIFKkcRXkE9Exo6ERIxACIFHHxGRUJ/XicRPh0GIRBnRQwrQyc7JRRNNB0ieScTO0UYJzwRFRAdIH0WGjVDEVYGSkNSRyBvEk80OzkWDCtfLSc4dEYbJn84JD83ACYzREw6XR9EHxofFiEJQgR0BTBIMQQRBzccJjFMZQERRhsGGTo4NgYjMBkiMisDGyVAJCwbGExmRw48fyEgEUUdKREZBh0UOT89ITJcJSsZHhwjEyckBhURHAAuRhtkPBEEExkvPUNFEzslexlDJx4TIB5GIBZKNxwqGSN/HwAxEjwbXQNGB0YXGwIAASYDWBwibh0UJgZfFiEkJCQbW3kwESk7ODAFKhsACiFhADknNwwSEwoZEDNbYwM5SH8xUwobMCUGDnlBAzQwXiIPKwE5MUxDCjNCJCIhDCI3ThUnfCYkRxkoUiIbMxsfNWEpNzJGPDc4FAElJUxqHxIkfytbKAoMEjhBTkIhNkMsJ1spMydBI08aNwMHJw8aNxk5ARdbFBM9Fj8bPT4ZLhMsdTE9JCImFy8/OwoAGm8XAyF/MS8vJxsLAUZ9KjIrPwxVWwoNJB0OfDo3QR0vVwUWESBHFX1cMl5NDjskPUFOCltnB0cLDyg3ET1fKgoGfAY+O38/EA40MCBGBFgEPTMSLTsOJiAmHSNjNBQVHTwCIBQuUEoGRB4aGQ0YKBxHPg8GIUoaFEAcCikkNT4ONUNgBSBHfyMZAipBNyIBHyEnNx8vTD0kIggqN3g7FAgAAjUDCTI0JRcUMB8DNwo7DBhHOBhBRzcHBBI8EQERGQ5ZGHRBPjt/USwsHDBTAw5XET5AHgYSI0YNBQQmbkYhOiAuFjghQycCAWkpFUceOFUIEgEsBTVOGD8lEVFQLgc1DjU2bDoyBX8FEQpHHyUwW3cQEScNOUgGPhJRRzZmSkUdIj4UCRlCVxUsSRJBIk0lIjsWRAYoFWULHEcBRhclJw0RWSFnNj82fwFQM0EeUgoBWwBCAy0wNQU+Jzk7OFRAMhMCXQYsKyIRPFteGRdHRj4XBwNDBCYCXAkVKzA9GgkKJhAmGh8aLxt/DS4OIRtSFDl4ETEUGFtXMgEAJzYXSikkFQMkUBgVQ1A0QV4XGAA7BSIENDYgPQBUKS4jJhM6EwQsUBMHYTQsQn8oUjM2PBNdEmowHEA4HxFaNj4lQDd8CjxJPyA6ChtAUEZHHT0iOAVeCDMXFSAzXxUxMkMSIAg+RzwqKzVURkE2fxEQB0IyDQgzHBA5KDcDOS8aRSZZQ0BDMAkkEwIgMQwkKwx8JRkEFjgkWwkyJkUfdEAsSBMtGyA4RiVKBENDJCd/WzUvIzc2IBN6HTgcOQsJODYhUEVBRwQUe1hETkZeMS82VH0hPyc0PSZLODE4X1kAXlt7", # noqa - "client_email": "data-load-tool-public-demo@zinc-mantra-353207.iam.gserviceaccount.com", + "client_email": ( + "data-load-tool-public-demo@zinc-mantra-353207.iam.gserviceaccount.com" + ), } # we do not want to have this key verbatim in repo so we decode it here @@ -14,7 +16,8 @@ [ _a ^ _b for _a, _b in zip( - base64.b64decode(_bigquery_credentials["private_key"]), b"quickstart-sv" * 150 + base64.b64decode(_bigquery_credentials["private_key"]), + b"quickstart-sv" * 150, ) ] ).decode("utf-8") diff --git a/docs/examples/archive/credentials/explicit.py b/docs/examples/archive/credentials/explicit.py index b1bc25fce6..56b48a1f22 100644 --- a/docs/examples/archive/credentials/explicit.py +++ b/docs/examples/archive/credentials/explicit.py @@ -14,7 +14,9 @@ def simple_data( # you do not need to follow the recommended way to keep your config and secrets # here we pass the secrets explicitly to the `simple_data` -data = simple_data(dlt.config["custom.simple_data.api_url"], dlt.secrets["simple_data_secret_7"]) +data = simple_data( + dlt.config["custom.simple_data.api_url"], dlt.secrets["simple_data_secret_7"] +) # you should see the values from config.toml or secrets.toml coming from custom locations # config.toml: # [custom] @@ -27,7 +29,9 @@ def simple_data( # all the config providers are being searched, let's use the environment variables which have precedence over the toml files os.environ["CUSTOM__SIMPLE_DATA__API_URL"] = "api url from env" os.environ["SIMPLE_DATA_SECRET_7"] = "api secret from env" -data = simple_data(dlt.config["custom.simple_data.api_url"], dlt.secrets["simple_data_secret_7"]) +data = simple_data( + dlt.config["custom.simple_data.api_url"], dlt.secrets["simple_data_secret_7"] +) print(list(data)) # you are free to pass credentials from custom location to destination diff --git a/docs/examples/archive/dbt_run_jaffle.py b/docs/examples/archive/dbt_run_jaffle.py index 098b35fff8..a9bee8ee33 100644 --- a/docs/examples/archive/dbt_run_jaffle.py +++ b/docs/examples/archive/dbt_run_jaffle.py @@ -3,19 +3,22 @@ pipeline = dlt.pipeline(destination="duckdb", dataset_name="jaffle_jaffle") print( - "create or restore virtual environment in which dbt is installed, use the newest version of dbt" + "create or restore virtual environment in which dbt is installed, use the newest" + " version of dbt" ) venv = dlt.dbt.get_venv(pipeline) print("get runner, optionally pass the venv") -dbt = dlt.dbt.package(pipeline, "https://github.com/dbt-labs/jaffle_shop.git", venv=venv) +dbt = dlt.dbt.package( + pipeline, "https://github.com/dbt-labs/jaffle_shop.git", venv=venv +) print("run the package (clone/pull repo, deps, seed, source tests, run)") models = dbt.run_all() for m in models: print( - f"Model {m.model_name} materialized in {m.time} with status {m.status} and message" - f" {m.message}" + f"Model {m.model_name} materialized in {m.time} with status {m.status} and" + f" message {m.message}" ) print("") @@ -23,7 +26,8 @@ models = dbt.test() for m in models: print( - f"Test {m.model_name} executed in {m.time} with status {m.status} and message {m.message}" + f"Test {m.model_name} executed in {m.time} with status {m.status} and message" + f" {m.message}" ) print("") diff --git a/docs/examples/archive/quickstart.py b/docs/examples/archive/quickstart.py index 6e49f1af7a..4969298cae 100644 --- a/docs/examples/archive/quickstart.py +++ b/docs/examples/archive/quickstart.py @@ -17,7 +17,9 @@ "type": "service_account", "project_id": "zinc-mantra-353207", "private_key": "XFhETkYxMSY7Og0jJDgjKDcuUz8kK1kAXltcfyQqIjYCBjs2bDc3PzcOCBobHwg1TVpDNDAkLCUqMiciMD9KBBEWJgIiDDY1IB09bzInMkAdMDtCFwYBJ18QGyR/LBVEFQNQOjhIB0UXHhFSOD4hDiRMYCYkIxgkMTgqJTZBOWceIkgHPCU6EiQtHyRcH0MmWh4xDjowBkcMGSY8I38cLgk6NVYAGEU3ExcvPVQvBUYyIS5BClkyHB4MPkATM0BCeFwcFS9dNg8AJA40B0pYJUUxAjkbCzhZQj9mODk6f0Y6JRUBJyQhZysEWkU8MwU1LCsELF4gBStNWzsHAh4PXTVAOxA3PSgJUksFFgAwVxkZGiMwJT4UEgwFEn8/FRd/O1UmKzYRH19kCjBaLCAGIB0VUVk+Bh0zJzQtElJKOBIFAGULRQY7BVInOSAoGBdaMCYgIhMnCBhfNQsDFABFIH8+MD0JBjM0PEQxBwRGXwAiIBkoExgcFCYQQzE6AUAHCCQzSjpdKwcYFAIkHg1CG0o3NSBMEztEBQRYCgB9NwQofw8FOAohDzgCbBQ7MzQoJigUEyQzJlsWNRk7CxYDJS43Jj5BIj5IQQ8UPUtELURCRjBHFRcZMzs+MVAgAmQfGyJ/JjcTHgVWBzBJXEQ6TRgHXD0YCUI7fDQVAiUCMCALM1MbBxw8LCkCJQEySwIZNTJDSyBBJCE0OgsBIkBGSwkfEH8DUjlKM1E+H30nGxwAMxYpG0IpMARoA08dDQFWExs/Lh06VT0hHicQNlsiQQIHDE4UAV4ABAAjMkMFPTB9ISU3fws2GysuBBo1GR84OCJQWgdLBCg3R0Y8FwIYDUwACyAmOR1GIUYgBw86DDIFKkcRXkE9Exo6ERIxACIFHHxGRUJ/XicRPh0GIRBnRQwrQyc7JRRNNB0ieScTO0UYJzwRFRAdIH0WGjVDEVYGSkNSRyBvEk80OzkWDCtfLSc4dEYbJn84JD83ACYzREw6XR9EHxofFiEJQgR0BTBIMQQRBzccJjFMZQERRhsGGTo4NgYjMBkiMisDGyVAJCwbGExmRw48fyEgEUUdKREZBh0UOT89ITJcJSsZHhwjEyckBhURHAAuRhtkPBEEExkvPUNFEzslexlDJx4TIB5GIBZKNxwqGSN/HwAxEjwbXQNGB0YXGwIAASYDWBwibh0UJgZfFiEkJCQbW3kwESk7ODAFKhsACiFhADknNwwSEwoZEDNbYwM5SH8xUwobMCUGDnlBAzQwXiIPKwE5MUxDCjNCJCIhDCI3ThUnfCYkRxkoUiIbMxsfNWEpNzJGPDc4FAElJUxqHxIkfytbKAoMEjhBTkIhNkMsJ1spMydBI08aNwMHJw8aNxk5ARdbFBM9Fj8bPT4ZLhMsdTE9JCImFy8/OwoAGm8XAyF/MS8vJxsLAUZ9KjIrPwxVWwoNJB0OfDo3QR0vVwUWESBHFX1cMl5NDjskPUFOCltnB0cLDyg3ET1fKgoGfAY+O38/EA40MCBGBFgEPTMSLTsOJiAmHSNjNBQVHTwCIBQuUEoGRB4aGQ0YKBxHPg8GIUoaFEAcCikkNT4ONUNgBSBHfyMZAipBNyIBHyEnNx8vTD0kIggqN3g7FAgAAjUDCTI0JRcUMB8DNwo7DBhHOBhBRzcHBBI8EQERGQ5ZGHRBPjt/USwsHDBTAw5XET5AHgYSI0YNBQQmbkYhOiAuFjghQycCAWkpFUceOFUIEgEsBTVOGD8lEVFQLgc1DjU2bDoyBX8FEQpHHyUwW3cQEScNOUgGPhJRRzZmSkUdIj4UCRlCVxUsSRJBIk0lIjsWRAYoFWULHEcBRhclJw0RWSFnNj82fwFQM0EeUgoBWwBCAy0wNQU+Jzk7OFRAMhMCXQYsKyIRPFteGRdHRj4XBwNDBCYCXAkVKzA9GgkKJhAmGh8aLxt/DS4OIRtSFDl4ETEUGFtXMgEAJzYXSikkFQMkUBgVQ1A0QV4XGAA7BSIENDYgPQBUKS4jJhM6EwQsUBMHYTQsQn8oUjM2PBNdEmowHEA4HxFaNj4lQDd8CjxJPyA6ChtAUEZHHT0iOAVeCDMXFSAzXxUxMkMSIAg+RzwqKzVURkE2fxEQB0IyDQgzHBA5KDcDOS8aRSZZQ0BDMAkkEwIgMQwkKwx8JRkEFjgkWwkyJkUfdEAsSBMtGyA4RiVKBENDJCd/WzUvIzc2IBN6HTgcOQsJODYhUEVBRwQUe1hETkZeMS82VH0hPyc0PSZLODE4X1kAXlt7", # noqa - "client_email": "data-load-tool-public-demo@zinc-mantra-353207.iam.gserviceaccount.com", + "client_email": ( + "data-load-tool-public-demo@zinc-mantra-353207.iam.gserviceaccount.com" + ), } db_dsn = "postgres://loader:loader@localhost:5432/dlt_data" @@ -28,7 +30,8 @@ [ _a ^ _b for _a, _b in zip( - base64.b64decode(gcp_credentials_json["private_key"]), b"quickstart-sv" * 150 + base64.b64decode(gcp_credentials_json["private_key"]), + b"quickstart-sv" * 150, ) ] ).decode("utf-8") diff --git a/docs/examples/archive/rasa_example.py b/docs/examples/archive/rasa_example.py index e83e6c61f7..f2879c5fe5 100644 --- a/docs/examples/archive/rasa_example.py +++ b/docs/examples/archive/rasa_example.py @@ -17,7 +17,9 @@ # in case of rasa the base source can be a file, database table (see sql_query.py), kafka topic, rabbitmq queue etc. corresponding to store or broker type # for the simplicity let's use jsonl source to read all files with events in a directory -event_files = jsonl_files([file for file in os.scandir("docs/examples/data/rasa_trackers")]) +event_files = jsonl_files( + [file for file in os.scandir("docs/examples/data/rasa_trackers")] +) info = dlt.pipeline( full_refresh=True, diff --git a/docs/examples/archive/read_table.py b/docs/examples/archive/read_table.py index 6cccf0efdb..b2b41798ed 100644 --- a/docs/examples/archive/read_table.py +++ b/docs/examples/archive/read_table.py @@ -10,7 +10,10 @@ # get data from table, we preserve method signature from pandas items = query_table( - "blocks__transactions", source_dsn, table_schema_name="mainnet_2_ethereum", coerce_float=False + "blocks__transactions", + source_dsn, + table_schema_name="mainnet_2_ethereum", + coerce_float=False, ) # the data is also an iterator @@ -20,7 +23,9 @@ print(f"{k}:{v} ({type(v)}:{py_type_to_sc_type(type(v))})") # get data from query -items = query_sql("select * from mainnet_2_ethereum.blocks__transactions limit 10", source_dsn) +items = query_sql( + "select * from mainnet_2_ethereum.blocks__transactions limit 10", source_dsn +) # and load it into a local postgres instance # the connection string does not have the password part. provide it in DESTINATION__CREDENTIALS__PASSWORD diff --git a/docs/examples/archive/singer_tap_example.py b/docs/examples/archive/singer_tap_example.py index a9b105fe93..e8e1650efe 100644 --- a/docs/examples/archive/singer_tap_example.py +++ b/docs/examples/archive/singer_tap_example.py @@ -22,13 +22,17 @@ "files": [ { "entity": "annotations_202205", - "path": os.path.abspath("examples/data/singer_taps/model_annotations.csv"), + "path": os.path.abspath( + "examples/data/singer_taps/model_annotations.csv" + ), "keys": ["message id"], } ] } print("running tap-csv") - tap_source = tap(venv, "tap-csv", csv_tap_config, "examples/data/singer_taps/csv_catalog.json") + tap_source = tap( + venv, "tap-csv", csv_tap_config, "examples/data/singer_taps/csv_catalog.json" + ) info = dlt.pipeline("meltano_csv", destination="postgres").run( tap_source, credentials="postgres://loader@localhost:5432/dlt_data" ) diff --git a/docs/examples/archive/singer_tap_jsonl_example.py b/docs/examples/archive/singer_tap_jsonl_example.py index c926a9f153..b180e6edfa 100644 --- a/docs/examples/archive/singer_tap_jsonl_example.py +++ b/docs/examples/archive/singer_tap_jsonl_example.py @@ -11,7 +11,11 @@ p = dlt.pipeline(destination="postgres", full_refresh=True) # now load a pipeline created from jsonl resource that feeds messages into singer tap transformer -pipe = jsonl_file("docs/examples/data/singer_taps/tap_hubspot.jsonl") | singer_raw_stream() +pipe = ( + jsonl_file("docs/examples/data/singer_taps/tap_hubspot.jsonl") | singer_raw_stream() +) # provide hubspot schema -info = p.run(pipe, schema=schema, credentials="postgres://loader@localhost:5432/dlt_data") +info = p.run( + pipe, schema=schema, credentials="postgres://loader@localhost:5432/dlt_data" +) print(info) diff --git a/docs/examples/archive/sources/google_sheets.py b/docs/examples/archive/sources/google_sheets.py index 69855154ae..2d9cdd5320 100644 --- a/docs/examples/archive/sources/google_sheets.py +++ b/docs/examples/archive/sources/google_sheets.py @@ -1,7 +1,10 @@ from typing import Any, Iterator, Sequence, Union, cast import dlt -from dlt.common.configuration.specs import GcpServiceAccountCredentials, GcpOAuthCredentials +from dlt.common.configuration.specs import ( + GcpServiceAccountCredentials, + GcpOAuthCredentials, +) from dlt.common.typing import DictStrAny, StrAny from dlt.common.exceptions import MissingDependencyException diff --git a/docs/examples/archive/sources/rasa/rasa.py b/docs/examples/archive/sources/rasa/rasa.py index 60643fe17e..99238c04d6 100644 --- a/docs/examples/archive/sources/rasa/rasa.py +++ b/docs/examples/archive/sources/rasa/rasa.py @@ -35,7 +35,8 @@ def events(source_events: TDataItems) -> Iterator[TDataItem]: # recover start_timestamp from state if given if store_last_timestamp: start_timestamp = max( - initial_timestamp or 0, dlt.current.source_state().get("start_timestamp", 0) + initial_timestamp or 0, + dlt.current.source_state().get("start_timestamp", 0), ) # we expect tracker store events here last_timestamp: int = None @@ -46,7 +47,9 @@ def _proc_event(source_event: TDataItem) -> Iterator[TDataItem]: # must be a dict assert isinstance(source_event, dict) # filter out events - if timestamp_within(source_event["timestamp"], start_timestamp, end_timestamp): + if timestamp_within( + source_event["timestamp"], start_timestamp, end_timestamp + ): # yield tracker table with all-event index event_type = source_event["event"] last_timestamp = source_event["timestamp"] diff --git a/docs/examples/archive/sources/singer_tap.py b/docs/examples/archive/sources/singer_tap.py index 3c733c33f1..f14fdeac1f 100644 --- a/docs/examples/archive/sources/singer_tap.py +++ b/docs/examples/archive/sources/singer_tap.py @@ -52,12 +52,16 @@ def get_source_from_stream( @dlt.transformer() -def singer_raw_stream(singer_messages: TDataItems, use_state: bool = True) -> Iterator[TDataItem]: +def singer_raw_stream( + singer_messages: TDataItems, use_state: bool = True +) -> Iterator[TDataItem]: if use_state: state = dlt.current.source_state() else: state = None - yield from get_source_from_stream(cast(Iterator[SingerMessage], singer_messages), state) + yield from get_source_from_stream( + cast(Iterator[SingerMessage], singer_messages), state + ) @dlt.source(spec=BaseConfiguration) # use BaseConfiguration spec to prevent injections @@ -93,7 +97,10 @@ def singer_messages() -> Iterator[TDataItem]: else: state = None if state is not None and state.get("singer"): - state_params = ("--state", as_config_file(dlt.current.source_state()["singer"])) + state_params = ( + "--state", + as_config_file(dlt.current.source_state()["singer"]), + ) else: state_params = () # type: ignore diff --git a/docs/examples/archive/sources/sql_query.py b/docs/examples/archive/sources/sql_query.py index 8cd60992b2..bcdcc1e382 100644 --- a/docs/examples/archive/sources/sql_query.py +++ b/docs/examples/archive/sources/sql_query.py @@ -13,7 +13,9 @@ import pandas except ImportError: raise MissingDependencyException( - "SQL Query Source", ["pandas"], "SQL Query Source temporarily uses pandas as DB interface" + "SQL Query Source", + ["pandas"], + "SQL Query Source temporarily uses pandas as DB interface", ) try: diff --git a/docs/examples/chess/chess.py b/docs/examples/chess/chess.py index df1fb18845..ec6be9d0f7 100644 --- a/docs/examples/chess/chess.py +++ b/docs/examples/chess/chess.py @@ -32,7 +32,9 @@ def players() -> Iterator[TDataItems]: @dlt.transformer(data_from=players, write_disposition="replace") @dlt.defer def players_profiles(username: Any) -> TDataItems: - print(f"getting {username} profile via thread {threading.current_thread().name}") + print( + f"getting {username} profile via thread {threading.current_thread().name}" + ) sleep(1) # add some latency to show parallel runs return _get_data_with_retry(f"player/{username}") @@ -52,7 +54,10 @@ def players_games(username: Any) -> Iterator[TDataItems]: # look for parallel run configuration in `config.toml`! # mind the full_refresh: it makes the pipeline to load to a distinct dataset each time it is run and always is resetting the schema and state load_info = dlt.pipeline( - pipeline_name="chess_games", destination="postgres", dataset_name="chess", full_refresh=True + pipeline_name="chess_games", + destination="postgres", + dataset_name="chess", + full_refresh=True, ).run(chess(max_players=5, month=9)) # display where the data went print(load_info) diff --git a/docs/examples/chess_production/chess_production.py b/docs/examples/chess_production/chess_production.py index c0f11203c8..f79c27c8c5 100644 --- a/docs/examples/chess_production/chess_production.py +++ b/docs/examples/chess_production/chess_production.py @@ -57,7 +57,9 @@ def players() -> Iterator[TDataItems]: # it uses `paralellized` flag to enable parallel run in thread pool. @dlt.transformer(data_from=players, write_disposition="replace", parallelized=True) def players_profiles(username: Any) -> TDataItems: - print(f"getting {username} profile via thread {threading.current_thread().name}") + print( + f"getting {username} profile via thread {threading.current_thread().name}" + ) sleep(1) # add some latency to show parallel runs return _get_data_with_retry(f"player/{username}") @@ -84,7 +86,10 @@ def load_data_with_retry(pipeline, data): reraise=True, ): with attempt: - logger.info(f"Running the pipeline, attempt={attempt.retry_state.attempt_number}") + logger.info( + "Running the pipeline," + f" attempt={attempt.retry_state.attempt_number}" + ) load_info = pipeline.run(data) logger.info(str(load_info)) @@ -92,12 +97,15 @@ def load_data_with_retry(pipeline, data): load_info.raise_on_failed_jobs() # send notification send_slack_message( - pipeline.runtime_config.slack_incoming_hook, "Data was successfully loaded!" + pipeline.runtime_config.slack_incoming_hook, + "Data was successfully loaded!", ) except Exception: # we get here after all the failed retries # send notification - send_slack_message(pipeline.runtime_config.slack_incoming_hook, "Something went wrong!") + send_slack_message( + pipeline.runtime_config.slack_incoming_hook, "Something went wrong!" + ) raise # we get here after a successful attempt @@ -106,14 +114,19 @@ def load_data_with_retry(pipeline, data): # print the information on the first load package and all jobs inside logger.info(f"First load package info: {load_info.load_packages[0]}") # print the information on the first completed job in first load package - logger.info(f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}") + logger.info( + "First completed job info:" + f" {load_info.load_packages[0].jobs['completed_jobs'][0]}" + ) # check for schema updates: schema_updates = [p.schema_update for p in load_info.load_packages] # send notifications if there are schema updates if schema_updates: # send notification - send_slack_message(pipeline.runtime_config.slack_incoming_hook, "Schema was updated!") + send_slack_message( + pipeline.runtime_config.slack_incoming_hook, "Schema was updated!" + ) # To run simple tests with `sql_client`, such as checking table counts and # warning if there is no data, you can use the `execute_query` method diff --git a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py index ea60b9b00d..16d1d3ba0f 100644 --- a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py +++ b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py @@ -74,7 +74,9 @@ def bigquery_insert( ) # since we have set the batch_size to 0, we get a filepath and can load the file directly with open(items, "rb") as f: - load_job = client.load_table_from_file(f, BIGQUERY_TABLE_ID, job_config=job_config) + load_job = client.load_table_from_file( + f, BIGQUERY_TABLE_ID, job_config=job_config + ) load_job.result() # Waits for the job to complete. diff --git a/docs/examples/incremental_loading/incremental_loading.py b/docs/examples/incremental_loading/incremental_loading.py index f1de4eecfe..3365e54e1c 100644 --- a/docs/examples/incremental_loading/incremental_loading.py +++ b/docs/examples/incremental_loading/incremental_loading.py @@ -32,7 +32,9 @@ @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 + start_date: Optional[TAnyDateTime] = pendulum.datetime( + year=2000, month=1, day=1 + ), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ diff --git a/docs/examples/nested_data/nested_data.py b/docs/examples/nested_data/nested_data.py index afda16a51a..314208d69b 100644 --- a/docs/examples/nested_data/nested_data.py +++ b/docs/examples/nested_data/nested_data.py @@ -49,7 +49,9 @@ def mongodb_collection( write_disposition: Optional[str] = dlt.config.value, ) -> Any: # set up mongo client - client: Any = MongoClient(connection_url, uuidRepresentation="standard", tz_aware=True) + client: Any = MongoClient( + connection_url, uuidRepresentation="standard", tz_aware=True + ) mongo_database = client.get_default_database() if not database else client[database] collection_obj = mongo_database[collection] diff --git a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py index 809a6cfbd6..7e5a5330ea 100644 --- a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py +++ b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py @@ -79,7 +79,9 @@ def pdf_to_text(file_item, separate_pages: bool = False): client = weaviate.Client("http://localhost:8080") # get text of all the invoices in InvoiceText class we just created above - print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) + print( + client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do() + ) # make sure nothing failed load_info.raise_on_failed_jobs() diff --git a/docs/examples/qdrant_zendesk/qdrant_zendesk.py b/docs/examples/qdrant_zendesk/qdrant_zendesk.py index 65f399104a..bb193158b1 100644 --- a/docs/examples/qdrant_zendesk/qdrant_zendesk.py +++ b/docs/examples/qdrant_zendesk/qdrant_zendesk.py @@ -45,7 +45,9 @@ @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 + start_date: Optional[TAnyDateTime] = pendulum.datetime( + year=2000, month=1, day=1 + ), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -76,7 +78,9 @@ def zendesk_support( # when two events have the same timestamp @dlt.resource(primary_key="id", write_disposition="append") def tickets_data( - updated_at: dlt.sources.incremental[pendulum.DateTime] = dlt.sources.incremental( + updated_at: dlt.sources.incremental[ + pendulum.DateTime + ] = dlt.sources.incremental( "updated_at", initial_value=start_date_obj, end_value=end_date_obj, diff --git a/docs/tools/check_embedded_snippets.py b/docs/tools/check_embedded_snippets.py index 96e1227745..a264c00d2c 100644 --- a/docs/tools/check_embedded_snippets.py +++ b/docs/tools/check_embedded_snippets.py @@ -68,7 +68,8 @@ def collect_snippets(markdown_files: List[str], verbose: bool) -> List[Snippet]: elif current_snippet: current_snippet.code += line assert not current_snippet, ( - "It seems that the last snippet in the file was not closed. Please check the file " + "It seems that the last snippet in the file was not closed. Please check" + " the file " + file ) @@ -83,7 +84,9 @@ def collect_snippets(markdown_files: List[str], verbose: bool) -> List[Snippet]: return snippets -def filter_snippets(snippets: List[Snippet], files: str, snippet_numbers: str) -> List[Snippet]: +def filter_snippets( + snippets: List[Snippet], files: str, snippet_numbers: str +) -> List[Snippet]: """ Filter out snippets based on file or snippet number """ @@ -100,8 +103,8 @@ def filter_snippets(snippets: List[Snippet], files: str, snippet_numbers: str) - filtered_snippets.append(snippet) if filtered_count: fmt.note( - f"{filtered_count} Snippets skipped based on file and snippet number settings." - f" {len(filtered_snippets)} snippets remaining." + f"{filtered_count} Snippets skipped based on file and snippet number" + f" settings. {len(filtered_snippets)} snippets remaining." ) else: fmt.note("0 Snippets skipped based on file and snippet number settings") @@ -120,7 +123,9 @@ def check_language(snippets: List[Snippet]) -> None: failed_count = 0 for snippet in snippets: if snippet.language not in ALLOWED_LANGUAGES: - fmt.warning(f"{str(snippet)} has an invalid language {snippet.language} setting.") + fmt.warning( + f"{str(snippet)} has an invalid language {snippet.language} setting." + ) failed_count += 1 if failed_count: @@ -193,7 +198,9 @@ def lint_snippets(snippets: List[Snippet], verbose: bool) -> None: for snippet in snippets: count += 1 prepare_for_linting(snippet) - result = subprocess.run(["ruff", "check", LINT_FILE], capture_output=True, text=True) + result = subprocess.run( + ["ruff", "check", LINT_FILE], capture_output=True, text=True + ) if verbose: fmt.echo(f"Linting {snippet} ({count} of {len(snippets)})") if "error" in result.stdout.lower(): @@ -236,27 +243,30 @@ def typecheck_snippets(snippets: List[Snippet], verbose: bool) -> None: if __name__ == "__main__": fmt.note( - "Welcome to Snippet Checker 3000, run 'python check_embedded_snippets.py --help' for help." + "Welcome to Snippet Checker 3000, run 'python check_embedded_snippets.py" + " --help' for help." ) # setup cli parser = argparse.ArgumentParser( description=( - "Check embedded snippets. Discover, parse, lint, and type check all code snippets in" - " the docs." + "Check embedded snippets. Discover, parse, lint, and type check all code" + " snippets in the docs." ), formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "command", help=( - 'Which checks to run. "full" will run all checks, parse, lint or typecheck will only' - " run that specific step" + 'Which checks to run. "full" will run all checks, parse, lint or typecheck' + " will only run that specific step" ), choices=["full", "parse", "lint", "typecheck"], default="full", ) - parser.add_argument("-v", "--verbose", help="Increase output verbosity", action="store_true") + parser.add_argument( + "-v", "--verbose", help="Increase output verbosity", action="store_true" + ) parser.add_argument( "-f", "--files", @@ -267,8 +277,8 @@ def typecheck_snippets(snippets: List[Snippet], verbose: bool) -> None: "-s", "--snippetnumbers", help=( - "Filter checked snippets to snippetnumbers contained in this string, example:" - ' "13,412,345"' + "Filter checked snippets to snippetnumbers contained in this string," + ' example: "13,412,345"' ), type=lambda i: i.split(","), default=None, diff --git a/docs/tools/fix_grammar_gpt.py b/docs/tools/fix_grammar_gpt.py index 051448a2d4..9f17879243 100644 --- a/docs/tools/fix_grammar_gpt.py +++ b/docs/tools/fix_grammar_gpt.py @@ -27,23 +27,28 @@ if __name__ == "__main__": load_dotenv() - fmt.note("Welcome to Grammar Fixer 3000, run 'python fix_grammar_gpt.py --help' for help.") + fmt.note( + "Welcome to Grammar Fixer 3000, run 'python fix_grammar_gpt.py --help' for" + " help." + ) # setup cli parser = argparse.ArgumentParser( description=( - "Fixes the grammar of our docs with open ai. Requires an .env file with the open ai" - " key." + "Fixes the grammar of our docs with open ai. Requires an .env file with the" + " open ai key." ), formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("-v", "--verbose", help="Increase output verbosity", action="store_true") + parser.add_argument( + "-v", "--verbose", help="Increase output verbosity", action="store_true" + ) parser.add_argument( "-f", "--files", help=( - "Specify the file name. Grammar Checker will filter all .md files containing this" - " string in the filepath." + "Specify the file name. Grammar Checker will filter all .md files" + " containing this string in the filepath." ), type=str, ) @@ -63,7 +68,9 @@ for file_path in markdown_files: count += 1 - fmt.note(f"Fixing grammar for file {file_path} ({count} of {len(markdown_files)})") + fmt.note( + f"Fixing grammar for file {file_path} ({count} of {len(markdown_files)})" + ) with open(file_path, "r", encoding="utf-8") as f: doc = f.readlines() diff --git a/docs/tools/lint_setup/template.py b/docs/tools/lint_setup/template.py index c72c4dba62..8e34de44a1 100644 --- a/docs/tools/lint_setup/template.py +++ b/docs/tools/lint_setup/template.py @@ -3,7 +3,17 @@ # mypy: disable-error-code="name-defined,import-not-found,import-untyped,empty-body,no-redef" # some universal imports -from typing import Optional, Dict, List, Any, Iterable, Iterator, Tuple, Sequence, Callable +from typing import ( + Optional, + Dict, + List, + Any, + Iterable, + Iterator, + Tuple, + Sequence, + Callable, +) import os diff --git a/docs/website/docs/dlt-ecosystem/transformations/dbt/dbt-snippets.py b/docs/website/docs/dlt-ecosystem/transformations/dbt/dbt-snippets.py index 4cb960b19f..02b24b1fd8 100644 --- a/docs/website/docs/dlt-ecosystem/transformations/dbt/dbt-snippets.py +++ b/docs/website/docs/dlt-ecosystem/transformations/dbt/dbt-snippets.py @@ -9,7 +9,9 @@ def run_dbt_standalone_snippet() -> None: None, # we do not need dataset name and we do not pass any credentials in environment to dlt working_dir=".", # the package below will be cloned to current dir package_location="https://github.com/dbt-labs/jaffle_shop.git", - package_profiles_dir=os.path.abspath("."), # profiles.yml must be placed in this dir + package_profiles_dir=os.path.abspath( + "." + ), # profiles.yml must be placed in this dir package_profile_name="duckdb_dlt_dbt_test", # name of the profile ) @@ -18,6 +20,6 @@ def run_dbt_standalone_snippet() -> None: for m in models: print( - f"Model {m.model_name} materialized in {m.time} with status {m.status} and message" - f" {m.message}" + f"Model {m.model_name} materialized in {m.time} with status {m.status} and" + f" message {m.message}" ) diff --git a/docs/website/docs/general-usage/snippets/destination-snippets.py b/docs/website/docs/general-usage/snippets/destination-snippets.py index 3484d943a0..4c5d5ccfe9 100644 --- a/docs/website/docs/general-usage/snippets/destination-snippets.py +++ b/docs/website/docs/general-usage/snippets/destination-snippets.py @@ -33,7 +33,9 @@ def destination_instantiation_snippet() -> None: # @@@DLT_SNIPPET_START instance import dlt - azure_bucket = filesystem("az://dlt-azure-bucket", destination_name="production_az_bucket") + azure_bucket = filesystem( + "az://dlt-azure-bucket", destination_name="production_az_bucket" + ) pipeline = dlt.pipeline("pipeline", destination=azure_bucket) # @@@DLT_SNIPPET_END instance assert pipeline.destination.destination_name == "production_az_bucket" @@ -45,7 +47,9 @@ def destination_instantiation_snippet() -> None: # pass full credentials - together with the password (not recommended) pipeline = dlt.pipeline( "pipeline", - destination=postgres(credentials="postgresql://loader:loader@localhost:5432/dlt_data"), + destination=postgres( + credentials="postgresql://loader:loader@localhost:5432/dlt_data" + ), ) # @@@DLT_SNIPPET_END config_explicit @@ -75,7 +79,8 @@ def destination_instantiation_snippet() -> None: # fill only the account name, leave key to be taken from secrets credentials.azure_storage_account_name = "production_storage" pipeline = dlt.pipeline( - "pipeline", destination=filesystem("az://dlt-azure-bucket", credentials=credentials) + "pipeline", + destination=filesystem("az://dlt-azure-bucket", credentials=credentials), ) # @@@DLT_SNIPPET_END config_partial_spec diff --git a/docs/website/docs/getting-started-snippets.py b/docs/website/docs/getting-started-snippets.py index eb00df9986..e9de1eb6e2 100644 --- a/docs/website/docs/getting-started-snippets.py +++ b/docs/website/docs/getting-started-snippets.py @@ -99,7 +99,9 @@ def db_snippet() -> None: # use any sql database supported by SQLAlchemy, below we use a public mysql instance to get data # NOTE: you'll need to install pymysql with "pip install pymysql" # NOTE: loading data from public mysql instance may take several seconds - engine = create_engine("mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam") + engine = create_engine( + "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam" + ) with engine.connect() as conn: # select genome table, stream data in batches of 100 elements rows = conn.execution_options(yield_per=100).exec_driver_sql( @@ -113,7 +115,9 @@ def db_snippet() -> None: ) # here we convert the rows into dictionaries on the fly with a map function - load_info = pipeline.run(map(lambda row: dict(row._mapping), rows), table_name="genome") + load_info = pipeline.run( + map(lambda row: dict(row._mapping), rows), table_name="genome" + ) print(load_info) # @@@DLT_SNIPPET_END db @@ -147,7 +151,9 @@ def incremental_snippet() -> None: @dlt.resource(table_name="issues", write_disposition="append") def get_issues( - created_at=dlt.sources.incremental("created_at", initial_value="1970-01-01T00:00:00Z") + created_at=dlt.sources.incremental( + "created_at", initial_value="1970-01-01T00:00:00Z" + ) ): # NOTE: we read only open issues to minimize number of calls to the API. There's a limit of ~50 calls for not authenticated Github users url = "https://api.github.com/repos/dlt-hub/dlt/issues?per_page=100&sort=created&directions=desc&state=open" @@ -194,7 +200,9 @@ def incremental_merge_snippet() -> None: primary_key="id", ) def get_issues( - updated_at=dlt.sources.incremental("updated_at", initial_value="1970-01-01T00:00:00Z") + updated_at=dlt.sources.incremental( + "updated_at", initial_value="1970-01-01T00:00:00Z" + ) ): # NOTE: we read only open issues to minimize number of calls to the API. There's a limit of ~50 calls for not authenticated Github users url = f"https://api.github.com/repos/dlt-hub/dlt/issues?since={updated_at.last_value}&per_page=100&sort=updated&directions=desc&state=open" @@ -230,7 +238,9 @@ def table_dispatch_snippet() -> None: import dlt from dlt.sources.helpers import requests - @dlt.resource(primary_key="id", table_name=lambda i: i["type"], write_disposition="append") + @dlt.resource( + primary_key="id", table_name=lambda i: i["type"], write_disposition="append" + ) def repo_events(last_created_at=dlt.sources.incremental("created_at")): url = "https://api.github.com/repos/dlt-hub/dlt/events?per_page=100" diff --git a/docs/website/docs/intro-snippets.py b/docs/website/docs/intro-snippets.py index f270dcee6e..4b89f9365a 100644 --- a/docs/website/docs/intro-snippets.py +++ b/docs/website/docs/intro-snippets.py @@ -59,7 +59,9 @@ def db_snippet() -> None: # MySQL instance to get data. # NOTE: you'll need to install pymysql with `pip install pymysql` # NOTE: loading data from public mysql instance may take several seconds - engine = create_engine("mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam") + engine = create_engine( + "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam" + ) with engine.connect() as conn: # Select genome table, stream data in batches of 100 elements @@ -73,7 +75,9 @@ def db_snippet() -> None: ) # Convert the rows into dictionaries on the fly with a map function - load_info = pipeline.run(map(lambda row: dict(row._mapping), rows), table_name="genome") + load_info = pipeline.run( + map(lambda row: dict(row._mapping), rows), table_name="genome" + ) print(load_info) # @@@DLT_SNIPPET_END db diff --git a/docs/website/docs/reference/performance_snippets/performance-snippets.py b/docs/website/docs/reference/performance_snippets/performance-snippets.py index 68ec8ed72d..59c97d3a96 100644 --- a/docs/website/docs/reference/performance_snippets/performance-snippets.py +++ b/docs/website/docs/reference/performance_snippets/performance-snippets.py @@ -14,13 +14,19 @@ def read_table(limit): while item_slice := list(islice(rows, 1000)): now = pendulum.now().isoformat() yield [ - {"row": _id, "description": "this is row with id {_id}", "timestamp": now} + { + "row": _id, + "description": "this is row with id {_id}", + "timestamp": now, + } for _id in item_slice ] # this prevents process pool to run the initialization code again if __name__ == "__main__" or "PYTEST_CURRENT_TEST" in os.environ: - pipeline = dlt.pipeline("parallel_load", destination="duckdb", full_refresh=True) + pipeline = dlt.pipeline( + "parallel_load", destination="duckdb", full_refresh=True + ) pipeline.extract(read_table(1000000)) load_id = pipeline.list_extracted_load_packages()[0] diff --git a/docs/website/docs/tutorial/load-data-from-an-api-snippets.py b/docs/website/docs/tutorial/load-data-from-an-api-snippets.py index d53af9e3d9..cf19f46a1b 100644 --- a/docs/website/docs/tutorial/load-data-from-an-api-snippets.py +++ b/docs/website/docs/tutorial/load-data-from-an-api-snippets.py @@ -53,7 +53,9 @@ def incremental_snippet() -> None: @dlt.resource(table_name="issues", write_disposition="append") def get_issues( - created_at=dlt.sources.incremental("created_at", initial_value="1970-01-01T00:00:00Z") + created_at=dlt.sources.incremental( + "created_at", initial_value="1970-01-01T00:00:00Z" + ) ): # NOTE: we read only open issues to minimize number of calls to the API. # There's a limit of ~50 calls for not authenticated Github users. @@ -107,7 +109,9 @@ def incremental_merge_snippet() -> None: primary_key="id", ) def get_issues( - updated_at=dlt.sources.incremental("updated_at", initial_value="1970-01-01T00:00:00Z") + updated_at=dlt.sources.incremental( + "updated_at", initial_value="1970-01-01T00:00:00Z" + ) ): # NOTE: we read only open issues to minimize number of calls to # the API. There's a limit of ~50 calls for not authenticated @@ -149,7 +153,9 @@ def table_dispatch_snippet() -> None: import dlt from dlt.sources.helpers import requests - @dlt.resource(primary_key="id", table_name=lambda i: i["type"], write_disposition="append") + @dlt.resource( + primary_key="id", table_name=lambda i: i["type"], write_disposition="append" + ) def repo_events(last_created_at=dlt.sources.incremental("created_at")): url = "https://api.github.com/repos/dlt-hub/dlt/events?per_page=100" diff --git a/docs/website/docs/utils.py b/docs/website/docs/utils.py index 10b56cc8b7..f6c6b52144 100644 --- a/docs/website/docs/utils.py +++ b/docs/website/docs/utils.py @@ -19,7 +19,8 @@ def parse_toml_file(filename: str) -> None: tomlkit.loads(toml_snippet) except Exception as e: print( - f"Error while testing snippet between: {current_marker} and {line.strip()}" + f"Error while testing snippet between: {current_marker} and" + f" {line.strip()}" ) raise e current_lines = [] diff --git a/pyproject.toml b/pyproject.toml index 444541aa43..ca719b7923 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -198,12 +198,12 @@ pymongo = ">=4.3.3" pandas = ">2" [tool.black] # https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html#configuration-via-a-file -line-length = 100 +line-length = 88 preview = true [tool.isort] # https://pycqa.github.io/isort/docs/configuration/options.html color_output = true -line_length = 100 +line_length = 88 profile = "black" src_paths = ["dlt"] multi_line_output = 3 diff --git a/tests/cases.py b/tests/cases.py index 85caec4b8d..a5398ee21f 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -88,8 +88,18 @@ {"name": "col9_null", "data_type": "complex", "nullable": True, "variant": True}, {"name": "col10_null", "data_type": "date", "nullable": True}, {"name": "col11_null", "data_type": "time", "nullable": True}, - {"name": "col1_precision", "data_type": "bigint", "precision": 16, "nullable": False}, - {"name": "col4_precision", "data_type": "timestamp", "precision": 3, "nullable": False}, + { + "name": "col1_precision", + "data_type": "bigint", + "precision": 16, + "nullable": False, + }, + { + "name": "col4_precision", + "data_type": "timestamp", + "precision": 3, + "nullable": False, + }, {"name": "col5_precision", "data_type": "text", "precision": 25, "nullable": False}, { "name": "col6_precision", @@ -98,7 +108,12 @@ "scale": 2, "nullable": False, }, - {"name": "col7_precision", "data_type": "binary", "precision": 19, "nullable": False}, + { + "name": "col7_precision", + "data_type": "binary", + "precision": 19, + "nullable": False, + }, {"name": "col11_precision", "data_type": "time", "precision": 3, "nullable": False}, ] TABLE_UPDATE_COLUMNS_SCHEMA: TTableSchemaColumns = {t["name"]: t for t in TABLE_UPDATE} @@ -152,7 +167,11 @@ def table_update_and_row( exclude_col_names = list(exclude_columns or []) if exclude_types: exclude_col_names.extend( - [key for key, value in column_schemas.items() if value["data_type"] in exclude_types] + [ + key + for key, value in column_schemas.items() + if value["data_type"] in exclude_types + ] ) for col_name in set(exclude_col_names): del column_schemas[col_name] @@ -178,25 +197,33 @@ def assert_all_data_types_row( else: db_mapping = {col_name: db_row[i] for i, col_name in enumerate(schema)} - expected_rows = {key: value for key, value in TABLE_ROW_ALL_DATA_TYPES.items() if key in schema} + expected_rows = { + key: value for key, value in TABLE_ROW_ALL_DATA_TYPES.items() if key in schema + } # prepare date to be compared: convert into pendulum instance, adjust microsecond precision if "col4" in expected_rows: parsed_date = pendulum.instance(db_mapping["col4"]) - db_mapping["col4"] = reduce_pendulum_datetime_precision(parsed_date, timestamp_precision) + db_mapping["col4"] = reduce_pendulum_datetime_precision( + parsed_date, timestamp_precision + ) expected_rows["col4"] = reduce_pendulum_datetime_precision( ensure_pendulum_datetime(expected_rows["col4"]), # type: ignore[arg-type] timestamp_precision, ) if "col4_precision" in expected_rows: parsed_date = pendulum.instance(db_mapping["col4_precision"]) - db_mapping["col4_precision"] = reduce_pendulum_datetime_precision(parsed_date, 3) + db_mapping["col4_precision"] = reduce_pendulum_datetime_precision( + parsed_date, 3 + ) expected_rows["col4_precision"] = reduce_pendulum_datetime_precision( ensure_pendulum_datetime(expected_rows["col4_precision"]), 3 # type: ignore[arg-type] ) if "col11_precision" in expected_rows: parsed_time = ensure_pendulum_time(db_mapping["col11_precision"]) - db_mapping["col11_precision"] = reduce_pendulum_datetime_precision(parsed_time, 3) + db_mapping["col11_precision"] = reduce_pendulum_datetime_precision( + parsed_time, 3 + ) expected_rows["col11_precision"] = reduce_pendulum_datetime_precision( ensure_pendulum_time(expected_rows["col11_precision"]), 3 # type: ignore[arg-type] ) @@ -212,7 +239,9 @@ def assert_all_data_types_row( except ValueError: if not allow_base64_binary: raise - db_mapping[binary_col] = base64.b64decode(db_mapping[binary_col], validate=True) + db_mapping[binary_col] = base64.b64decode( + db_mapping[binary_col], validate=True + ) else: db_mapping[binary_col] = bytes(db_mapping[binary_col]) @@ -239,7 +268,9 @@ def assert_all_data_types_row( for key, expected in expected_rows.items(): actual = db_mapping[key] - assert expected == actual, f"Expected {expected} but got {actual} for column {key}" + assert ( + expected == actual + ), f"Expected {expected} but got {actual} for column {key}" assert db_mapping == expected_rows @@ -278,20 +309,30 @@ def arrow_table_all_data_types( "string": [random.choice(ascii_lowercase) for _ in range(num_rows)], "float": [round(random.uniform(0, 100), 4) for _ in range(num_rows)], "int": [random.randrange(0, 100) for _ in range(num_rows)], - "datetime": pd.date_range("2021-01-01T01:02:03.1234", periods=num_rows, tz="UTC"), + "datetime": pd.date_range( + "2021-01-01T01:02:03.1234", periods=num_rows, tz="UTC" + ), "date": pd.date_range("2021-01-01", periods=num_rows, tz="UTC").date, "binary": [random.choice(ascii_lowercase).encode() for _ in range(num_rows)], - "decimal": [Decimal(str(round(random.uniform(0, 100), 4))) for _ in range(num_rows)], + "decimal": [ + Decimal(str(round(random.uniform(0, 100), 4))) for _ in range(num_rows) + ], "bool": [random.choice([True, False]) for _ in range(num_rows)], - "string_null": [random.choice(ascii_lowercase) for _ in range(num_rows - 1)] + [None], + "string_null": [random.choice(ascii_lowercase) for _ in range(num_rows - 1)] + [ + None + ], "null": pd.Series([None for _ in range(num_rows)]), } if include_name_clash: - data["pre Normalized Column"] = [random.choice(ascii_lowercase) for _ in range(num_rows)] + data["pre Normalized Column"] = [ + random.choice(ascii_lowercase) for _ in range(num_rows) + ] include_not_normalized_name = True if include_not_normalized_name: - data["Pre Normalized Column"] = [random.choice(ascii_lowercase) for _ in range(num_rows)] + data["Pre Normalized Column"] = [ + random.choice(ascii_lowercase) for _ in range(num_rows) + ] if include_json: data["json"] = [{"a": random.randrange(0, 100)} for _ in range(num_rows)] diff --git a/tests/cli/common/test_cli_invoke.py b/tests/cli/common/test_cli_invoke.py index d367a97261..c2e7458115 100644 --- a/tests/cli/common/test_cli_invoke.py +++ b/tests/cli/common/test_cli_invoke.py @@ -59,10 +59,14 @@ def test_invoke_pipeline(script_runner: ScriptRunner) -> None: p = dlt.pipeline(pipeline_name="dummy_pipeline") p._wipe_working_folder() - shutil.copytree("tests/cli/cases/deploy_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True) + shutil.copytree( + "tests/cli/cases/deploy_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True + ) with set_working_dir(TEST_STORAGE_ROOT): - with custom_environ({"COMPETED_PROB": "1.0", "DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ( + {"COMPETED_PROB": "1.0", "DLT_DATA_DIR": get_dlt_data_dir()} + ): venv = Venv.restore_current() venv.run_script("dummy_pipeline.py") # we check output test_pipeline_command else @@ -81,7 +85,14 @@ def test_invoke_pipeline(script_runner: ScriptRunner) -> None: try: # use debug flag to raise an exception result = script_runner.run( - ["dlt", "--debug", "pipeline", "dummy_pipeline", "load-package", "NON EXISTENT"] + [ + "dlt", + "--debug", + "pipeline", + "dummy_pipeline", + "load-package", + "NON EXISTENT", + ] ) # exception terminates command assert result.returncode == 1 @@ -101,7 +112,10 @@ def test_invoke_init_chess_and_template(script_runner: ScriptRunner) -> None: assert "Verified source chess was added to your project!" in result.stdout assert result.returncode == 0 result = script_runner.run(["dlt", "init", "debug_pipeline", "dummy"]) - assert "Your new pipeline debug_pipeline is ready to be customized!" in result.stdout + assert ( + "Your new pipeline debug_pipeline is ready to be customized!" + in result.stdout + ) assert result.returncode == 0 @@ -118,21 +132,39 @@ def test_invoke_deploy_project(script_runner: ScriptRunner) -> None: # store dlt data in test storage (like patch_home_dir) with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): result = script_runner.run( - ["dlt", "deploy", "debug_pipeline.py", "github-action", "--schedule", "@daily"] + [ + "dlt", + "deploy", + "debug_pipeline.py", + "github-action", + "--schedule", + "@daily", + ] ) assert result.returncode == -4 assert "The pipeline script does not exist" in result.stderr - result = script_runner.run(["dlt", "deploy", "debug_pipeline.py", "airflow-composer"]) + result = script_runner.run( + ["dlt", "deploy", "debug_pipeline.py", "airflow-composer"] + ) assert result.returncode == -4 assert "The pipeline script does not exist" in result.stderr # now init result = script_runner.run(["dlt", "init", "chess", "dummy"]) assert result.returncode == 0 result = script_runner.run( - ["dlt", "deploy", "chess_pipeline.py", "github-action", "--schedule", "@daily"] + [ + "dlt", + "deploy", + "chess_pipeline.py", + "github-action", + "--schedule", + "@daily", + ] ) assert "NOTE: You must run the pipeline locally" in result.stdout - result = script_runner.run(["dlt", "deploy", "chess_pipeline.py", "airflow-composer"]) + result = script_runner.run( + ["dlt", "deploy", "chess_pipeline.py", "airflow-composer"] + ) assert "NOTE: You must run the pipeline locally" in result.stdout @@ -140,7 +172,14 @@ def test_invoke_deploy_mock(script_runner: ScriptRunner) -> None: # NOTE: you can mock only once per test with ScriptRunner !! with patch("dlt.cli.deploy_command.deploy_command") as _deploy_command: script_runner.run( - ["dlt", "deploy", "debug_pipeline.py", "github-action", "--schedule", "@daily"] + [ + "dlt", + "deploy", + "debug_pipeline.py", + "github-action", + "--schedule", + "@daily", + ] ) assert _deploy_command.called assert _deploy_command.call_args[1] == { @@ -183,13 +222,17 @@ def test_invoke_deploy_mock(script_runner: ScriptRunner) -> None: } # no schedule fails _deploy_command.reset_mock() - result = script_runner.run(["dlt", "deploy", "debug_pipeline.py", "github-action"]) + result = script_runner.run( + ["dlt", "deploy", "debug_pipeline.py", "github-action"] + ) assert not _deploy_command.called assert result.returncode != 0 assert "the following arguments are required: --schedule" in result.stderr # airflow without schedule works _deploy_command.reset_mock() - result = script_runner.run(["dlt", "deploy", "debug_pipeline.py", "airflow-composer"]) + result = script_runner.run( + ["dlt", "deploy", "debug_pipeline.py", "airflow-composer"] + ) assert _deploy_command.called assert result.returncode == 0 assert _deploy_command.call_args[1] == { @@ -203,7 +246,14 @@ def test_invoke_deploy_mock(script_runner: ScriptRunner) -> None: # env secrets format _deploy_command.reset_mock() result = script_runner.run( - ["dlt", "deploy", "debug_pipeline.py", "airflow-composer", "--secrets-format", "env"] + [ + "dlt", + "deploy", + "debug_pipeline.py", + "airflow-composer", + "--secrets-format", + "env", + ] ) assert _deploy_command.called assert result.returncode == 0 diff --git a/tests/cli/common/test_telemetry_command.py b/tests/cli/common/test_telemetry_command.py index 1b6588c9c8..39869de413 100644 --- a/tests/cli/common/test_telemetry_command.py +++ b/tests/cli/common/test_telemetry_command.py @@ -8,13 +8,18 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.paths import DOT_DLT from dlt.common.configuration.providers import ConfigTomlProvider, CONFIG_TOML -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.storages import FileStorage from dlt.common.typing import DictStrAny from dlt.common.utils import set_working_dir from dlt.cli.utils import track_command -from dlt.cli.telemetry_command import telemetry_status_command, change_telemetry_status_command +from dlt.cli.telemetry_command import ( + telemetry_status_command, + change_telemetry_status_command, +) from tests.utils import patch_random_home_dir, start_test_telemetry, test_storage @@ -30,9 +35,9 @@ def _initial_providers(): glob_ctx = ConfigProvidersContext() glob_ctx.providers = _initial_providers() - with set_working_dir(test_storage.make_full_path("project")), Container().injectable_context( - glob_ctx - ), patch( + with set_working_dir( + test_storage.make_full_path("project") + ), Container().injectable_context(glob_ctx), patch( "dlt.common.configuration.specs.config_providers_context.ConfigProvidersContext.initial_providers", _initial_providers, ): @@ -174,7 +179,10 @@ def test_instrumentation_wrappers() -> None: ) msg = SENT_ITEMS[0] assert msg["event"] == "command_deploy" - assert msg["properties"]["deployment_method"] == DeploymentMethods.github_actions.value + assert ( + msg["properties"]["deployment_method"] + == DeploymentMethods.github_actions.value + ) assert msg["properties"]["success"] is False diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 78efcd03c4..2743415ce2 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -1 +1,6 @@ -from tests.utils import preserve_environ, autouse_test_storage, unload_modules, wipe_pipeline +from tests.utils import ( + preserve_environ, + autouse_test_storage, + unload_modules, + wipe_pipeline, +) diff --git a/tests/cli/test_config_toml_writer.py b/tests/cli/test_config_toml_writer.py index 8ccac21f99..7b3f44b729 100644 --- a/tests/cli/test_config_toml_writer.py +++ b/tests/cli/test_config_toml_writer.py @@ -36,25 +36,39 @@ def test_write_value(example_toml): assert toml_table["species"] == "Homo sapiens" # Test with is_default_of_interest=True and non-optional, non-final hint - write_value(toml_table, "species", str, overwrite_existing=True, is_default_of_interest=True) + write_value( + toml_table, "species", str, overwrite_existing=True, is_default_of_interest=True + ) assert toml_table["species"] == "species" # Test with is_default_of_interest=False and non-optional, non-final hint, and no default write_value( - toml_table, "population", int, overwrite_existing=True, is_default_of_interest=False + toml_table, + "population", + int, + overwrite_existing=True, + is_default_of_interest=False, ) # non default get typed example value assert "population" in toml_table # Test with optional hint write_value( - toml_table, "habitat", Optional[str], overwrite_existing=True, is_default_of_interest=False + toml_table, + "habitat", + Optional[str], + overwrite_existing=True, + is_default_of_interest=False, ) assert "habitat" not in toml_table # test with optional hint of interest write_value( - toml_table, "habitat", Optional[str], overwrite_existing=True, is_default_of_interest=True + toml_table, + "habitat", + Optional[str], + overwrite_existing=True, + is_default_of_interest=True, ) assert "habitat" in toml_table @@ -82,7 +96,9 @@ def test_write_value(example_toml): def test_write_values(example_toml): values = [ WritableConfigValue("species", str, "Homo sapiens", ("taxonomy", "genus")), - WritableConfigValue("species", str, "Mus musculus", ("taxonomy", "genus", "subgenus")), + WritableConfigValue( + "species", str, "Mus musculus", ("taxonomy", "genus", "subgenus") + ), WritableConfigValue("genome_size", float, 3.2, ("genomic_info",)), ] write_values(example_toml, values, overwrite_existing=True) @@ -107,7 +123,10 @@ def test_write_values(example_toml): write_values(example_toml, new_values, overwrite_existing=True) assert example_toml["taxonomy"]["genus"]["species"] == "Canis lupus" - assert example_toml["taxonomy"]["genus"]["subgenus"]["species"] == "Canis lupus familiaris" + assert ( + example_toml["taxonomy"]["genus"]["subgenus"]["species"] + == "Canis lupus familiaris" + ) assert example_toml["genomic_info"]["genome_size"] == 2.8 @@ -138,24 +157,35 @@ def test_write_values_without_defaults(example_toml): WritableConfigValue("species", str, None, ("taxonomy", "genus")), WritableConfigValue("genome_size", float, None, ("genomic_info",)), WritableConfigValue("is_animal", bool, None, ("animal_info",)), - WritableConfigValue("chromosomes", list, None, ("genomic_info", "chromosome_data")), + WritableConfigValue( + "chromosomes", list, None, ("genomic_info", "chromosome_data") + ), WritableConfigValue("genes", dict, None, ("genomic_info", "gene_data")), ] write_values(example_toml, values, overwrite_existing=True) assert example_toml["taxonomy"]["genus"]["species"] == "species" - assert example_toml["taxonomy"]["genus"]["species"].trivia.comment == EXAMPLE_COMMENT + assert ( + example_toml["taxonomy"]["genus"]["species"].trivia.comment == EXAMPLE_COMMENT + ) assert example_toml["genomic_info"]["genome_size"] == 1.0 assert example_toml["genomic_info"]["genome_size"].trivia.comment == EXAMPLE_COMMENT assert example_toml["animal_info"]["is_animal"] is True - assert example_toml["genomic_info"]["chromosome_data"]["chromosomes"] == ["a", "b", "c"] + assert example_toml["genomic_info"]["chromosome_data"]["chromosomes"] == [ + "a", + "b", + "c", + ] assert ( example_toml["genomic_info"]["chromosome_data"]["chromosomes"].trivia.comment == EXAMPLE_COMMENT ) assert example_toml["genomic_info"]["gene_data"]["genes"] == {"key": "value"} - assert example_toml["genomic_info"]["gene_data"]["genes"].trivia.comment == EXAMPLE_COMMENT + assert ( + example_toml["genomic_info"]["gene_data"]["genes"].trivia.comment + == EXAMPLE_COMMENT + ) diff --git a/tests/cli/test_deploy_command.py b/tests/cli/test_deploy_command.py index 685921ca6e..32155bb2c2 100644 --- a/tests/cli/test_deploy_command.py +++ b/tests/cli/test_deploy_command.py @@ -23,7 +23,10 @@ DEPLOY_PARAMS = [ - ("github-action", {"schedule": "*/30 * * * *", "run_on_push": True, "run_manually": True}), + ( + "github-action", + {"schedule": "*/30 * * * *", "run_on_push": True, "run_manually": True}, + ), ("airflow-composer", {"secrets_format": "toml"}), ("airflow-composer", {"secrets_format": "env"}), ] @@ -64,7 +67,9 @@ def test_deploy_command( p = dlt.pipeline(pipeline_name="debug_pipeline") p._wipe_working_folder() - shutil.copytree("tests/cli/cases/deploy_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True) + shutil.copytree( + "tests/cli/cases/deploy_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True + ) with set_working_dir(TEST_STORAGE_ROOT): from git import Repo, Remote diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index bf5c21c80f..de3d29f3e7 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -16,7 +16,11 @@ from dlt.common import git from dlt.common.configuration.paths import make_dlt_settings_path -from dlt.common.configuration.providers import CONFIG_TOML, SECRETS_TOML, SecretsTomlProvider +from dlt.common.configuration.providers import ( + CONFIG_TOML, + SECRETS_TOML, + SecretsTomlProvider, +) from dlt.common.runners import Venv from dlt.common.storages.file_storage import FileStorage from dlt.common.source import _SOURCES @@ -53,21 +57,27 @@ def get_verified_source_candidates(repo_dir: str) -> List[str]: return files_ops.get_verified_source_names(sources_storage) -def test_init_command_pipeline_template(repo_dir: str, project_files: FileStorage) -> None: +def test_init_command_pipeline_template( + repo_dir: str, project_files: FileStorage +) -> None: init_command.init_command("debug_pipeline", "bigquery", False, repo_dir) visitor = assert_init_files(project_files, "debug_pipeline", "bigquery") # single resource assert len(visitor.known_resource_calls) == 1 -def test_init_command_pipeline_generic(repo_dir: str, project_files: FileStorage) -> None: +def test_init_command_pipeline_generic( + repo_dir: str, project_files: FileStorage +) -> None: init_command.init_command("generic_pipeline", "redshift", True, repo_dir) visitor = assert_init_files(project_files, "generic_pipeline", "redshift") # multiple resources assert len(visitor.known_resource_calls) > 1 -def test_init_command_new_pipeline_same_name(repo_dir: str, project_files: FileStorage) -> None: +def test_init_command_new_pipeline_same_name( + repo_dir: str, project_files: FileStorage +) -> None: init_command.init_command("debug_pipeline", "bigquery", False, repo_dir) with io.StringIO() as buf, contextlib.redirect_stdout(buf): init_command.init_command("debug_pipeline", "bigquery", False, repo_dir) @@ -75,7 +85,9 @@ def test_init_command_new_pipeline_same_name(repo_dir: str, project_files: FileS assert "already exist, exiting" in _out -def test_init_command_chess_verified_source(repo_dir: str, project_files: FileStorage) -> None: +def test_init_command_chess_verified_source( + repo_dir: str, project_files: FileStorage +) -> None: init_command.init_command("chess", "duckdb", False, repo_dir) assert_source_files(project_files, "chess", "duckdb", has_source_section=True) assert_requirements_txt(project_files, "duckdb") @@ -107,7 +119,9 @@ def test_init_command_chess_verified_source(repo_dir: str, project_files: FileSt raise -def test_init_list_verified_pipelines(repo_dir: str, project_files: FileStorage) -> None: +def test_init_list_verified_pipelines( + repo_dir: str, project_files: FileStorage +) -> None: sources = init_command._list_verified_sources(repo_dir) # a few known sources must be there known_sources = ["chess", "sql_database", "google_sheets", "pipedrive"] @@ -123,7 +137,9 @@ def test_init_list_verified_pipelines_update_warning( repo_dir: str, project_files: FileStorage ) -> None: """Sources listed include a warning if a different dlt version is required""" - with mock.patch.object(SourceRequirements, "current_dlt_version", return_value="0.0.1"): + with mock.patch.object( + SourceRequirements, "current_dlt_version", return_value="0.0.1" + ): with io.StringIO() as buf, contextlib.redirect_stdout(buf): init_command.list_verified_sources_command(repo_dir) _out = buf.getvalue() @@ -140,7 +156,9 @@ def test_init_list_verified_pipelines_update_warning( assert "0.0.1" not in parsed_requirement.specifier -def test_init_all_verified_sources_together(repo_dir: str, project_files: FileStorage) -> None: +def test_init_all_verified_sources_together( + repo_dir: str, project_files: FileStorage +) -> None: source_candidates = get_verified_source_candidates(repo_dir) # source_candidates = [source_name for source_name in source_candidates if source_name == "salesforce"] for source_name in source_candidates: @@ -158,7 +176,9 @@ def test_init_all_verified_sources_together(repo_dir: str, project_files: FileSt assert files_ops.load_verified_sources_local_index(source_name) is not None # credentials for all destinations for destination_name in ["bigquery", "postgres", "redshift"]: - assert secrets.get_value(destination_name, type, None, "destination") is not None + assert ( + secrets.get_value(destination_name, type, None, "destination") is not None + ) # create pipeline template on top init_command.init_command("debug_pipeline", "postgres", False, repo_dir) @@ -196,7 +216,9 @@ def test_init_all_destinations( def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) -> None: sources_storage = FileStorage(os.path.join(repo_dir, SOURCES_MODULE_NAME)) new_content = '"""New docstrings"""' - new_content_hash = hashlib.sha3_256(bytes(new_content, encoding="ascii")).hexdigest() + new_content_hash = hashlib.sha3_256( + bytes(new_content, encoding="ascii") + ).hexdigest() init_command.init_command("pipedrive", "duckdb", False, repo_dir) # modify existing file, no commit @@ -227,7 +249,10 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) # remote file entry in modified assert modified[mod_file_path] == remote_index["files"][mod_file_path] # git sha didn't change (not committed) - assert modified[mod_file_path]["git_sha"] == local_index["files"][mod_file_path]["git_sha"] + assert ( + modified[mod_file_path]["git_sha"] + == local_index["files"][mod_file_path]["git_sha"] + ) # local entry in deleted assert deleted[del_file_path] == local_index["files"][del_file_path] @@ -240,7 +265,9 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) # merge into local index modified.update(new) - local_index = files_ops._merge_remote_index(local_index, remote_index, modified, deleted) + local_index = files_ops._merge_remote_index( + local_index, remote_index, modified, deleted + ) assert new_file_path in local_index["files"] assert del_file_path not in local_index["files"] assert local_index["files"][mod_file_path]["sha3_256"] == new_content_hash @@ -270,7 +297,11 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) # resolve conflicts in three different ways # skip option (the default) res, sel_modified, sel_deleted = _select_source_files( - "pipedrive", deepcopy(modified), deepcopy(deleted), conflict_modified, conflict_deleted + "pipedrive", + deepcopy(modified), + deepcopy(deleted), + conflict_modified, + conflict_deleted, ) # noting is written, including non-conflicting file assert res == "s" @@ -279,7 +310,11 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) # Apply option - local changes will be lost with echo.always_choose(False, "a"): res, sel_modified, sel_deleted = _select_source_files( - "pipedrive", deepcopy(modified), deepcopy(deleted), conflict_modified, conflict_deleted + "pipedrive", + deepcopy(modified), + deepcopy(deleted), + conflict_modified, + conflict_deleted, ) assert res == "a" assert sel_modified == modified @@ -287,7 +322,11 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) # merge only non conflicting changes are applied with echo.always_choose(False, "m"): res, sel_modified, sel_deleted = _select_source_files( - "pipedrive", deepcopy(modified), deepcopy(deleted), conflict_modified, conflict_deleted + "pipedrive", + deepcopy(modified), + deepcopy(deleted), + conflict_modified, + conflict_deleted, ) assert res == "m" assert len(sel_modified) == 1 and mod_file_path_2 in sel_modified @@ -321,7 +360,9 @@ def test_init_code_update_index_diff(repo_dir: str, project_files: FileStorage) assert conflict_modified == [mod_file_path] -def test_init_code_update_no_conflict(repo_dir: str, project_files: FileStorage) -> None: +def test_init_code_update_no_conflict( + repo_dir: str, project_files: FileStorage +) -> None: init_command.init_command("pipedrive", "duckdb", False, repo_dir) with git.get_repo(repo_dir) as repo: assert git.is_clean_and_synced(repo) is True @@ -357,7 +398,9 @@ def test_init_code_update_no_conflict(repo_dir: str, project_files: FileStorage) != local_index["files"][mod_local_path]["git_sha"] ) # all the other files must keep the old hashes - for old_f, new_f in zip(local_index["files"].items(), new_local_index["files"].items()): + for old_f, new_f in zip( + local_index["files"].items(), new_local_index["files"].items() + ): # assert new_f[1]["commit_sha"] == commit.hexsha if old_f[0] != mod_local_path: assert old_f[1]["git_sha"] == new_f[1]["git_sha"] @@ -463,20 +506,25 @@ def test_pipeline_template_sources_in_single_file( # _SOURCES now contains the sources from pipeline.py which simulates loading from two places with pytest.raises(CliCommandException) as cli_ex: init_command.init_command("generic_pipeline", "redshift", True, repo_dir) - assert "In init scripts you must declare all sources and resources in single file." in str( - cli_ex.value + assert ( + "In init scripts you must declare all sources and resources in single file." + in str(cli_ex.value) ) -def test_incompatible_dlt_version_warning(repo_dir: str, project_files: FileStorage) -> None: - with mock.patch.object(SourceRequirements, "current_dlt_version", return_value="0.1.1"): +def test_incompatible_dlt_version_warning( + repo_dir: str, project_files: FileStorage +) -> None: + with mock.patch.object( + SourceRequirements, "current_dlt_version", return_value="0.1.1" + ): with io.StringIO() as buf, contextlib.redirect_stdout(buf): init_command.init_command("facebook_ads", "bigquery", False, repo_dir) _out = buf.getvalue() assert ( - "WARNING: This pipeline requires a newer version of dlt than your installed version" - " (0.1.1)." + "WARNING: This pipeline requires a newer version of dlt than your installed" + " version (0.1.1)." in _out ) @@ -487,7 +535,9 @@ def assert_init_files( destination_name: str, dependency_destination: Optional[str] = None, ) -> PipelineScriptVisitor: - visitor, _ = assert_common_files(project_files, pipeline_name + ".py", destination_name) + visitor, _ = assert_common_files( + project_files, pipeline_name + ".py", destination_name + ) assert not project_files.has_folder(pipeline_name) assert_requirements_txt(project_files, dependency_destination or destination_name) return visitor @@ -506,7 +556,9 @@ def assert_requirements_txt(project_files: FileStorage, destination_name: str) - assert len(source_requirements.dlt_requirement.specifier) >= 1 -def assert_index_version_constraint(project_files: FileStorage, source_name: str) -> None: +def assert_index_version_constraint( + project_files: FileStorage, source_name: str +) -> None: # check dlt version constraint in .sources index for given source matches the one in requirements.txt local_index = files_ops.load_verified_sources_local_index(source_name) index_constraint = local_index["dlt_version_constraint"] @@ -567,9 +619,19 @@ def assert_common_files( secrets = SecretsTomlProvider() if destination_name not in ["duckdb", "dummy"]: # destination is there - assert secrets.get_value(destination_name, type, None, "destination") is not None + assert ( + secrets.get_value(destination_name, type, None, "destination") is not None + ) # certain values are never there - for not_there in ["destination_name", "default_schema_name", "as_staging", "staging_config"]: - assert secrets.get_value(not_there, type, None, "destination", destination_name)[0] is None + for not_there in [ + "destination_name", + "default_schema_name", + "as_staging", + "staging_config", + ]: + assert ( + secrets.get_value(not_there, type, None, "destination", destination_name)[0] + is None + ) return visitor, secrets diff --git a/tests/cli/test_pipeline_command.py b/tests/cli/test_pipeline_command.py index 1f8e2ff4f3..5b296feeb1 100644 --- a/tests/cli/test_pipeline_command.py +++ b/tests/cli/test_pipeline_command.py @@ -159,7 +159,9 @@ def test_pipeline_command_operations(repo_dir: str, project_files: FileStorage) assert "players_profiles" not in pipeline.default_schema.tables -def test_pipeline_command_failed_jobs(repo_dir: str, project_files: FileStorage) -> None: +def test_pipeline_command_failed_jobs( + repo_dir: str, project_files: FileStorage +) -> None: init_command.init_command("chess", "dummy", False, repo_dir) try: @@ -194,7 +196,9 @@ def test_pipeline_command_failed_jobs(repo_dir: str, project_files: FileStorage) assert "JOB file type: jsonl" in _out -def test_pipeline_command_drop_partial_loads(repo_dir: str, project_files: FileStorage) -> None: +def test_pipeline_command_drop_partial_loads( + repo_dir: str, project_files: FileStorage +) -> None: init_command.init_command("chess", "dummy", False, repo_dir) try: @@ -221,13 +225,17 @@ def test_pipeline_command_drop_partial_loads(repo_dir: str, project_files: FileS with io.StringIO() as buf, contextlib.redirect_stdout(buf): with echo.always_choose(False, True): - pipeline_command.pipeline_command("drop-pending-packages", "chess_pipeline", None, 1) + pipeline_command.pipeline_command( + "drop-pending-packages", "chess_pipeline", None, 1 + ) _out = buf.getvalue() assert "Pending packages deleted" in _out print(_out) with io.StringIO() as buf, contextlib.redirect_stdout(buf): - pipeline_command.pipeline_command("drop-pending-packages", "chess_pipeline", None, 1) + pipeline_command.pipeline_command( + "drop-pending-packages", "chess_pipeline", None, 1 + ) _out = buf.getvalue() assert "No pending packages found" in _out print(_out) diff --git a/tests/common/configuration/test_accessors.py b/tests/common/configuration/test_accessors.py index 147d56abec..3d9a7dbe36 100644 --- a/tests/common/configuration/test_accessors.py +++ b/tests/common/configuration/test_accessors.py @@ -16,7 +16,9 @@ GcpServiceAccountCredentialsWithoutDefaults, ConnectionStringCredentials, ) -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.configuration.utils import get_resolved_traces, ResolvedValueTrace from dlt.common.runners.configuration import PoolRunnerConfiguration from dlt.common.typing import AnyType, TSecretValue @@ -33,7 +35,9 @@ def test_accessor_singletons() -> None: assert dlt.secrets.value is None -def test_getter_accessor(toml_providers: ConfigProvidersContext, environment: Any) -> None: +def test_getter_accessor( + toml_providers: ConfigProvidersContext, environment: Any +) -> None: with pytest.raises(KeyError) as py_ex: dlt.config["_unknown"] with pytest.raises(ConfigFieldMissingException) as py_ex: @@ -57,21 +61,41 @@ def test_getter_accessor(toml_providers: ConfigProvidersContext, environment: An # get sectioned values assert dlt.config["typecheck.str_val"] == "test string" assert RESOLVED_TRACES["typecheck.str_val"] == ResolvedValueTrace( - "str_val", "test string", None, AnyType, ["typecheck"], ConfigTomlProvider().name, None + "str_val", + "test string", + None, + AnyType, + ["typecheck"], + ConfigTomlProvider().name, + None, ) environment["DLT__THIS__VALUE"] = "embedded" assert dlt.config["dlt.this.value"] == "embedded" assert RESOLVED_TRACES["dlt.this.value"] == ResolvedValueTrace( - "value", "embedded", None, AnyType, ["dlt", "this"], EnvironProvider().name, None + "value", + "embedded", + None, + AnyType, + ["dlt", "this"], + EnvironProvider().name, + None, ) assert dlt.secrets["dlt.this.value"] == "embedded" assert RESOLVED_TRACES["dlt.this.value"] == ResolvedValueTrace( - "value", "embedded", None, TSecretValue, ["dlt", "this"], EnvironProvider().name, None + "value", + "embedded", + None, + TSecretValue, + ["dlt", "this"], + EnvironProvider().name, + None, ) -def test_getter_auto_cast(toml_providers: ConfigProvidersContext, environment: Any) -> None: +def test_getter_auto_cast( + toml_providers: ConfigProvidersContext, environment: Any +) -> None: environment["VALUE"] = "{SET}" assert dlt.config["value"] == "{SET}" # bool @@ -123,7 +147,8 @@ def test_getter_auto_cast(toml_providers: ConfigProvidersContext, environment: A ) # equivalent assert ( - dlt.secrets["destination.bigquery.client_email"] == "loader@a7513.iam.gserviceaccount.com" + dlt.secrets["destination.bigquery.client_email"] + == "loader@a7513.iam.gserviceaccount.com" ) assert RESOLVED_TRACES["destination.bigquery.client_email"] == ResolvedValueTrace( "client_email", @@ -136,14 +161,24 @@ def test_getter_auto_cast(toml_providers: ConfigProvidersContext, environment: A ) -def test_getter_accessor_typed(toml_providers: ConfigProvidersContext, environment: Any) -> None: +def test_getter_accessor_typed( + toml_providers: ConfigProvidersContext, environment: Any +) -> None: # get a dict as str - credentials_str = '{"secret_value":"2137","project_id":"mock-project-id-credentials"}' + credentials_str = ( + '{"secret_value":"2137","project_id":"mock-project-id-credentials"}' + ) # the typed version coerces the value into desired type, in this case "dict" -> "str" assert dlt.secrets.get("credentials", str) == credentials_str # note that trace keeps original value of "credentials" which was of dictionary type assert RESOLVED_TRACES[".credentials"] == ResolvedValueTrace( - "credentials", json.loads(credentials_str), None, str, [], SecretsTomlProvider().name, None + "credentials", + json.loads(credentials_str), + None, + str, + [], + SecretsTomlProvider().name, + None, ) # unchanged type assert isinstance(dlt.secrets.get("credentials"), dict) @@ -161,7 +196,9 @@ def test_getter_accessor_typed(toml_providers: ConfigProvidersContext, environme "credentials", credentials_str, None, ConnectionStringCredentials, ["databricks"], SecretsTomlProvider().name, ConnectionStringCredentials # type: ignore[arg-type] ) assert c.drivername == "databricks+connector" - c2 = dlt.secrets.get("destination.credentials", GcpServiceAccountCredentialsWithoutDefaults) + c2 = dlt.secrets.get( + "destination.credentials", GcpServiceAccountCredentialsWithoutDefaults + ) assert c2.client_email == "loader@a7513.iam.gserviceaccount.com" diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 5fbcd86d92..33c5b2c40e 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -20,8 +20,18 @@ from dlt.common.configuration.specs.gcp_credentials import ( GcpServiceAccountCredentialsWithoutDefaults, ) -from dlt.common.utils import custom_environ, get_exception_trace, get_exception_trace_chain -from dlt.common.typing import AnyType, DictStrAny, StrAny, TSecretValue, extract_inner_type +from dlt.common.utils import ( + custom_environ, + get_exception_trace, + get_exception_trace_chain, +) +from dlt.common.typing import ( + AnyType, + DictStrAny, + StrAny, + TSecretValue, + extract_inner_type, +) from dlt.common.configuration.exceptions import ( ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported, @@ -72,7 +82,9 @@ INVALID_COERCIONS = { # 'STR_VAL': 'test string', # string always OK "int_val": "a12345", - "bool_val": "not_bool", # bool overridden by string - that is the most common problem + "bool_val": ( + "not_bool" + ), # bool overridden by string - that is the most common problem "list_val": {"2": 1, "3": 3.0}, "dict_val": "{'a': 1, 'b', '2'}", "bytes_val": "Hello World!", @@ -261,7 +273,9 @@ def test_set_default_config_value(environment: Any) -> None: ) assert c.to_native_representation() == "h>a>b>he" # set from native form - c = resolve.resolve_configuration(InstrumentedConfiguration(), explicit_value="h>a>b>he") + c = resolve.resolve_configuration( + InstrumentedConfiguration(), explicit_value="h>a>b>he" + ) assert c.head == "h" assert c.tube == ["a", "b"] assert c.heels == "he" @@ -295,7 +309,9 @@ def test_explicit_values(environment: Any) -> None: # unknown field in explicit value dict is ignored c = resolve.resolve_configuration( - CoercionTestConfiguration(), explicit_value={"created_val": "3343"}, accept_partial=True + CoercionTestConfiguration(), + explicit_value={"created_val": "3343"}, + accept_partial=True, ) assert "created_val" not in c @@ -303,7 +319,8 @@ def test_explicit_values(environment: Any) -> None: def test_explicit_values_false_when_bool() -> None: # values like 0, [], "" all coerce to bool False c = resolve.resolve_configuration( - InstrumentedConfiguration(), explicit_value={"head": "", "tube": [], "heels": ""} + InstrumentedConfiguration(), + explicit_value={"head": "", "tube": [], "heels": ""}, ) assert c.head == "" assert c.tube == [] @@ -354,7 +371,9 @@ def test_explicit_native_always_skips_resolve(environment: Any) -> None: with patch.object(InstrumentedConfiguration, "__section__", "ins"): # explicit native representations skips resolve environment["INS__HEELS"] = "xhe" - c = resolve.resolve_configuration(InstrumentedConfiguration(), explicit_value="h>a>b>he") + c = resolve.resolve_configuration( + InstrumentedConfiguration(), explicit_value="h>a>b>he" + ) assert c.heels == "he" # normal resolve (heels from env) @@ -371,7 +390,9 @@ def test_explicit_native_always_skips_resolve(environment: Any) -> None: assert c.heels == "uhe" # also the native explicit value - c = resolve.resolve_configuration(InstrumentedConfiguration(), explicit_value="h>a>b>uhe") + c = resolve.resolve_configuration( + InstrumentedConfiguration(), explicit_value="h>a>b>uhe" + ) assert c.heels == "uhe" @@ -412,7 +433,9 @@ def test_invalid_native_config_value() -> None: def test_on_resolved(environment: Any) -> None: with pytest.raises(RuntimeError): # head over hells - resolve.resolve_configuration(InstrumentedConfiguration(), explicit_value="he>a>b>h") + resolve.resolve_configuration( + InstrumentedConfiguration(), explicit_value="he>a>b>h" + ) def test_embedded_config(environment: Any) -> None: @@ -526,7 +549,9 @@ def test_run_configuration_gen_name(environment: Any) -> None: assert C.pipeline_name.startswith("dlt_") -def test_configuration_is_mutable_mapping(environment: Any, env_provider: ConfigProvider) -> None: +def test_configuration_is_mutable_mapping( + environment: Any, env_provider: ConfigProvider +) -> None: @configspec class _SecretCredentials(RunConfiguration): pipeline_name: Optional[str] = "secret" @@ -632,7 +657,9 @@ class MultiConfiguration( assert C.__section__ == "DLT_TEST" -def test_raises_on_unresolved_field(environment: Any, env_provider: ConfigProvider) -> None: +def test_raises_on_unresolved_field( + environment: Any, env_provider: ConfigProvider +) -> None: # via make configuration with pytest.raises(ConfigFieldMissingException) as cf_missing_exc: resolve.resolve_configuration(WrongConfiguration()) @@ -656,7 +683,9 @@ def test_raises_on_unresolved_field(environment: Any, env_provider: ConfigProvid assert exception_trace["exception_attrs"]["fields"] == ["NoneConfigVar"] -def test_raises_on_many_unresolved_fields(environment: Any, env_provider: ConfigProvider) -> None: +def test_raises_on_many_unresolved_fields( + environment: Any, env_provider: ConfigProvider +) -> None: # via make configuration with pytest.raises(ConfigFieldMissingException) as cf_missing_exc: resolve.resolve_configuration(CoercionTestConfiguration()) @@ -666,7 +695,9 @@ def test_raises_on_many_unresolved_fields(environment: Any, env_provider: Config assert cf_missing_exc.value.spec_name == "CoercionTestConfiguration" # get all fields that must be set val_fields = [ - f for f in CoercionTestConfiguration().get_resolvable_fields() if f.lower().endswith("_val") + f + for f in CoercionTestConfiguration().get_resolvable_fields() + if f.lower().endswith("_val") ] traces = cf_missing_exc.value.traces assert len(traces) == len(val_fields) @@ -726,7 +757,11 @@ def test_coercion_to_hint_types(environment: Any) -> None: C = CoercionTestConfiguration() resolve._resolve_config_fields( - C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False + C, + explicit_values=None, + explicit_sections=(), + embedded_sections=(), + accept_partial=False, ) for key in COERCIONS: @@ -762,24 +797,34 @@ def test_values_serialization() -> None: # test credentials credentials_str = "databricks+connector://token:-databricks_token-@:443/?conn_timeout=15&search_path=a%2Cb%2Cc" - credentials = deserialize_value("credentials", credentials_str, ConnectionStringCredentials) + credentials = deserialize_value( + "credentials", credentials_str, ConnectionStringCredentials + ) assert credentials.drivername == "databricks+connector" assert credentials.query == {"conn_timeout": "15", "search_path": "a,b,c"} assert credentials.password == "-databricks_token-" assert serialize_value(credentials) == credentials_str # using dict also works credentials_dict = dict(credentials) - credentials_2 = deserialize_value("credentials", credentials_dict, ConnectionStringCredentials) + credentials_2 = deserialize_value( + "credentials", credentials_dict, ConnectionStringCredentials + ) assert serialize_value(credentials_2) == credentials_str # if string is not a valid native representation of credentials but is parsable json dict then it works as well credentials_json = json.dumps(credentials_dict) - credentials_3 = deserialize_value("credentials", credentials_json, ConnectionStringCredentials) + credentials_3 = deserialize_value( + "credentials", credentials_json, ConnectionStringCredentials + ) assert serialize_value(credentials_3) == credentials_str # test config without native representation - secret_config = deserialize_value("credentials", {"secret_value": "a"}, SecretConfiguration) + secret_config = deserialize_value( + "credentials", {"secret_value": "a"}, SecretConfiguration + ) assert secret_config.secret_value == "a" - secret_config = deserialize_value("credentials", '{"secret_value": "a"}', SecretConfiguration) + secret_config = deserialize_value( + "credentials", '{"secret_value": "a"}', SecretConfiguration + ) assert secret_config.secret_value == "a" assert serialize_value(secret_config) == '{"secret_value":"a"}' @@ -803,7 +848,9 @@ def test_invalid_coercions(environment: Any) -> None: # overwrite with valid value and go to next env environment[key.upper()] = serialize_value(COERCIONS[key]) continue - raise AssertionError("%s was coerced with %s which is invalid type" % (key, value)) + raise AssertionError( + "%s was coerced with %s which is invalid type" % (key, value) + ) def test_excepted_coercions(environment: Any) -> None: @@ -811,7 +858,11 @@ def test_excepted_coercions(environment: Any) -> None: add_config_dict_to_env(COERCIONS) add_config_dict_to_env(EXCEPTED_COERCIONS, overwrite_keys=True) resolve._resolve_config_fields( - C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False + C, + explicit_values=None, + explicit_sections=(), + embedded_sections=(), + accept_partial=False, ) for key in EXCEPTED_COERCIONS: assert getattr(C, key) == COERCED_EXCEPTIONS[key] @@ -911,7 +962,10 @@ def test_is_valid_hint() -> None: # in case of generics, origin will be used and args are not checked assert is_valid_hint(MutableMapping[TSecretValue, Any]) is True # this is valid (args not checked) - assert is_valid_hint(MutableMapping[TSecretValue, ConfigValueCannotBeCoercedException]) is True + assert ( + is_valid_hint(MutableMapping[TSecretValue, ConfigValueCannotBeCoercedException]) + is True + ) assert is_valid_hint(Wei) is True # any class type, except deriving from BaseConfiguration is wrong type assert is_valid_hint(ConfigFieldMissingException) is False @@ -945,7 +999,9 @@ def test_secret_value_not_secret_provider(mock_provider: MockProvider) -> None: # anything derived from CredentialsConfiguration will fail with patch.object(SecretCredentials, "__section__", "credentials"): with pytest.raises(ValueNotSecretException) as py_ex: - resolve.resolve_configuration(WithCredentialsConfiguration(), sections=("mock",)) + resolve.resolve_configuration( + WithCredentialsConfiguration(), sections=("mock",) + ) assert py_ex.value.provider_name == "Mock Provider" assert py_ex.value.key == "-credentials" @@ -1027,7 +1083,13 @@ def test_resolved_trace(environment: Any) -> None: ) # value is before casting assert traces["instrumented.tube"] == ResolvedValueTrace( - "tube", '["tu", "u", "be"]', None, List[str], ["instrumented"], prov_name, c.instrumented + "tube", + '["tu", "u", "be"]', + None, + List[str], + ["instrumented"], + prov_name, + c.instrumented, ) assert deserialize_value( "tube", traces["instrumented.tube"].value, resolve.extract_inner_hint(List[str]) @@ -1053,7 +1115,9 @@ def test_resolved_trace(environment: Any) -> None: c = resolve.resolve_configuration(EmbeddedConfiguration()) resolve.resolve_configuration(InstrumentedConfiguration()) - assert traces[".default"] == ResolvedValueTrace("default", "UNDEF", None, str, [], prov_name, c) + assert traces[".default"] == ResolvedValueTrace( + "default", "UNDEF", None, str, [], prov_name, c + ) assert traces[".instrumented"] == ResolvedValueTrace( "instrumented", "h>t>t>t>he", None, InstrumentedConfiguration, [], prov_name, c ) @@ -1075,7 +1139,10 @@ def test_extract_inner_hint() -> None: # extracts new types assert resolve.extract_inner_hint(TSecretValue) is AnyType # preserves new types on extract - assert resolve.extract_inner_hint(TSecretValue, preserve_new_types=True) is TSecretValue + assert ( + resolve.extract_inner_hint(TSecretValue, preserve_new_types=True) + is TSecretValue + ) def test_is_secret_hint() -> None: @@ -1085,7 +1152,9 @@ def test_is_secret_hint() -> None: assert resolve.is_secret_hint(Optional[TSecretValue]) is True # type: ignore[arg-type] assert resolve.is_secret_hint(InstrumentedConfiguration) is False # do not recognize new types - TTestSecretNt = NewType("TTestSecretNt", GcpServiceAccountCredentialsWithoutDefaults) + TTestSecretNt = NewType( + "TTestSecretNt", GcpServiceAccountCredentialsWithoutDefaults + ) assert resolve.is_secret_hint(TTestSecretNt) is False # recognize unions with credentials assert resolve.is_secret_hint(Union[GcpServiceAccountCredentialsWithoutDefaults, StrAny, str]) is True # type: ignore[arg-type] @@ -1127,7 +1196,9 @@ def test_dynamic_type_hint_subclass(environment: Dict[str, str]) -> None: environment["DUMMY__DISCRIMINATOR"] = "c" environment["DUMMY__EMBEDDED_CONFIG__FIELD_FOR_C"] = "some_value" - config = resolve.resolve_configuration(SubclassConfigWithDynamicType(), sections=("dummy",)) + config = resolve.resolve_configuration( + SubclassConfigWithDynamicType(), sections=("dummy",) + ) assert isinstance(config.embedded_config, DynamicConfigC) assert config.embedded_config.field_for_c == "some_value" @@ -1136,7 +1207,9 @@ def test_dynamic_type_hint_subclass(environment: Dict[str, str]) -> None: environment["DUMMY__DISCRIMINATOR"] = "b" environment["DUMMY__EMBEDDED_CONFIG__FIELD_FOR_B"] = "some_value" - config = resolve.resolve_configuration(SubclassConfigWithDynamicType(), sections=("dummy",)) + config = resolve.resolve_configuration( + SubclassConfigWithDynamicType(), sections=("dummy",) + ) assert isinstance(config.embedded_config, DynamicConfigB) assert config.embedded_config.field_for_b == "some_value" @@ -1146,7 +1219,9 @@ def test_dynamic_type_hint_subclass(environment: Dict[str, str]) -> None: environment["DUMMY__DYNAMIC_TYPE_FIELD"] = "some" with pytest.raises(ConfigValueCannotBeCoercedException) as e: - config = resolve.resolve_configuration(SubclassConfigWithDynamicType(), sections=("dummy",)) + config = resolve.resolve_configuration( + SubclassConfigWithDynamicType(), sections=("dummy",) + ) assert e.value.field_name == "dynamic_type_field" assert e.value.hint == int @@ -1219,11 +1294,14 @@ def test_configuration_copy() -> None: # try credentials cred = ConnectionStringCredentials() - cred.parse_native_representation("postgresql://loader:loader@localhost:5432/dlt_data") + cred.parse_native_representation( + "postgresql://loader:loader@localhost:5432/dlt_data" + ) copy_cred = cred.copy() assert dict(copy_cred) == dict(cred) assert ( - copy_cred.to_native_representation() == "postgresql://loader:loader@localhost:5432/dlt_data" + copy_cred.to_native_representation() + == "postgresql://loader:loader@localhost:5432/dlt_data" ) # resolve the copy assert not copy_cred.is_resolved() @@ -1235,7 +1313,9 @@ def test_configuration_with_configuration_as_default() -> None: instrumented_default = InstrumentedConfiguration() instrumented_default.parse_native_representation("h>a>b>he") cred = ConnectionStringCredentials() - cred.parse_native_representation("postgresql://loader:loader@localhost:5432/dlt_data") + cred.parse_native_representation( + "postgresql://loader:loader@localhost:5432/dlt_data" + ) @configspec class EmbeddedConfigurationWithDefaults(BaseConfiguration): diff --git a/tests/common/configuration/test_container.py b/tests/common/configuration/test_container.py index eddd0b21dc..b600fa289a 100644 --- a/tests/common/configuration/test_container.py +++ b/tests/common/configuration/test_container.py @@ -73,7 +73,9 @@ def test_singleton(container: Container) -> None: @pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) -def test_container_items(container: Container, spec: Type[InjectableTestContext]) -> None: +def test_container_items( + container: Container, spec: Type[InjectableTestContext] +) -> None: # will add InjectableTestContext instance to container container[spec] assert spec in container @@ -150,7 +152,9 @@ def test_container_injectable_context_mangled( @pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) -def test_container_thread_affinity(container: Container, spec: Type[InjectableTestContext]) -> None: +def test_container_thread_affinity( + container: Container, spec: Type[InjectableTestContext] +) -> None: event = threading.Semaphore(0) thread_item: InjectableTestContext = None @@ -182,7 +186,9 @@ def _thread() -> None: @pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) -def test_container_pool_affinity(container: Container, spec: Type[InjectableTestContext]) -> None: +def test_container_pool_affinity( + container: Container, spec: Type[InjectableTestContext] +) -> None: event = threading.Semaphore(0) thread_item: InjectableTestContext = None @@ -194,7 +200,9 @@ def _thread() -> None: thread_item = container[spec] event.release() - threading.Thread(target=_thread, daemon=True, name=Container.thread_pool_prefix()).start() + threading.Thread( + target=_thread, daemon=True, name=Container.thread_pool_prefix() + ).start() event.acquire() # it may be or separate copy (InjectableTestContext) or single copy (GlobalTestContext) main_item = container[spec] @@ -214,7 +222,9 @@ def test_thread_pool_affinity(container: Container) -> None: def _context() -> InjectableTestContext: return container[InjectableTestContext] - main_item = container[InjectableTestContext] = InjectableTestContext(current_value="MAIN") + main_item = container[InjectableTestContext] = InjectableTestContext( + current_value="MAIN" + ) with ThreadPoolExecutor(thread_name_prefix=container.thread_pool_prefix()) as p: future = p.submit(_context) @@ -231,7 +241,9 @@ def _context() -> InjectableTestContext: @pytest.mark.parametrize("spec", (InjectableTestContext, GlobalTestContext)) -def test_container_provider(container: Container, spec: Type[InjectableTestContext]) -> None: +def test_container_provider( + container: Container, spec: Type[InjectableTestContext] +) -> None: provider = ContextProvider() # default value will be created v, k = provider.get_value("n/a", spec, None) @@ -261,9 +273,13 @@ def test_container_provider(container: Container, spec: Type[InjectableTestConte assert k == "typing.Literal['a']" -def test_container_provider_embedded_inject(container: Container, environment: Any) -> None: +def test_container_provider_embedded_inject( + container: Container, environment: Any +) -> None: environment["INJECTED"] = "unparsable" - with container.injectable_context(InjectableTestContext(current_value="Embed")) as injected: + with container.injectable_context( + InjectableTestContext(current_value="Embed") + ) as injected: # must have top precedence - over the environ provider. environ provider is returning a value that will cannot be parsed # but the container provider has a precedence and the lookup in environ provider will never happen C = resolve_configuration(EmbeddedWithInjectableContext()) diff --git a/tests/common/configuration/test_credentials.py b/tests/common/configuration/test_credentials.py index 7c184c16e5..39385987f7 100644 --- a/tests/common/configuration/test_credentials.py +++ b/tests/common/configuration/test_credentials.py @@ -69,7 +69,9 @@ def test_connection_string_credentials_native_representation(environment) -> Non ConnectionStringCredentials().parse_native_representation(1) with pytest.raises(InvalidConnectionString): - ConnectionStringCredentials().parse_native_representation("loader@localhost:5432/dlt_data") + ConnectionStringCredentials().parse_native_representation( + "loader@localhost:5432/dlt_data" + ) dsn = "postgres://loader:pass@localhost:5432/dlt_data?a=b&c=d" csc = ConnectionStringCredentials() @@ -111,7 +113,9 @@ def test_connection_string_letter_case(environment: Any) -> None: assert csc.to_native_representation() == dsn -def test_connection_string_resolved_from_native_representation(environment: Any) -> None: +def test_connection_string_resolved_from_native_representation( + environment: Any, +) -> None: destination_dsn = "mysql+pymsql://localhost:5432/dlt_data" c = ConnectionStringCredentials() c.parse_native_representation(destination_dsn) @@ -142,7 +146,9 @@ def test_connection_string_resolved_from_native_representation(environment: Any) assert c.password == "pwd" -def test_connection_string_resolved_from_native_representation_env(environment: Any) -> None: +def test_connection_string_resolved_from_native_representation_env( + environment: Any, +) -> None: environment["CREDENTIALS"] = "mysql+pymsql://USER@/dlt_data" c = resolve_configuration(ConnectionStringCredentials()) assert not c.is_partial() @@ -159,7 +165,9 @@ def test_connection_string_resolved_from_native_representation_env(environment: def test_connection_string_from_init() -> None: - c = ConnectionStringCredentials("postgres://loader:pass@localhost:5432/dlt_data?a=b&c=d") + c = ConnectionStringCredentials( + "postgres://loader:pass@localhost:5432/dlt_data?a=b&c=d" + ) assert c.drivername == "postgres" assert c.is_resolved() assert not c.is_partial() @@ -198,9 +206,12 @@ def test_gcp_service_credentials_native_representation(environment) -> None: gcpc = GcpServiceAccountCredentials() gcpc.parse_native_representation( SERVICE_JSON - % '"private_key": "-----BEGIN PRIVATE KEY-----\\n\\n-----END PRIVATE KEY-----\\n",' + % '"private_key": "-----BEGIN PRIVATE KEY-----\\n\\n-----END PRIVATE' + ' KEY-----\\n",' + ) + assert ( + gcpc.private_key == "-----BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY-----\n" ) - assert gcpc.private_key == "-----BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY-----\n" assert gcpc.project_id == "chat-analytics" assert gcpc.client_email == "loader@iam.gserviceaccount.com" # location is present but deprecated @@ -219,7 +230,9 @@ def test_gcp_service_credentials_native_representation(environment) -> None: assert gcpc_2.default_credentials() is None -def test_gcp_service_credentials_resolved_from_native_representation(environment: Any) -> None: +def test_gcp_service_credentials_resolved_from_native_representation( + environment: Any, +) -> None: gcpc = GcpServiceAccountCredentialsWithoutDefaults() # without PK @@ -242,7 +255,9 @@ def test_gcp_oauth_credentials_native_representation(environment) -> None: GcpOAuthCredentials().parse_native_representation("notjson") gcoauth = GcpOAuthCredentials() - gcoauth.parse_native_representation(OAUTH_APP_USER_INFO % '"refresh_token": "refresh_token",') + gcoauth.parse_native_representation( + OAUTH_APP_USER_INFO % '"refresh_token": "refresh_token",' + ) # is not resolved, we resolve only when default credentials are present assert gcoauth.is_resolved() is False # but is not partial - all required fields are present @@ -272,11 +287,15 @@ def test_gcp_oauth_credentials_native_representation(environment) -> None: # use OAUTH_USER_INFO without "installed" gcpc_3 = GcpOAuthCredentials() - gcpc_3.parse_native_representation(OAUTH_USER_INFO % '"refresh_token": "refresh_token",') + gcpc_3.parse_native_representation( + OAUTH_USER_INFO % '"refresh_token": "refresh_token",' + ) assert dict(gcpc_3) == dict(gcpc_2) -def test_gcp_oauth_credentials_resolved_from_native_representation(environment: Any) -> None: +def test_gcp_oauth_credentials_resolved_from_native_representation( + environment: Any, +) -> None: gcpc = GcpOAuthCredentialsWithoutDefaults() # without refresh token @@ -328,7 +347,9 @@ def test_run_configuration_slack_credentials(environment: Any) -> None: assert c.slack_incoming_hook == hook # and obfuscated-like but really not - environment["RUNTIME__SLACK_INCOMING_HOOK"] = "DBgAXQFPQVsAAEteXlFRWUoPG0BdHQ-EbAg==" + environment["RUNTIME__SLACK_INCOMING_HOOK"] = ( + "DBgAXQFPQVsAAEteXlFRWUoPG0BdHQ-EbAg==" + ) c = resolve_configuration(RunConfiguration()) assert c.slack_incoming_hook == "DBgAXQFPQVsAAEteXlFRWUoPG0BdHQ-EbAg==" diff --git a/tests/common/configuration/test_environ_provider.py b/tests/common/configuration/test_environ_provider.py index 0608ea1d7a..1b76ba326f 100644 --- a/tests/common/configuration/test_environ_provider.py +++ b/tests/common/configuration/test_environ_provider.py @@ -12,7 +12,11 @@ from dlt.common.configuration.providers import environ as environ_provider from tests.utils import preserve_environ -from tests.common.configuration.utils import WrongConfiguration, SecretConfiguration, environment +from tests.common.configuration.utils import ( + WrongConfiguration, + SecretConfiguration, + environment, +) @configspec @@ -37,7 +41,11 @@ def test_resolves_from_environ(environment: Any) -> None: C = WrongConfiguration() resolve._resolve_config_fields( - C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False + C, + explicit_values=None, + explicit_sections=(), + embedded_sections=(), + accept_partial=False, ) assert not C.is_partial() @@ -49,7 +57,11 @@ def test_resolves_from_environ_with_coercion(environment: Any) -> None: C = SimpleRunConfiguration() resolve._resolve_config_fields( - C, explicit_values=None, explicit_sections=(), embedded_sections=(), accept_partial=False + C, + explicit_values=None, + explicit_sections=(), + embedded_sections=(), + accept_partial=False, ) assert not C.is_partial() @@ -102,9 +114,13 @@ def test_secret_kube_fallback(environment: Any) -> None: def test_configuration_files(environment: Any) -> None: # overwrite config file paths - environment["RUNTIME__CONFIG_FILES_STORAGE_PATH"] = "./tests/common/cases/schemas/ev1/" + environment["RUNTIME__CONFIG_FILES_STORAGE_PATH"] = ( + "./tests/common/cases/schemas/ev1/" + ) C = resolve.resolve_configuration(MockProdRunConfigurationVar()) - assert C.config_files_storage_path == environment["RUNTIME__CONFIG_FILES_STORAGE_PATH"] + assert ( + C.config_files_storage_path == environment["RUNTIME__CONFIG_FILES_STORAGE_PATH"] + ) assert C.has_configuration_file("hasn't") is False assert C.has_configuration_file("event.schema.json") is True assert ( diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index 1aa52c1919..f2fe79798d 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -21,7 +21,9 @@ ConnectionStringCredentials, ) from dlt.common.configuration.specs.base_configuration import configspec, is_secret_hint -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.reflection.spec import _get_spec_name_from_f from dlt.common.typing import StrAny, TSecretValue, is_newtype_type @@ -94,7 +96,9 @@ def test_inject_from_argument_section(toml_providers: ConfigProvidersContext) -> # `gcp_storage` is a key in `secrets.toml` and the default `credentials` section of GcpServiceAccountCredentialsWithoutDefaults must be replaced with it @with_config - def f_credentials(gcp_storage: GcpServiceAccountCredentialsWithoutDefaults = dlt.secrets.value): + def f_credentials( + gcp_storage: GcpServiceAccountCredentialsWithoutDefaults = dlt.secrets.value, + ): # unique project name assert gcp_storage.project_id == "mock-project-id-gcp-storage" @@ -104,7 +108,9 @@ def f_credentials(gcp_storage: GcpServiceAccountCredentialsWithoutDefaults = dlt def test_inject_secret_value_secret_type(environment: Any) -> None: @with_config def f_custom_secret_type( - _dict: Dict[str, Any] = dlt.secrets.value, _int: int = dlt.secrets.value, **kwargs: Any + _dict: Dict[str, Any] = dlt.secrets.value, + _int: int = dlt.secrets.value, + **kwargs: Any, ): # secret values were coerced into types assert _dict == {"a": 1} @@ -192,7 +198,9 @@ def test_sections(value=dlt.config.value): return value # a section context that prefers existing context - @with_config(sections=("test",), sections_merge_style=ConfigSectionContext.prefer_existing) + @with_config( + sections=("test",), sections_merge_style=ConfigSectionContext.prefer_existing + ) def test_sections_pref_existing(value=dlt.config.value): return value @@ -409,12 +417,16 @@ def test_use_most_specific_union_type( ) -> None: @with_config def postgres_union( - local_credentials: Union[ConnectionStringCredentials, str, StrAny] = dlt.secrets.value + local_credentials: Union[ + ConnectionStringCredentials, str, StrAny + ] = dlt.secrets.value ): return local_credentials @with_config - def postgres_direct(local_credentials: ConnectionStringCredentials = dlt.secrets.value): + def postgres_direct( + local_credentials: ConnectionStringCredentials = dlt.secrets.value, + ): return local_credentials conn_str = "postgres://loader:loader@localhost:5432/dlt_data" @@ -497,6 +509,11 @@ def stuff_test(pos_par, /, kw_par) -> None: # synthesized spec present in current module assert "TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration" in globals() # instantiate - C: BaseConfiguration = globals()["TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration"]() + C: BaseConfiguration = globals()[ + "TestAutoDerivedSpecTypeNameAutoNameTestInitConfiguration" + ]() # pos_par converted to secrets, kw_par converted to optional - assert C.get_resolvable_fields() == {"pos_par": TSecretValue, "kw_par": Optional[Any]} + assert C.get_resolvable_fields() == { + "pos_par": TSecretValue, + "kw_par": Optional[Any], + } diff --git a/tests/common/configuration/test_sections.py b/tests/common/configuration/test_sections.py index bf6780e087..f0de56424c 100644 --- a/tests/common/configuration/test_sections.py +++ b/tests/common/configuration/test_sections.py @@ -54,7 +54,9 @@ class EmbeddedWithIgnoredEmbeddedConfiguration(BaseConfiguration): ignored_embedded: EmbeddedIgnoredWithSectionedConfiguration = None -def test_sectioned_configuration(environment: Any, env_provider: ConfigProvider) -> None: +def test_sectioned_configuration( + environment: Any, env_provider: ConfigProvider +) -> None: with pytest.raises(ConfigFieldMissingException) as exc_val: resolve.resolve_configuration(SectionedConfiguration()) @@ -71,7 +73,9 @@ def test_sectioned_configuration(environment: Any, env_provider: ConfigProvider) # assert traces[2] == LookupTrace("config.toml", ["DLT_TEST"], "DLT_TEST.password", None) # init vars work without section - C = resolve.resolve_configuration(SectionedConfiguration(), explicit_value={"password": "PASS"}) + C = resolve.resolve_configuration( + SectionedConfiguration(), explicit_value={"password": "PASS"} + ) assert C.password == "PASS" # env var must be prefixed @@ -119,7 +123,12 @@ def test_explicit_sections_with_sectioned_config(mock_provider: MockProvider) -> # sectioned config is always innermost mock_provider.reset_stats() resolve.resolve_configuration(SectionedConfiguration(), sections=("ns1",)) - assert mock_provider.last_sections == [("ns1",), (), ("ns1", "DLT_TEST"), ("DLT_TEST",)] + assert mock_provider.last_sections == [ + ("ns1",), + (), + ("ns1", "DLT_TEST"), + ("DLT_TEST",), + ] mock_provider.reset_stats() resolve.resolve_configuration(SectionedConfiguration(), sections=("ns1", "ns2")) assert mock_provider.last_sections == [ @@ -218,7 +227,9 @@ def test_injected_sections(mock_provider: MockProvider) -> None: assert mock_provider.last_sections == [("inj-ns1", "sv_config"), ("sv_config",)] # multiple injected sections - with container.injectable_context(ConfigSectionContext(sections=("inj-ns1", "inj-ns2"))): + with container.injectable_context( + ConfigSectionContext(sections=("inj-ns1", "inj-ns2")) + ): mock_provider.reset_stats() resolve.resolve_configuration(SingleValConfiguration()) assert mock_provider.last_sections == [("inj-ns1", "inj-ns2"), ("inj-ns1",), ()] @@ -235,7 +246,10 @@ def test_section_context() -> None: with pytest.raises(ValueError): ConfigSectionContext(sections=("sources", "modules")).source_name() - assert ConfigSectionContext(sections=("sources", "modules", "func")).source_name() == "func" + assert ( + ConfigSectionContext(sections=("sources", "modules", "func")).source_name() + == "func" + ) # TODO: test merge functions @@ -272,7 +286,12 @@ def test_section_with_pipeline_name(mock_provider: MockProvider) -> None: mock_provider.reset_stats() resolve.resolve_configuration(SectionedConfiguration()) # first the whole SectionedConfiguration is looked under key DLT_TEST (sections: ('PIPE',), ()), then fields of SectionedConfiguration - assert mock_provider.last_sections == [("PIPE",), (), ("PIPE", "DLT_TEST"), ("DLT_TEST",)] + assert mock_provider.last_sections == [ + ("PIPE",), + (), + ("PIPE", "DLT_TEST"), + ("DLT_TEST",), + ] # with pipeline and injected sections with container.injectable_context( @@ -281,7 +300,12 @@ def test_section_with_pipeline_name(mock_provider: MockProvider) -> None: mock_provider.return_value_on = () mock_provider.reset_stats() resolve.resolve_configuration(SingleValConfiguration()) - assert mock_provider.last_sections == [("PIPE", "inj-ns1"), ("PIPE",), ("inj-ns1",), ()] + assert mock_provider.last_sections == [ + ("PIPE", "inj-ns1"), + ("PIPE",), + ("inj-ns1",), + (), + ] # def test_sections_with_duplicate(mock_provider: MockProvider) -> None: @@ -304,15 +328,27 @@ def test_section_with_pipeline_name(mock_provider: MockProvider) -> None: def test_inject_section(mock_provider: MockProvider) -> None: mock_provider.value = "value" - with inject_section(ConfigSectionContext(pipeline_name="PIPE", sections=("inj-ns1",))): + with inject_section( + ConfigSectionContext(pipeline_name="PIPE", sections=("inj-ns1",)) + ): resolve.resolve_configuration(SingleValConfiguration()) - assert mock_provider.last_sections == [("PIPE", "inj-ns1"), ("PIPE",), ("inj-ns1",), ()] + assert mock_provider.last_sections == [ + ("PIPE", "inj-ns1"), + ("PIPE",), + ("inj-ns1",), + (), + ] # inject with merge previous with inject_section(ConfigSectionContext(sections=("inj-ns2",))): mock_provider.reset_stats() resolve.resolve_configuration(SingleValConfiguration()) - assert mock_provider.last_sections == [("PIPE", "inj-ns2"), ("PIPE",), ("inj-ns2",), ()] + assert mock_provider.last_sections == [ + ("PIPE", "inj-ns2"), + ("PIPE",), + ("inj-ns2",), + (), + ] # inject without merge mock_provider.reset_stats() diff --git a/tests/common/configuration/test_spec_union.py b/tests/common/configuration/test_spec_union.py index b1e316734d..7e20713c39 100644 --- a/tests/common/configuration/test_spec_union.py +++ b/tests/common/configuration/test_spec_union.py @@ -4,13 +4,18 @@ from typing import Optional, Union, Any import dlt -from dlt.common.configuration.exceptions import InvalidNativeValue, ConfigFieldMissingException +from dlt.common.configuration.exceptions import ( + InvalidNativeValue, + ConfigFieldMissingException, +) from dlt.common.configuration.providers import EnvironProvider from dlt.common.configuration.specs import CredentialsConfiguration, BaseConfiguration from dlt.common.configuration import configspec, resolve_configuration from dlt.common.configuration.specs.gcp_credentials import GcpServiceAccountCredentials from dlt.common.typing import TSecretValue -from dlt.common.configuration.specs.connection_string_credentials import ConnectionStringCredentials +from dlt.common.configuration.specs.connection_string_credentials import ( + ConnectionStringCredentials, +) from dlt.common.configuration.resolve import initialize_credentials from dlt.common.configuration.specs.exceptions import NativeValueError @@ -163,7 +168,9 @@ def test_union_decorator() -> None: # this will generate equivalent of ZenConfig @dlt.source def zen_source( - credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials, str] = dlt.secrets.value, + credentials: Union[ + ZenApiKeyCredentials, ZenEmailCredentials, str + ] = dlt.secrets.value, some_option: bool = False, ): # depending on what the user provides in config, ZenApiKeyCredentials or ZenEmailCredentials will be injected in credentials @@ -240,7 +247,9 @@ class Engine: @dlt.source -def sql_database(credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value): +def sql_database( + credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value +): yield dlt.resource([credentials], name="creds") @@ -274,7 +283,9 @@ def test_initialize_credentials(environment: Any) -> None: assert not zen_cred.is_resolved() zen_cred = initialize_credentials(ZenEmailCredentials, "email:rfix:pass") assert zen_cred.is_resolved() - zen_cred = initialize_credentials(ZenEmailCredentials, {"email": "rfix", "password": "pass"}) + zen_cred = initialize_credentials( + ZenEmailCredentials, {"email": "rfix", "password": "pass"} + ) assert zen_cred.is_resolved() with pytest.raises(NativeValueError): initialize_credentials(ZenEmailCredentials, "email") @@ -288,7 +299,9 @@ def test_initialize_credentials(environment: Any) -> None: assert isinstance(zen_cred, ZenEmailCredentials) assert zen_cred.is_resolved() # resolve from dict - zen_cred = initialize_credentials(ZenUnion, {"api_key": "key", "api_secret": "secret"}) + zen_cred = initialize_credentials( + ZenUnion, {"api_key": "key", "api_secret": "secret"} + ) assert isinstance(zen_cred, ZenApiKeyCredentials) assert zen_cred.is_resolved() # does not fit any native format diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index 4f2219716a..6a624b441f 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -19,7 +19,9 @@ StringTomlProvider, TomlProviderReadException, ) -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.configuration.specs import ( BaseConfiguration, GcpServiceAccountCredentialsWithoutDefaults, @@ -71,7 +73,9 @@ def test_secrets_from_toml_secrets(toml_providers: ConfigProvidersContext) -> No def test_toml_types(toml_providers: ConfigProvidersContext) -> None: # resolve CoercionTestConfiguration from typecheck section - c = resolve.resolve_configuration(CoercionTestConfiguration(), sections=("typecheck",)) + c = resolve.resolve_configuration( + CoercionTestConfiguration(), sections=("typecheck",) + ) for k, v in COERCIONS.items(): # toml does not know tuples if isinstance(v, tuple): @@ -81,7 +85,9 @@ def test_toml_types(toml_providers: ConfigProvidersContext) -> None: assert v == c[k] -def test_config_provider_order(toml_providers: ConfigProvidersContext, environment: Any) -> None: +def test_config_provider_order( + toml_providers: ConfigProvidersContext, environment: Any +) -> None: # add env provider @with_config(sections=("api",)) @@ -125,18 +131,25 @@ def test_toml_sections(toml_providers: ConfigProvidersContext) -> None: cfg = toml_providers["config.toml"] assert cfg.get_value("api_type", str, None) == ("REST", "api_type") assert cfg.get_value("port", int, None, "api") == (1024, "api.port") - assert cfg.get_value("param1", str, None, "api", "params") == ("a", "api.params.param1") + assert cfg.get_value("param1", str, None, "api", "params") == ( + "a", + "api.params.param1", + ) -def test_secrets_toml_credentials(environment: Any, toml_providers: ConfigProvidersContext) -> None: +def test_secrets_toml_credentials( + environment: Any, toml_providers: ConfigProvidersContext +) -> None: # there are credentials exactly under destination.bigquery.credentials c = resolve.resolve_configuration( - GcpServiceAccountCredentialsWithoutDefaults(), sections=("destination", "bigquery") + GcpServiceAccountCredentialsWithoutDefaults(), + sections=("destination", "bigquery"), ) assert c.project_id.endswith("destination.bigquery.credentials") # there are no destination.gcp_storage.credentials so it will fallback to "destination"."credentials" c = resolve.resolve_configuration( - GcpServiceAccountCredentialsWithoutDefaults(), sections=("destination", "gcp_storage") + GcpServiceAccountCredentialsWithoutDefaults(), + sections=("destination", "gcp_storage"), ) assert c.project_id.endswith("destination.credentials") # also explicit @@ -146,7 +159,13 @@ def test_secrets_toml_credentials(environment: Any, toml_providers: ConfigProvid assert c.project_id.endswith("destination.credentials") # there's "credentials" key but does not contain valid gcp credentials with pytest.raises(ConfigFieldMissingException): - print(dict(resolve.resolve_configuration(GcpServiceAccountCredentialsWithoutDefaults()))) + print( + dict( + resolve.resolve_configuration( + GcpServiceAccountCredentialsWithoutDefaults() + ) + ) + ) # also try postgres credentials c2 = ConnectionStringCredentials() c2.update({"drivername": "postgres"}) @@ -222,7 +241,9 @@ def test_secrets_toml_credentials_from_native_repr( # but project id got overridden from credentials.project_id assert c.project_id.endswith("-credentials") # also try sql alchemy url (native repr) - c2 = resolve.resolve_configuration(ConnectionStringCredentials(), sections=("databricks",)) + c2 = resolve.resolve_configuration( + ConnectionStringCredentials(), sections=("databricks",) + ) assert c2.drivername == "databricks+connector" assert c2.username == "token" assert c2.password == "" @@ -292,11 +313,19 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: # set single key provider.set_value("_new_key_bool", True, None) TAny: Type[Any] = Any # type: ignore[assignment] - assert provider.get_value("_new_key_bool", TAny, None) == (True, "_new_key_bool") + assert provider.get_value("_new_key_bool", TAny, None) == ( + True, + "_new_key_bool", + ) provider.set_value("_new_key_literal", TSecretValue("literal"), None) - assert provider.get_value("_new_key_literal", TAny, None) == ("literal", "_new_key_literal") + assert provider.get_value("_new_key_literal", TAny, None) == ( + "literal", + "_new_key_literal", + ) # this will create path of tables - provider.set_value("deep_int", 2137, "deep_pipeline", "deep", "deep", "deep", "deep") + provider.set_value( + "deep_int", 2137, "deep_pipeline", "deep", "deep", "deep", "deep" + ) assert provider._toml["deep_pipeline"]["deep"]["deep"]["deep"]["deep"]["deep_int"] == 2137 # type: ignore[index] assert provider.get_value( "deep_int", TAny, "deep_pipeline", "deep", "deep", "deep", "deep" @@ -304,7 +333,9 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: # same without the pipeline now = pendulum.now() provider.set_value("deep_date", now, None, "deep", "deep", "deep", "deep") - assert provider.get_value("deep_date", TAny, None, "deep", "deep", "deep", "deep") == ( + assert provider.get_value( + "deep_date", TAny, None, "deep", "deep", "deep", "deep" + ) == ( now, "deep.deep.deep.deep.deep_date", ) @@ -315,7 +346,9 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: "deep.deep.deep.deep_list", ) # still there - assert provider.get_value("deep_date", TAny, None, "deep", "deep", "deep", "deep") == ( + assert provider.get_value( + "deep_date", TAny, None, "deep", "deep", "deep", "deep" + ) == ( now, "deep.deep.deep.deep.deep_date", ) @@ -327,7 +360,9 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: ) # invalid type with pytest.raises(ValueError): - provider.set_value("deep_decimal", Decimal("1.2"), None, "deep", "deep", "deep", "deep") + provider.set_value( + "deep_decimal", Decimal("1.2"), None, "deep", "deep", "deep", "deep" + ) # write new dict to a new key test_d1 = {"key": "top", "embed": {"inner": "bottom", "inner_2": True}} @@ -343,7 +378,9 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: "dict_test.deep_dict", ) # get a fragment - assert provider.get_value("inner_2", TAny, None, "dict_test", "deep_dict", "embed") == ( + assert provider.get_value( + "inner_2", TAny, None, "dict_test", "deep_dict", "embed" + ) == ( True, "dict_test.deep_dict.embed.inner_2", ) @@ -354,7 +391,11 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: "deep.deep.deep.deep_list", ) # merge dicts - test_d2 = {"key": "_top", "key2": "new2", "embed": {"inner": "_bottom", "inner_3": 2121}} + test_d2 = { + "key": "_top", + "key2": "new2", + "embed": {"inner": "_bottom", "inner_3": 2121}, + } provider.set_value("deep_dict", test_d2, None, "dict_test") test_m_d1_d2 = { "key": "_top", diff --git a/tests/common/configuration/utils.py b/tests/common/configuration/utils.py index 670dcac87a..1ff0f343fb 100644 --- a/tests/common/configuration/utils.py +++ b/tests/common/configuration/utils.py @@ -27,7 +27,9 @@ SecretsTomlProvider, ) from dlt.common.configuration.utils import get_resolved_traces -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.typing import TSecretValue, StrAny diff --git a/tests/common/data_writers/test_data_writers.py b/tests/common/data_writers/test_data_writers.py index ac4f118229..3ab8a3ca27 100644 --- a/tests/common/data_writers/test_data_writers.py +++ b/tests/common/data_writers/test_data_writers.py @@ -30,7 +30,11 @@ from tests.common.utils import load_json_case, row_to_column_schemas -ALL_LITERAL_ESCAPE = [escape_redshift_literal, escape_postgres_literal, escape_duckdb_literal] +ALL_LITERAL_ESCAPE = [ + escape_redshift_literal, + escape_postgres_literal, + escape_duckdb_literal, +] class _StringIOWriter(DataWriter): @@ -110,9 +114,18 @@ def test_unicode_insert_writer(insert_writer: _StringIOWriter) -> None: def test_string_literal_escape() -> None: - assert escape_redshift_literal(", NULL'); DROP TABLE --") == "', NULL''); DROP TABLE --'" - assert escape_redshift_literal(", NULL');\n DROP TABLE --") == "', NULL'');\\n DROP TABLE --'" - assert escape_redshift_literal(", NULL);\n DROP TABLE --") == "', NULL);\\n DROP TABLE --'" + assert ( + escape_redshift_literal(", NULL'); DROP TABLE --") + == "', NULL''); DROP TABLE --'" + ) + assert ( + escape_redshift_literal(", NULL');\n DROP TABLE --") + == "', NULL'');\\n DROP TABLE --'" + ) + assert ( + escape_redshift_literal(", NULL);\n DROP TABLE --") + == "', NULL);\\n DROP TABLE --'" + ) assert ( escape_redshift_literal(", NULL);\\n DROP TABLE --\\") == "', NULL);\\\\n DROP TABLE --\\\\'" @@ -153,7 +166,10 @@ def test_identifier_escape_bigquery() -> None: def test_string_literal_escape_unicode() -> None: # test on some unicode characters - assert escape_redshift_literal(", NULL);\n DROP TABLE --") == "', NULL);\\n DROP TABLE --'" + assert ( + escape_redshift_literal(", NULL);\n DROP TABLE --") + == "', NULL);\\n DROP TABLE --'" + ) assert ( escape_redshift_literal("イロハニホヘト チリヌルヲ ワカヨタレソ ツネナラム") == "'イロハニホヘト チリヌルヲ ワカヨタレソ ツネナラム'" @@ -171,9 +187,9 @@ def test_data_writer_metrics_add() -> None: add_m: DataWriterMetrics = metrics + EMPTY_DATA_WRITER_METRICS # type: ignore[assignment] assert add_m == DataWriterMetrics("", 10, 100, now, now + 10) assert metrics + metrics == DataWriterMetrics("", 20, 200, now, now + 10) - assert sum((metrics, metrics, metrics), EMPTY_DATA_WRITER_METRICS) == DataWriterMetrics( - "", 30, 300, now, now + 10 - ) + assert sum( + (metrics, metrics, metrics), EMPTY_DATA_WRITER_METRICS + ) == DataWriterMetrics("", 30, 300, now, now + 10) # time range extends when added add_m = metrics + DataWriterMetrics("file", 99, 120, now - 10, now + 20) # type: ignore[assignment] assert add_m == DataWriterMetrics("", 109, 220, now - 10, now + 20) diff --git a/tests/common/normalizers/custom_normalizers.py b/tests/common/normalizers/custom_normalizers.py index 3ae65c8b53..b8a04c633f 100644 --- a/tests/common/normalizers/custom_normalizers.py +++ b/tests/common/normalizers/custom_normalizers.py @@ -1,6 +1,10 @@ from dlt.common.normalizers.json import TNormalizedRowIterator -from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention +from dlt.common.normalizers.json.relational import ( + DataItemNormalizer as RelationalNormalizer, +) +from dlt.common.normalizers.naming.snake_case import ( + NamingConvention as SnakeCaseNamingConvention, +) from dlt.common.typing import TDataItem diff --git a/tests/common/normalizers/test_import_normalizers.py b/tests/common/normalizers/test_import_normalizers.py index df6b973943..ff4c419af3 100644 --- a/tests/common/normalizers/test_import_normalizers.py +++ b/tests/common/normalizers/test_import_normalizers.py @@ -5,10 +5,15 @@ from dlt.common.configuration.container import Container from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.normalizers import explicit_normalizers, import_normalizers -from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer +from dlt.common.normalizers.json.relational import ( + DataItemNormalizer as RelationalNormalizer, +) from dlt.common.normalizers.naming import snake_case from dlt.common.normalizers.naming import direct -from dlt.common.normalizers.naming.exceptions import InvalidNamingModule, UnknownNamingModule +from dlt.common.normalizers.naming.exceptions import ( + InvalidNamingModule, + UnknownNamingModule, +) from tests.common.normalizers.custom_normalizers import ( DataItemNormalizer as CustomRelationalNormalizer, diff --git a/tests/common/normalizers/test_json_relational.py b/tests/common/normalizers/test_json_relational.py index 502ce619dd..7adf4f3c21 100644 --- a/tests/common/normalizers/test_json_relational.py +++ b/tests/common/normalizers/test_json_relational.py @@ -75,7 +75,9 @@ def test_preserve_complex_value(norm: RelationalNormalizer) -> None: def test_preserve_complex_value_with_hint(norm: RelationalNormalizer) -> None: # add preferred type for "value" - norm.schema._settings.setdefault("preferred_types", {})[TSimpleRegex("re:^value$")] = "complex" + norm.schema._settings.setdefault("preferred_types", {})[ + TSimpleRegex("re:^value$") + ] = "complex" norm.schema._compile_settings() row_1 = {"value": 1} @@ -137,7 +139,9 @@ def test_child_table_linking(norm: RelationalNormalizer) -> None: def test_child_table_linking_primary_key(norm: RelationalNormalizer) -> None: row = { "id": "level0", - "f": [{"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]}], + "f": [ + {"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]} + ], } norm.schema.merge_hints({"primary_key": [TSimpleRegex("id")]}) norm.schema._compile_settings() @@ -158,18 +162,24 @@ def test_child_table_linking_primary_key(norm: RelationalNormalizer) -> None: list_rows = [t for t in rows if t[0][0] == "table__f__l"] assert all( - e[1]["_dlt_parent_id"] != digest128("level1", DLT_ID_LENGTH_BYTES) for e in list_rows + e[1]["_dlt_parent_id"] != digest128("level1", DLT_ID_LENGTH_BYTES) + for e in list_rows ) assert all(r[0][1] == "table__f" for r in list_rows) obj_rows = [t for t in rows if t[0][0] == "table__f__o"] - assert all(e[1]["_dlt_parent_id"] != digest128("level1", DLT_ID_LENGTH_BYTES) for e in obj_rows) + assert all( + e[1]["_dlt_parent_id"] != digest128("level1", DLT_ID_LENGTH_BYTES) + for e in obj_rows + ) assert all(r[0][1] == "table__f" for r in obj_rows) def test_yields_parents_first(norm: RelationalNormalizer) -> None: row = { "id": "level0", - "f": [{"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]}], + "f": [ + {"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]} + ], "g": [{"id": "level2_g", "l": ["a"]}], } rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] @@ -277,7 +287,9 @@ def test_yields_parent_relation(norm: RelationalNormalizer) -> None: def test_list_position(norm: RelationalNormalizer) -> None: row: StrAny = { - "f": [{"l": ["a", "b", "c"], "v": 120, "lo": [{"e": "a"}, {"e": "b"}, {"e": "c"}]}] + "f": [ + {"l": ["a", "b", "c"], "v": 120, "lo": [{"e": "a"}, {"e": "b"}, {"e": "c"}]} + ] } rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] # root has no pos @@ -290,12 +302,16 @@ def test_list_position(norm: RelationalNormalizer) -> None: # f_l must be ordered as it appears in the list for pos, elem in enumerate(["a", "b", "c"]): - row = next(t[1] for t in rows if t[0][0] == "table__f__l" and t[1]["value"] == elem) + row = next( + t[1] for t in rows if t[0][0] == "table__f__l" and t[1]["value"] == elem + ) assert row["_dlt_list_idx"] == pos # f_lo must be ordered - list of objects for pos, elem in enumerate(["a", "b", "c"]): - row = next(t[1] for t in rows if t[0][0] == "table__f__lo" and t[1]["e"] == elem) + row = next( + t[1] for t in rows if t[0][0] == "table__f__lo" and t[1]["e"] == elem + ) assert row["_dlt_list_idx"] == pos @@ -315,7 +331,13 @@ def test_list_position(norm: RelationalNormalizer) -> None: def test_control_descending(norm: RelationalNormalizer) -> None: row: StrAny = { - "f": [{"l": ["a", "b", "c"], "v": 120, "lo": [[{"e": "a"}, {"e": "b"}, {"e": "c"}]]}], + "f": [ + { + "l": ["a", "b", "c"], + "v": 120, + "lo": [[{"e": "a"}, {"e": "b"}, {"e": "c"}]], + } + ], "g": "val", } @@ -369,8 +391,14 @@ def test_list_in_list() -> None: "ended_at": "2023-05-12T13:14:32Z", "webpath": [ [ - {"url": "https://www.website.com/", "timestamp": "2023-05-12T12:35:01Z"}, - {"url": "https://www.website.com/products", "timestamp": "2023-05-12T12:38:45Z"}, + { + "url": "https://www.website.com/", + "timestamp": "2023-05-12T12:35:01Z", + }, + { + "url": "https://www.website.com/products", + "timestamp": "2023-05-12T12:38:45Z", + }, { "url": "https://www.website.com/products/item123", "timestamp": "2023-05-12T12:42:22Z", @@ -405,15 +433,22 @@ def test_list_in_list() -> None: assert len(zen__webpath__list) == 7 assert zen__webpath__list[0][1]["_dlt_parent_id"] == zen__webpath[0][1]["_dlt_id"] # 4th list is itself a list - zen__webpath__list__list = [row for row in rows if row[0][0] == "zen__webpath__list__list"] - assert zen__webpath__list__list[0][1]["_dlt_parent_id"] == zen__webpath__list[3][1]["_dlt_id"] + zen__webpath__list__list = [ + row for row in rows if row[0][0] == "zen__webpath__list__list" + ] + assert ( + zen__webpath__list__list[0][1]["_dlt_parent_id"] + == zen__webpath__list[3][1]["_dlt_id"] + ) # test the same setting webpath__list to complex zen_table = new_table("zen") schema.update_table(zen_table) path_table = new_table( - "zen__webpath", parent_table_name="zen", columns=[{"name": "list", "data_type": "complex"}] + "zen__webpath", + parent_table_name="zen", + columns=[{"name": "list", "data_type": "complex"}], ) schema.update_table(path_table) rows = list(schema.normalize_data_item(chats, "1762162.1212", "zen")) @@ -428,7 +463,9 @@ def test_child_row_deterministic_hash(norm: RelationalNormalizer) -> None: # directly set record hash so it will be adopted in normalizer as top level hash row = { "_dlt_id": row_id, - "f": [{"l": ["a", "b", "c"], "v": 120, "lo": [{"e": "a"}, {"e": "b"}, {"e": "c"}]}], + "f": [ + {"l": ["a", "b", "c"], "v": 120, "lo": [{"e": "a"}, {"e": "b"}, {"e": "c"}]} + ], } rows = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] children = [t for t in rows if t[0][0] != "table"] @@ -444,26 +481,38 @@ def test_child_row_deterministic_hash(norm: RelationalNormalizer) -> None: assert ch["_dlt_id"] == expected_hash # direct compute one of the - el_f = next(t[1] for t in rows if t[0][0] == "table__f" and t[1]["_dlt_list_idx"] == 0) - f_lo_p2 = next(t[1] for t in rows if t[0][0] == "table__f__lo" and t[1]["_dlt_list_idx"] == 2) - assert f_lo_p2["_dlt_id"] == digest128(f"{el_f['_dlt_id']}_table__f__lo_2", DLT_ID_LENGTH_BYTES) + el_f = next( + t[1] for t in rows if t[0][0] == "table__f" and t[1]["_dlt_list_idx"] == 0 + ) + f_lo_p2 = next( + t[1] for t in rows if t[0][0] == "table__f__lo" and t[1]["_dlt_list_idx"] == 2 + ) + assert f_lo_p2["_dlt_id"] == digest128( + f"{el_f['_dlt_id']}_table__f__lo_2", DLT_ID_LENGTH_BYTES + ) # same data with same table and row_id rows_2 = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] children_2 = [t for t in rows_2 if t[0][0] != "table"] # corresponding hashes must be identical - assert all(ch[0][1]["_dlt_id"] == ch[1][1]["_dlt_id"] for ch in zip(children, children_2)) + assert all( + ch[0][1]["_dlt_id"] == ch[1][1]["_dlt_id"] for ch in zip(children, children_2) + ) # change parent table and all child hashes must be different rows_4 = list(norm._normalize_row(row, {}, ("other_table",))) # type: ignore[arg-type] children_4 = [t for t in rows_4 if t[0][0] != "other_table"] - assert all(ch[0][1]["_dlt_id"] != ch[1][1]["_dlt_id"] for ch in zip(children, children_4)) + assert all( + ch[0][1]["_dlt_id"] != ch[1][1]["_dlt_id"] for ch in zip(children, children_4) + ) # change parent hash and all child hashes must be different row["_dlt_id"] = uniq_id() rows_3 = list(norm._normalize_row(row, {}, ("table",))) # type: ignore[arg-type] children_3 = [t for t in rows_3 if t[0][0] != "table"] - assert all(ch[0][1]["_dlt_id"] != ch[1][1]["_dlt_id"] for ch in zip(children, children_3)) + assert all( + ch[0][1]["_dlt_id"] != ch[1][1]["_dlt_id"] for ch in zip(children, children_3) + ) def test_keeps_dlt_id(norm: RelationalNormalizer) -> None: @@ -514,14 +563,16 @@ def test_propagates_root_context(norm: RelationalNormalizer) -> None: assert all("__not_found" not in r[1] for r in non_root) -@pytest.mark.parametrize("add_pk,add_dlt_id", [(False, False), (True, False), (True, True)]) +@pytest.mark.parametrize( + "add_pk,add_dlt_id", [(False, False), (True, False), (True, True)] +) def test_propagates_table_context( norm: RelationalNormalizer, add_pk: bool, add_dlt_id: bool ) -> None: add_dlt_root_id_propagation(norm) - prop_config: RelationalNormalizerConfigPropagation = norm.schema._normalizers_config["json"][ - "config" - ]["propagation"] + prop_config: RelationalNormalizerConfigPropagation = ( + norm.schema._normalizers_config["json"]["config"]["propagation"] + ) prop_config["root"]["timestamp"] = "_partition_ts" # type: ignore[index] # for table "table__lvl1" request to propagate "vx" and "partition_ovr" as "_partition_ts" (should overwrite root) prop_config["tables"]["table__lvl1"] = { # type: ignore[index] @@ -538,7 +589,11 @@ def test_propagates_table_context( "_dlt_id": "###", "timestamp": 12918291.1212, "lvl1": [ - {"vx": "ax", "partition_ovr": 1283.12, "lvl2": [{"_partition_ts": "overwritten"}]} + { + "vx": "ax", + "partition_ovr": 1283.12, + "lvl2": [{"_partition_ts": "overwritten"}], + } ], } if add_dlt_id: @@ -552,7 +607,11 @@ def test_propagates_table_context( # __not_found nowhere assert all("__not_found" not in r[1] for r in non_root) # _partition_ts == timestamp only at lvl1 - assert all(r[1]["_partition_ts"] == 12918291.1212 for r in non_root if r[0][0] == "table__lvl1") + assert all( + r[1]["_partition_ts"] == 12918291.1212 + for r in non_root + if r[0][0] == "table__lvl1" + ) # _partition_ts == partition_ovr and __vx only at lvl2 assert all( r[1]["_partition_ts"] == 1283.12 and r[1]["__vx"] == "ax" @@ -571,9 +630,9 @@ def test_propagates_table_context( def test_propagates_table_context_to_lists(norm: RelationalNormalizer) -> None: add_dlt_root_id_propagation(norm) - prop_config: RelationalNormalizerConfigPropagation = norm.schema._normalizers_config["json"][ - "config" - ]["propagation"] + prop_config: RelationalNormalizerConfigPropagation = ( + norm.schema._normalizers_config["json"]["config"]["propagation"] + ) prop_config["root"]["timestamp"] = "_partition_ts" # type: ignore[index] row = {"_dlt_id": "###", "timestamp": 12918291.1212, "lvl1": [1, 2, 3, [4, 5, 6]]} @@ -712,7 +771,11 @@ def test_extract_with_table_name_meta() -> None: "permission_overwrites": [], } # force table name - rows = list(create_schema_with_name("discord").normalize_data_item(row, "load_id", "channel")) + rows = list( + create_schema_with_name("discord").normalize_data_item( + row, "load_id", "channel" + ) + ) # table is channel assert rows[0][0][0] == "channel" normalized_row = rows[0][1] @@ -727,7 +790,9 @@ def test_table_name_meta_normalized() -> None: } # force table name rows = list( - create_schema_with_name("discord").normalize_data_item(row, "load_id", "channelSURFING") + create_schema_with_name("discord").normalize_data_item( + row, "load_id", "channelSURFING" + ) ) # table is channel assert rows[0][0][0] == "channel_surfing" @@ -739,7 +804,10 @@ def test_parse_with_primary_key() -> None: schema._compile_settings() add_dlt_root_id_propagation(schema.data_item_normalizer) # type: ignore[arg-type] - row = {"id": "817949077341208606", "w_id": [{"id": 9128918293891111, "wo_id": [1, 2, 3]}]} + row = { + "id": "817949077341208606", + "w_id": [{"id": 9128918293891111, "wo_id": [1, 2, 3]}], + } rows = list(schema.normalize_data_item(row, "load_id", "discord")) # get root root = next(t[1] for t in rows if t[0][0] == "discord") @@ -758,11 +826,17 @@ def test_parse_with_primary_key() -> None: # this must have deterministic child key f_wo_id = next( - t[1] for t in rows if t[0][0] == "discord__w_id__wo_id" and t[1]["_dlt_list_idx"] == 2 + t[1] + for t in rows + if t[0][0] == "discord__w_id__wo_id" and t[1]["_dlt_list_idx"] == 2 ) assert f_wo_id["value"] == 3 - assert f_wo_id["_dlt_root_id"] != digest128("817949077341208606", DLT_ID_LENGTH_BYTES) - assert f_wo_id["_dlt_parent_id"] != digest128("9128918293891111", DLT_ID_LENGTH_BYTES) + assert f_wo_id["_dlt_root_id"] != digest128( + "817949077341208606", DLT_ID_LENGTH_BYTES + ) + assert f_wo_id["_dlt_parent_id"] != digest128( + "9128918293891111", DLT_ID_LENGTH_BYTES + ) assert f_wo_id["_dlt_id"] == RelationalNormalizer._get_child_row_hash( f_wo_id["_dlt_parent_id"], "discord__w_id__wo_id", 2 ) @@ -770,7 +844,11 @@ def test_parse_with_primary_key() -> None: def test_keeps_none_values() -> None: row = {"a": None, "timestamp": 7} - rows = list(create_schema_with_name("other").normalize_data_item(row, "1762162.1212", "other")) + rows = list( + create_schema_with_name("other").normalize_data_item( + row, "1762162.1212", "other" + ) + ) table_name = rows[0][0][0] assert table_name == "other" normalized_row = rows[0][1] @@ -803,7 +881,8 @@ def test_normalize_and_shorten_deterministically() -> None: root_data_keys = list(root_data.keys()) # "short:ident:2": "a" will be flattened into root tag = NamingConvention._compute_tag( - "short_ident_1__short_ident_2__short_ident_3", NamingConvention._DEFAULT_COLLISION_PROB + "short_ident_1__short_ident_2__short_ident_3", + NamingConvention._DEFAULT_COLLISION_PROB, ) assert tag in root_data_keys[0] # long:SO+LONG:_>16 shortened on normalized name @@ -814,7 +893,8 @@ def test_normalize_and_shorten_deterministically() -> None: # table name in second row table_name = rows[1][0][0] tag = NamingConvention._compute_tag( - "s__lis_txident_1__lis_txident_2__lis_txident_3", NamingConvention._DEFAULT_COLLISION_PROB + "s__lis_txident_1__lis_txident_2__lis_txident_3", + NamingConvention._DEFAULT_COLLISION_PROB, ) assert tag in table_name @@ -853,13 +933,16 @@ def test_propagation_update_on_table_change(norm: RelationalNormalizer): table_2 = new_table("table_2", parent_table_name="table_1") norm.schema.update_table(table_2) assert ( - "table_2" not in norm.schema._normalizers_config["json"]["config"]["propagation"]["tables"] + "table_2" + not in norm.schema._normalizers_config["json"]["config"]["propagation"][ + "tables" + ] ) # test merging into existing propagation - norm.schema._normalizers_config["json"]["config"]["propagation"]["tables"]["table_3"] = { - "prop1": "prop2" - } + norm.schema._normalizers_config["json"]["config"]["propagation"]["tables"][ + "table_3" + ] = {"prop1": "prop2"} table_3 = new_table("table_3", write_disposition="merge") norm.schema.update_table(table_3) assert norm.schema._normalizers_config["json"]["config"]["propagation"]["tables"][ @@ -868,7 +951,9 @@ def test_propagation_update_on_table_change(norm: RelationalNormalizer): def set_max_nesting(norm: RelationalNormalizer, max_nesting: int) -> None: - RelationalNormalizer.update_normalizer_config(norm.schema, {"max_nesting": max_nesting}) + RelationalNormalizer.update_normalizer_config( + norm.schema, {"max_nesting": max_nesting} + ) norm._reset() diff --git a/tests/common/normalizers/test_naming.py b/tests/common/normalizers/test_naming.py index 3bf4762c35..03af77acfd 100644 --- a/tests/common/normalizers/test_naming.py +++ b/tests/common/normalizers/test_naming.py @@ -3,8 +3,12 @@ from typing import List, Type from dlt.common.normalizers.naming import NamingConvention -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention -from dlt.common.normalizers.naming.direct import NamingConvention as DirectNamingConvention +from dlt.common.normalizers.naming.snake_case import ( + NamingConvention as SnakeCaseNamingConvention, +) +from dlt.common.normalizers.naming.direct import ( + NamingConvention as DirectNamingConvention, +) from dlt.common.typing import DictStrStr from dlt.common.utils import uniq_id @@ -115,11 +119,17 @@ def test_tag_placement() -> None: def test_shorten_identifier() -> None: # no limit long_ident = 8 * LONG_PATH - assert NamingConvention.shorten_identifier(long_ident, long_ident, None) == long_ident + assert ( + NamingConvention.shorten_identifier(long_ident, long_ident, None) == long_ident + ) # within limit - assert NamingConvention.shorten_identifier("012345678", "xxx012345678xxx", 10) == "012345678" assert ( - NamingConvention.shorten_identifier("0123456789", "xxx012345678xx?", 10) == "0123456789" + NamingConvention.shorten_identifier("012345678", "xxx012345678xxx", 10) + == "012345678" + ) + assert ( + NamingConvention.shorten_identifier("0123456789", "xxx012345678xx?", 10) + == "0123456789" ) # max_length # tag based on original string placed in the middle tag = NamingConvention._compute_tag( @@ -134,12 +144,16 @@ def test_shorten_identifier() -> None: tag = NamingConvention._compute_tag( raw_content, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB ) - norm_ident = NamingConvention.shorten_identifier(IDENT_20_CHARS, raw_content, 20) + norm_ident = NamingConvention.shorten_identifier( + IDENT_20_CHARS, raw_content, 20 + ) assert tag in norm_ident assert len(norm_ident) == 20 -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize( + "convention", (SnakeCaseNamingConvention, DirectNamingConvention) +) def test_normalize_with_shorten_identifier(convention: Type[NamingConvention]) -> None: naming = convention() # None/empty ident raises @@ -151,7 +165,9 @@ def test_normalize_with_shorten_identifier(convention: Type[NamingConvention]) - # normalized string is different assert naming.normalize_identifier(RAW_IDENT) != RAW_IDENT # strip spaces - assert naming.normalize_identifier(RAW_IDENT) == naming.normalize_identifier(RAW_IDENT_W_SPACES) + assert naming.normalize_identifier(RAW_IDENT) == naming.normalize_identifier( + RAW_IDENT_W_SPACES + ) # force to shorten naming = convention(len(RAW_IDENT) // 2) @@ -160,11 +176,15 @@ def test_normalize_with_shorten_identifier(convention: Type[NamingConvention]) - RAW_IDENT, collision_prob=NamingConvention._DEFAULT_COLLISION_PROB ) # spaces are stripped - assert naming.normalize_identifier(RAW_IDENT) == naming.normalize_identifier(RAW_IDENT_W_SPACES) + assert naming.normalize_identifier(RAW_IDENT) == naming.normalize_identifier( + RAW_IDENT_W_SPACES + ) assert tag in naming.normalize_identifier(RAW_IDENT) -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize( + "convention", (SnakeCaseNamingConvention, DirectNamingConvention) +) def test_normalize_path_shorting(convention: Type[NamingConvention]) -> None: naming = convention() path = naming.make_path(*LONG_PATH.split("__")) @@ -207,7 +227,9 @@ def test_normalize_path_shorting(convention: Type[NamingConvention]) -> None: assert len(naming.break_path(norm_path)) == 1 -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize( + "convention", (SnakeCaseNamingConvention, DirectNamingConvention) +) def test_normalize_path(convention: Type[NamingConvention]) -> None: naming = convention() raw_path_str = naming.make_path(*RAW_PATH) @@ -239,7 +261,9 @@ def test_normalize_path(convention: Type[NamingConvention]) -> None: == naming.normalize_path(tagged_raw_path_str) == naming.normalize_path(naming.normalize_path(tagged_raw_path_str)) ) - assert tagged_raw_path_str == naming.make_path(*naming.break_path(tagged_raw_path_str)) + assert tagged_raw_path_str == naming.make_path( + *naming.break_path(tagged_raw_path_str) + ) # also cut idents naming = convention(len(RAW_IDENT) - 4) @@ -248,7 +272,9 @@ def test_normalize_path(convention: Type[NamingConvention]) -> None: assert tag in tagged_raw_path_str -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DirectNamingConvention)) +@pytest.mark.parametrize( + "convention", (SnakeCaseNamingConvention, DirectNamingConvention) +) def test_shorten_fragments(convention: Type[NamingConvention]) -> None: # max length around the length of the path naming = convention() @@ -273,4 +299,6 @@ def test_shorten_fragments(convention: Type[NamingConvention]) -> None: def assert_short_path(norm_path: str, naming: NamingConvention) -> None: assert len(norm_path) == naming.max_length assert naming.normalize_path(norm_path) == norm_path - assert all(len(ident) <= naming.max_length for ident in naming.break_path(norm_path)) + assert all( + len(ident) <= naming.max_length for ident in naming.break_path(norm_path) + ) diff --git a/tests/common/normalizers/test_naming_duck_case.py b/tests/common/normalizers/test_naming_duck_case.py index 099134ca2f..08b9e0df5e 100644 --- a/tests/common/normalizers/test_naming_duck_case.py +++ b/tests/common/normalizers/test_naming_duck_case.py @@ -1,7 +1,9 @@ import pytest from dlt.common.normalizers.naming.duck_case import NamingConvention -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeNamingConvention +from dlt.common.normalizers.naming.snake_case import ( + NamingConvention as SnakeNamingConvention, +) @pytest.fixture diff --git a/tests/common/normalizers/test_naming_snake_case.py b/tests/common/normalizers/test_naming_snake_case.py index 6d619b5257..daca486908 100644 --- a/tests/common/normalizers/test_naming_snake_case.py +++ b/tests/common/normalizers/test_naming_snake_case.py @@ -2,8 +2,12 @@ import pytest from dlt.common.normalizers.naming import NamingConvention -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention -from dlt.common.normalizers.naming.duck_case import NamingConvention as DuckCaseNamingConvention +from dlt.common.normalizers.naming.snake_case import ( + NamingConvention as SnakeCaseNamingConvention, +) +from dlt.common.normalizers.naming.duck_case import ( + NamingConvention as DuckCaseNamingConvention, +) @pytest.fixture @@ -25,8 +29,14 @@ def test_normalize_identifier(naming_unlimited: NamingConvention) -> None: assert naming_unlimited.normalize_identifier("BAN_ANA") == "ban_ana" assert naming_unlimited.normalize_identifier("BANaNA") == "ba_na_na" # handling spaces - assert naming_unlimited.normalize_identifier("Small Love Potion") == "small_love_potion" - assert naming_unlimited.normalize_identifier(" Small Love Potion ") == "small_love_potion" + assert ( + naming_unlimited.normalize_identifier("Small Love Potion") + == "small_love_potion" + ) + assert ( + naming_unlimited.normalize_identifier(" Small Love Potion ") + == "small_love_potion" + ) # removes trailing _ assert naming_unlimited.normalize_identifier("BANANA_") == "bananax" assert naming_unlimited.normalize_identifier("BANANA____") == "bananaxxxx" @@ -39,16 +49,22 @@ def test_normalize_identifier(naming_unlimited: NamingConvention) -> None: def test_alphabet_reduction(naming_unlimited: NamingConvention) -> None: assert ( - naming_unlimited.normalize_identifier(SnakeCaseNamingConvention._REDUCE_ALPHABET[0]) + naming_unlimited.normalize_identifier( + SnakeCaseNamingConvention._REDUCE_ALPHABET[0] + ) == SnakeCaseNamingConvention._REDUCE_ALPHABET[1] ) def test_normalize_path(naming_unlimited: NamingConvention) -> None: assert naming_unlimited.normalize_path("small_love_potion") == "small_love_potion" - assert naming_unlimited.normalize_path("small__love__potion") == "small__love__potion" + assert ( + naming_unlimited.normalize_path("small__love__potion") == "small__love__potion" + ) assert naming_unlimited.normalize_path("Small_Love_Potion") == "small_love_potion" - assert naming_unlimited.normalize_path("Small__Love__Potion") == "small__love__potion" + assert ( + naming_unlimited.normalize_path("Small__Love__Potion") == "small__love__potion" + ) assert naming_unlimited.normalize_path("Small Love Potion") == "small_love_potion" assert naming_unlimited.normalize_path("Small Love Potion") == "small_love_potion" @@ -56,10 +72,14 @@ def test_normalize_path(naming_unlimited: NamingConvention) -> None: def test_normalize_non_alpha_single_underscore() -> None: assert SnakeCaseNamingConvention._RE_NON_ALPHANUMERIC.sub("_", "-=!*") == "_" assert SnakeCaseNamingConvention._RE_NON_ALPHANUMERIC.sub("_", "1-=!0*-") == "1_0_" - assert SnakeCaseNamingConvention._RE_NON_ALPHANUMERIC.sub("_", "1-=!_0*-") == "1__0_" + assert ( + SnakeCaseNamingConvention._RE_NON_ALPHANUMERIC.sub("_", "1-=!_0*-") == "1__0_" + ) -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DuckCaseNamingConvention)) +@pytest.mark.parametrize( + "convention", (SnakeCaseNamingConvention, DuckCaseNamingConvention) +) def test_normalize_break_path(convention: Type[NamingConvention]) -> None: naming_unlimited = convention() assert naming_unlimited.break_path("A__B__C") == ["A", "B", "C"] @@ -71,7 +91,9 @@ def test_normalize_break_path(convention: Type[NamingConvention]) -> None: assert naming_unlimited.break_path("_a__ \t\r__b") == ["_a", "b"] -@pytest.mark.parametrize("convention", (SnakeCaseNamingConvention, DuckCaseNamingConvention)) +@pytest.mark.parametrize( + "convention", (SnakeCaseNamingConvention, DuckCaseNamingConvention) +) def test_normalize_make_path(convention: Type[NamingConvention]) -> None: naming_unlimited = convention() assert naming_unlimited.make_path("A", "B") == "A__B" @@ -82,6 +104,10 @@ def test_normalize_make_path(convention: Type[NamingConvention]) -> None: def test_normalizes_underscores(naming_unlimited: NamingConvention) -> None: assert ( - naming_unlimited.normalize_identifier("event__value_value2____") == "event_value_value2xxxx" + naming_unlimited.normalize_identifier("event__value_value2____") + == "event_value_value2xxxx" + ) + assert ( + naming_unlimited.normalize_path("e_vent__value_value2___") + == "e_vent__value_value2__x" ) - assert naming_unlimited.normalize_path("e_vent__value_value2___") == "e_vent__value_value2__x" diff --git a/tests/common/reflection/test_reflect_spec.py b/tests/common/reflection/test_reflect_spec.py index 952d0fc596..e70019b34d 100644 --- a/tests/common/reflection/test_reflect_spec.py +++ b/tests/common/reflection/test_reflect_spec.py @@ -100,7 +100,9 @@ def f_untyped_default( ) -> None: pass - SPEC, _ = spec_from_signature(f_untyped_default, inspect.signature(f_untyped_default)) + SPEC, _ = spec_from_signature( + f_untyped_default, inspect.signature(f_untyped_default) + ) assert SPEC.untyped_p1 == "str" assert SPEC.untyped_p2 == _DECIMAL_DEFAULT assert isinstance(SPEC().untyped_p3, ConnectionStringCredentials) @@ -157,7 +159,9 @@ def f_variadic(var_1: str = "A", *args, kw_var_1: str, **kwargs) -> None: SPEC, _ = spec_from_signature(f_variadic, inspect.signature(f_variadic)) assert SPEC.var_1 == "A" - assert not hasattr(SPEC, "kw_var_1") # kw parameters that must be explicitly passed are removed + assert not hasattr( + SPEC, "kw_var_1" + ) # kw parameters that must be explicitly passed are removed assert not hasattr(SPEC, "args") fields = SPEC.get_resolvable_fields() assert fields == {"var_1": str} @@ -167,7 +171,9 @@ def test_spec_when_no_fields() -> None: def f_default_only(arg1, arg2=None): pass - SPEC, fields = spec_from_signature(f_default_only, inspect.signature(f_default_only)) + SPEC, fields = spec_from_signature( + f_default_only, inspect.signature(f_default_only) + ) assert len(fields) > 0 del globals()[SPEC.__name__] @@ -223,7 +229,11 @@ def f_defaults( @with_config def f_kw_defaults( - *, kw1=dlt.config.value, kw_lit="12131", kw_secret_val=dlt.secrets.value, **kwargs + *, + kw1=dlt.config.value, + kw_lit="12131", + kw_secret_val=dlt.secrets.value, + **kwargs, ): pass diff --git a/tests/common/runners/test_pipes.py b/tests/common/runners/test_pipes.py index 6db7c2d0e2..57fe8aeb14 100644 --- a/tests/common/runners/test_pipes.py +++ b/tests/common/runners/test_pipes.py @@ -90,19 +90,27 @@ def test_synth_pickler_unknown_types() -> None: def test_iter_stdout() -> None: with Venv.create(tempfile.mkdtemp()) as venv: expected = ["0", "1", "2", "3", "4", "exit"] - for i, l in enumerate(iter_stdout(venv, "python", "tests/common/scripts/counter.py")): + for i, l in enumerate( + iter_stdout(venv, "python", "tests/common/scripts/counter.py") + ): assert expected[i] == l lines = list(iter_stdout(venv, "python", "tests/common/scripts/empty.py")) assert lines == [] with pytest.raises(CalledProcessError) as cpe: list( - iter_stdout(venv, "python", "tests/common/scripts/no_stdout_no_stderr_with_fail.py") + iter_stdout( + venv, + "python", + "tests/common/scripts/no_stdout_no_stderr_with_fail.py", + ) ) # empty stdout assert cpe.value.output == "" assert cpe.value.stderr == "" # three lines with 1 MB size + newline - for _i, l in enumerate(iter_stdout(venv, "python", "tests/common/scripts/long_lines.py")): + for _i, l in enumerate( + iter_stdout(venv, "python", "tests/common/scripts/long_lines.py") + ): assert len(l) == 1024 * 1024 assert _i == 2 @@ -123,7 +131,11 @@ def test_iter_stdout_raises() -> None: # we actually consumed part of the iterator up until "2" assert i == 2 with pytest.raises(CalledProcessError) as cpe: - list(iter_stdout(venv, "python", "tests/common/scripts/no_stdout_exception.py")) + list( + iter_stdout( + venv, "python", "tests/common/scripts/no_stdout_exception.py" + ) + ) # empty stdout assert cpe.value.output == "" assert "no stdout" in cpe.value.stderr @@ -149,25 +161,37 @@ def test_iter_stdout_raises() -> None: def test_stdout_encode_result() -> None: # use current venv to execute so we have dlt venv = Venv.restore_current() - lines = list(iter_stdout(venv, "python", "tests/common/scripts/stdout_encode_result.py")) + lines = list( + iter_stdout(venv, "python", "tests/common/scripts/stdout_encode_result.py") + ) # last line contains results assert decode_obj(lines[-1]) == ("this is string", TRunMetrics(True, 300)) # stderr will contain pickled exception somewhere with pytest.raises(CalledProcessError) as cpe: - list(iter_stdout(venv, "python", "tests/common/scripts/stdout_encode_exception.py")) + list( + iter_stdout( + venv, "python", "tests/common/scripts/stdout_encode_exception.py" + ) + ) assert isinstance(decode_last_obj(cpe.value.stderr.split("\n")), Exception) # this script returns something that it cannot pickle - lines = list(iter_stdout(venv, "python", "tests/common/scripts/stdout_encode_unpicklable.py")) + lines = list( + iter_stdout(venv, "python", "tests/common/scripts/stdout_encode_unpicklable.py") + ) assert decode_last_obj(lines) is None def test_iter_stdout_with_result() -> None: venv = Venv.restore_current() - i = iter_stdout_with_result(venv, "python", "tests/common/scripts/stdout_encode_result.py") + i = iter_stdout_with_result( + venv, "python", "tests/common/scripts/stdout_encode_result.py" + ) assert iter_until_returns(i) == ("this is string", TRunMetrics(True, 300)) - i = iter_stdout_with_result(venv, "python", "tests/common/scripts/stdout_encode_unpicklable.py") + i = iter_stdout_with_result( + venv, "python", "tests/common/scripts/stdout_encode_unpicklable.py" + ) assert iter_until_returns(i) is None # it just excepts without encoding exception with pytest.raises(CalledProcessError): diff --git a/tests/common/runners/test_runnable.py b/tests/common/runners/test_runnable.py index e25f28e521..707a3a6586 100644 --- a/tests/common/runners/test_runnable.py +++ b/tests/common/runners/test_runnable.py @@ -89,7 +89,9 @@ def test_configuredworker(method: str) -> None: p.map(_worker_1, *zip(*[(config, "PX1", "PX2")])) -def _worker_1(CONFIG: SchemaStorageConfiguration, par1: str, par2: str = "DEFAULT") -> None: +def _worker_1( + CONFIG: SchemaStorageConfiguration, par1: str, par2: str = "DEFAULT" +) -> None: # a correct type was passed assert type(CONFIG) is SchemaStorageConfiguration # check if config values are restored diff --git a/tests/common/runners/test_venv.py b/tests/common/runners/test_venv.py index ee62df3c83..f134aa7245 100644 --- a/tests/common/runners/test_venv.py +++ b/tests/common/runners/test_venv.py @@ -33,7 +33,9 @@ def test_restore_venv() -> None: assert venv.context.env_dir == restored_venv.context.env_dir assert venv.context.env_exe == restored_venv.context.env_exe script = "print('success')" - assert restored_venv.run_command(venv.context.env_exe, "-c", script) == "success\n" + assert ( + restored_venv.run_command(venv.context.env_exe, "-c", script) == "success\n" + ) # restored env will fail - venv deleted with pytest.raises(FileNotFoundError): restored_venv.run_command(venv.context.env_exe, "-c", script) @@ -98,7 +100,9 @@ def test_venv_working_dir() -> None: print(os.getcwd()) """ - assert venv.run_command(venv.context.env_exe, "-c", script).strip() == os.getcwd() + assert ( + venv.run_command(venv.context.env_exe, "-c", script).strip() == os.getcwd() + ) def test_run_command_with_error() -> None: @@ -150,7 +154,9 @@ def test_run_script() -> None: assert lines[-1] == "exit" # argv - result = venv.run_script(os.path.abspath("tests/common/scripts/args.py"), "--with-arg") + result = venv.run_script( + os.path.abspath("tests/common/scripts/args.py"), "--with-arg" + ) lines = result.splitlines() assert lines[0] == "2" assert "'--with-arg'" in lines[1] @@ -164,7 +170,9 @@ def test_run_script() -> None: # non exiting script with pytest.raises(FileNotFoundError): - venv.run_script(os.path.abspath("tests/common/scripts/_non_existing_.py"), "--with-arg") + venv.run_script( + os.path.abspath("tests/common/scripts/_non_existing_.py"), "--with-arg" + ) # raising script with pytest.raises(CalledProcessError) as cpe: @@ -202,7 +210,9 @@ def test_current_venv() -> None: assert "dlt" in freeze # use command - with venv.start_command("pip", "freeze", "--all", stdout=PIPE, text=True) as process: + with venv.start_command( + "pip", "freeze", "--all", stdout=PIPE, text=True + ) as process: output, _ = process.communicate() assert process.poll() == 0 assert "pip" in output @@ -220,7 +230,9 @@ def test_current_base_python() -> None: assert "dlt" in freeze # use command - with venv.start_command("pip", "freeze", "--all", stdout=PIPE, text=True) as process: + with venv.start_command( + "pip", "freeze", "--all", stdout=PIPE, text=True + ) as process: output, _ = process.communicate() assert process.poll() == 0 assert "pip" in output @@ -228,7 +240,9 @@ def test_current_base_python() -> None: def test_start_command() -> None: with Venv.create(tempfile.mkdtemp()) as venv: - with venv.start_command("pip", "freeze", "--all", stdout=PIPE, text=True) as process: + with venv.start_command( + "pip", "freeze", "--all", stdout=PIPE, text=True + ) as process: output, _ = process.communicate() assert process.poll() == 0 assert "pip" in output @@ -247,5 +261,7 @@ def test_start_command() -> None: venv.start_command("blip", "freeze", "--all", stdout=PIPE, text=True) # command exit code - with venv.start_command("pip", "wrong_command", stdout=PIPE, text=True) as process: + with venv.start_command( + "pip", "wrong_command", stdout=PIPE, text=True + ) as process: assert process.wait() == 1 diff --git a/tests/common/runners/utils.py b/tests/common/runners/utils.py index 3d6adbf70c..17742ace8c 100644 --- a/tests/common/runners/utils.py +++ b/tests/common/runners/utils.py @@ -10,7 +10,9 @@ from dlt.common.utils import uniq_id # remove fork-server because it hangs the tests no CI -ALL_METHODS = set(multiprocessing.get_all_start_methods()).intersection(["fork", "spawn"]) +ALL_METHODS = set(multiprocessing.get_all_start_methods()).intersection( + ["fork", "spawn"] +) @pytest.fixture(autouse=True) @@ -39,7 +41,10 @@ def _run(self, pool: Executor) -> List[Tuple[int, str, int]]: rid = id(self) assert rid in _TestRunnableWorkerMethod.RUNNING self.rv = rv = list( - pool.map(_TestRunnableWorkerMethod.worker, *zip(*[(rid, i) for i in range(self.tasks)])) + pool.map( + _TestRunnableWorkerMethod.worker, + *zip(*[(rid, i) for i in range(self.tasks)]), + ) ) assert rid in _TestRunnableWorkerMethod.RUNNING return rv @@ -65,7 +70,9 @@ def worker(v: int) -> Tuple[int, int]: def _run(self, pool: Executor) -> List[Tuple[int, int]]: self.rv = rv = list( - pool.map(_TestRunnableWorker.worker, *zip(*[(i,) for i in range(self.tasks)])) + pool.map( + _TestRunnableWorker.worker, *zip(*[(i,) for i in range(self.tasks)]) + ) ) return rv diff --git a/tests/common/runtime/test_collector.py b/tests/common/runtime/test_collector.py index dbe4b8c94d..34e12205ff 100644 --- a/tests/common/runtime/test_collector.py +++ b/tests/common/runtime/test_collector.py @@ -35,7 +35,9 @@ def test_dict_collector_context_manager(): def test_dict_collector_no_labels(): with DictCollector()("test") as collector: - with pytest.raises(AssertionError, match="labels not supported in dict collector"): + with pytest.raises( + AssertionError, match="labels not supported in dict collector" + ): collector.update("counter1", inc=1, label="label1") diff --git a/tests/common/runtime/test_telemetry.py b/tests/common/runtime/test_telemetry.py index e67f7e8360..a5a6f2fb83 100644 --- a/tests/common/runtime/test_telemetry.py +++ b/tests/common/runtime/test_telemetry.py @@ -25,9 +25,7 @@ @configspec class SentryLoggerConfiguration(RunConfiguration): pipeline_name: str = "logger" - sentry_dsn: str = ( - "https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752" - ) + sentry_dsn: str = "https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752" dlthub_telemetry_segment_write_key: str = "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB" diff --git a/tests/common/schema/test_coercion.py b/tests/common/schema/test_coercion.py index 34b62f9564..eba688fce3 100644 --- a/tests/common/schema/test_coercion.py +++ b/tests/common/schema/test_coercion.py @@ -33,7 +33,8 @@ def test_coerce_type_to_text() -> None: assert coerce_value("text", "binary", b"binary string") == "YmluYXJ5IHN0cmluZw==" # HexBytes to text (hex with prefix) assert ( - coerce_value("text", "binary", HexBytes(b"binary string")) == "0x62696e61727920737472696e67" + coerce_value("text", "binary", HexBytes(b"binary string")) + == "0x62696e61727920737472696e67" ) # Str enum value @@ -171,15 +172,15 @@ def test_coerce_type_to_timestamp() -> None: "2021-10-04T10:54:58.741524+00:00" ) # if text is ISO string it will be coerced - assert coerce_value("timestamp", "text", "2022-05-10T03:41:31.466000+00:00") == pendulum.parse( - "2022-05-10T03:41:31.466000+00:00" - ) - assert coerce_value("timestamp", "text", "2022-05-10T03:41:31.466+02:00") == pendulum.parse( - "2022-05-10T01:41:31.466Z" - ) - assert coerce_value("timestamp", "text", "2022-05-10T03:41:31.466+0200") == pendulum.parse( - "2022-05-10T01:41:31.466Z" - ) + assert coerce_value( + "timestamp", "text", "2022-05-10T03:41:31.466000+00:00" + ) == pendulum.parse("2022-05-10T03:41:31.466000+00:00") + assert coerce_value( + "timestamp", "text", "2022-05-10T03:41:31.466+02:00" + ) == pendulum.parse("2022-05-10T01:41:31.466Z") + assert coerce_value( + "timestamp", "text", "2022-05-10T03:41:31.466+0200" + ) == pendulum.parse("2022-05-10T01:41:31.466Z") # parse almost ISO compliant string assert coerce_value("timestamp", "text", "2022-04-26 10:36+02") == pendulum.parse( "2022-04-26T10:36:00+02:00" @@ -188,11 +189,13 @@ def test_coerce_type_to_timestamp() -> None: "2022-04-26T10:36:00+00:00" ) # parse date string - assert coerce_value("timestamp", "text", "2021-04-25") == pendulum.parse("2021-04-25") - # from date type - assert coerce_value("timestamp", "date", datetime.date(2023, 2, 27)) == pendulum.parse( - "2023-02-27" + assert coerce_value("timestamp", "text", "2021-04-25") == pendulum.parse( + "2021-04-25" ) + # from date type + assert coerce_value( + "timestamp", "date", datetime.date(2023, 2, 27) + ) == pendulum.parse("2023-02-27") # fails on "now" - yes pendulum by default parses "now" as .now() with pytest.raises(ValueError): @@ -246,19 +249,23 @@ def test_coerce_type_to_date() -> None: assert coerce_value("date", "double", 1677546399.494264) == pendulum.parse( "2023-02-28", exact=True ) - assert coerce_value("date", "text", " 1677546399 ") == pendulum.parse("2023-02-28", exact=True) - # ISO date string - assert coerce_value("date", "text", "2023-02-27") == pendulum.parse("2023-02-27", exact=True) - # ISO datetime string - assert coerce_value("date", "text", "2022-05-10T03:41:31.466000+00:00") == pendulum.parse( - "2022-05-10", exact=True - ) - assert coerce_value("date", "text", "2022-05-10T03:41:31.466+02:00") == pendulum.parse( - "2022-05-10", exact=True + assert coerce_value("date", "text", " 1677546399 ") == pendulum.parse( + "2023-02-28", exact=True ) - assert coerce_value("date", "text", "2022-05-10T03:41:31.466+0200") == pendulum.parse( - "2022-05-10", exact=True + # ISO date string + assert coerce_value("date", "text", "2023-02-27") == pendulum.parse( + "2023-02-27", exact=True ) + # ISO datetime string + assert coerce_value( + "date", "text", "2022-05-10T03:41:31.466000+00:00" + ) == pendulum.parse("2022-05-10", exact=True) + assert coerce_value( + "date", "text", "2022-05-10T03:41:31.466+02:00" + ) == pendulum.parse("2022-05-10", exact=True) + assert coerce_value( + "date", "text", "2022-05-10T03:41:31.466+0200" + ) == pendulum.parse("2022-05-10", exact=True) # almost ISO compliant string assert coerce_value("date", "text", "2022-04-26 10:36+02") == pendulum.parse( "2022-04-26", exact=True @@ -282,12 +289,14 @@ def test_coerce_type_to_time() -> None: "03:41:31.466000", exact=True ) # time object returns same value - assert coerce_value("time", "time", pendulum.time(3, 41, 31, 466000)) == pendulum.time( - 3, 41, 31, 466000 - ) + assert coerce_value( + "time", "time", pendulum.time(3, 41, 31, 466000) + ) == pendulum.time(3, 41, 31, 466000) # from datetime object fails with pytest.raises(TypeError): - coerce_value("time", "timestamp", pendulum.datetime(1995, 5, 6, 00, 1, 1, tz=UTC)) + coerce_value( + "time", "timestamp", pendulum.datetime(1995, 5, 6, 00, 1, 1, tz=UTC) + ) # from unix timestamp fails with pytest.raises(TypeError): @@ -300,7 +309,9 @@ def test_coerce_type_to_time() -> None: ) # ISO date string fails with pytest.raises(ValueError): - assert coerce_value("time", "text", "2023-02-27") == pendulum.parse("00:00:00", exact=True) + assert coerce_value("time", "text", "2023-02-27") == pendulum.parse( + "00:00:00", exact=True + ) # ISO datetime string fails with pytest.raises(ValueError): assert coerce_value("time", "text", "2022-05-10T03:41:31.466000+00:00") diff --git a/tests/common/schema/test_detections.py b/tests/common/schema/test_detections.py index 61ce0ede45..d3f8a9a32c 100644 --- a/tests/common/schema/test_detections.py +++ b/tests/common/schema/test_detections.py @@ -94,15 +94,25 @@ def test_wei_to_double() -> None: def test_detection_function() -> None: assert autodetect_sc_type(None, str, str(pendulum.now())) is None - assert autodetect_sc_type(["iso_timestamp"], str, str(pendulum.now())) == "timestamp" + assert ( + autodetect_sc_type(["iso_timestamp"], str, str(pendulum.now())) == "timestamp" + ) assert autodetect_sc_type(["iso_timestamp"], float, str(pendulum.now())) is None assert autodetect_sc_type(["iso_date"], str, str(pendulum.now().date())) == "date" assert autodetect_sc_type(["iso_date"], float, str(pendulum.now().date())) is None assert autodetect_sc_type(["timestamp"], str, str(pendulum.now())) is None assert ( - autodetect_sc_type(["timestamp", "iso_timestamp"], float, pendulum.now().timestamp()) + autodetect_sc_type( + ["timestamp", "iso_timestamp"], float, pendulum.now().timestamp() + ) == "timestamp" ) assert autodetect_sc_type(["timestamp", "large_integer"], int, 2**64) == "wei" - assert autodetect_sc_type(["large_integer", "hexbytes_to_text"], HexBytes, b"hey") == "text" - assert autodetect_sc_type(["large_integer", "wei_to_double"], Wei, Wei(10**18)) == "double" + assert ( + autodetect_sc_type(["large_integer", "hexbytes_to_text"], HexBytes, b"hey") + == "text" + ) + assert ( + autodetect_sc_type(["large_integer", "wei_to_double"], Wei, Wei(10**18)) + == "double" + ) diff --git a/tests/common/schema/test_filtering.py b/tests/common/schema/test_filtering.py index 8cfac9309f..c5d1b153c7 100644 --- a/tests/common/schema/test_filtering.py +++ b/tests/common/schema/test_filtering.py @@ -35,10 +35,14 @@ def test_whole_row_filter(schema: Schema) -> None: _add_excludes(schema) bot_case: DictStrAny = load_json_case("mod_bot_case") # the whole row should be eliminated if the exclude matches all the rows - filtered_case = schema.filter_row("event_bot__metadata", deepcopy(bot_case)["metadata"]) + filtered_case = schema.filter_row( + "event_bot__metadata", deepcopy(bot_case)["metadata"] + ) assert filtered_case == {} # also child rows will be excluded - filtered_case = schema.filter_row("event_bot__metadata__user", deepcopy(bot_case)["metadata"]) + filtered_case = schema.filter_row( + "event_bot__metadata__user", deepcopy(bot_case)["metadata"] + ) assert filtered_case == {} @@ -46,7 +50,9 @@ def test_whole_row_filter_with_exception(schema: Schema) -> None: _add_excludes(schema) bot_case: DictStrAny = load_json_case("mod_bot_case") # whole row will be eliminated - filtered_case = schema.filter_row("event_bot__custom_data", deepcopy(bot_case)["custom_data"]) + filtered_case = schema.filter_row( + "event_bot__custom_data", deepcopy(bot_case)["custom_data"] + ) # mind that path event_bot__custom_data__included_object was also eliminated assert filtered_case == {} # this child of the row has exception (^event_bot__custom_data__included_object__ - the __ at the end select all childern but not the parent) @@ -56,7 +62,8 @@ def test_whole_row_filter_with_exception(schema: Schema) -> None: ) assert filtered_case == bot_case["custom_data"]["included_object"] filtered_case = schema.filter_row( - "event_bot__custom_data__excluded_path", deepcopy(bot_case)["custom_data"]["excluded_path"] + "event_bot__custom_data__excluded_path", + deepcopy(bot_case)["custom_data"]["excluded_path"], ) assert filtered_case == {} @@ -99,7 +106,10 @@ def test_filter_parent_table_schema_update(schema: Schema) -> None: schema = Schema("event") _add_excludes(schema) schema.get_table("event_bot")["filters"]["includes"].extend( - [TSimpleRegex("re:^metadata___dlt_"), TSimpleRegex("re:^metadata__elvl1___dlt_")] + [ + TSimpleRegex("re:^metadata___dlt_"), + TSimpleRegex("re:^metadata__elvl1___dlt_"), + ] ) schema._compile_settings() for (t, p), row in schema.normalize_data_item(source_row, "load_id", "event_bot"): @@ -108,7 +118,9 @@ def test_filter_parent_table_schema_update(schema: Schema) -> None: assert "_dlt_id" in row else: # full linking not wiped out - assert set(row.keys()).issuperset(["_dlt_id", "_dlt_parent_id", "_dlt_list_idx"]) + assert set(row.keys()).issuperset( + ["_dlt_id", "_dlt_parent_id", "_dlt_list_idx"] + ) row, partial_table = schema.coerce_row(t, p, row) updates.append(partial_table) schema.update_table(partial_table) diff --git a/tests/common/schema/test_inference.py b/tests/common/schema/test_inference.py index da5c809827..984aecee26 100644 --- a/tests/common/schema/test_inference.py +++ b/tests/common/schema/test_inference.py @@ -44,7 +44,10 @@ def test_map_column_preferred_type(schema: Schema) -> None: # timestamp from coercable type assert schema._infer_column_type(18271, "timestamp") == "timestamp" assert schema._infer_column_type("18271.11", "timestamp") == "timestamp" - assert schema._infer_column_type("2022-05-10T00:54:38.237000+00:00", "timestamp") == "timestamp" + assert ( + schema._infer_column_type("2022-05-10T00:54:38.237000+00:00", "timestamp") + == "timestamp" + ) # value should be wei assert schema._infer_column_type(" 0xfe ", "value") == "wei" @@ -118,7 +121,10 @@ def test_coerce_row(schema: Schema) -> None: row_2 = {"timestamp": timestamp_float, "confidence": 0.18721} new_row_2, new_table = schema.coerce_row("event_user", None, row_2) assert new_table is None - assert new_row_2 == {"timestamp": pendulum.parse(timestamp_str), "confidence": 0.18721} + assert new_row_2 == { + "timestamp": pendulum.parse(timestamp_str), + "confidence": 0.18721, + } # all coerced row_3 = {"timestamp": "78172.128", "confidence": 1} @@ -134,13 +140,19 @@ def test_coerce_row(schema: Schema) -> None: assert new_columns[0]["data_type"] == "text" assert new_columns[0]["name"] == "confidence__v_text" assert new_columns[0]["variant"] is True - assert new_row_4 == {"timestamp": pendulum.parse(timestamp_str), "confidence__v_text": "STR"} + assert new_row_4 == { + "timestamp": pendulum.parse(timestamp_str), + "confidence__v_text": "STR", + } schema.update_table(new_table) # add against variant new_row_4, new_table = schema.coerce_row("event_user", None, row_4) assert new_table is None - assert new_row_4 == {"timestamp": pendulum.parse(timestamp_str), "confidence__v_text": "STR"} + assert new_row_4 == { + "timestamp": pendulum.parse(timestamp_str), + "confidence__v_text": "STR", + } # another variant new_row_5, new_table = schema.coerce_row("event_user", None, {"confidence": False}) @@ -153,7 +165,9 @@ def test_coerce_row(schema: Schema) -> None: # variant column clashes with existing column - create new_colbool_v_binary column that would be created for binary variant, but give it a type datetime _, new_table = schema.coerce_row( - "event_user", None, {"new_colbool": False, "new_colbool__v_timestamp": b"not fit"} + "event_user", + None, + {"new_colbool": False, "new_colbool__v_timestamp": b"not fit"}, ) schema.update_table(new_table) with pytest.raises(CannotCoerceColumnException) as exc_val: @@ -275,7 +289,9 @@ def test_supports_variant_pua_decode(schema: Schema) -> None: # use actual encoding for wei from dlt.common.json import _WEI, _HEXBYTES - rows[0]["_tx_transactionHash"] = rows[0]["_tx_transactionHash"].replace("", _HEXBYTES) + rows[0]["_tx_transactionHash"] = rows[0]["_tx_transactionHash"].replace( + "", _HEXBYTES + ) rows[0]["wad"] = rows[0]["wad"].replace("", _WEI) normalized_row = list(schema.normalize_data_item(rows[0], "0912uhj222", "event")) @@ -395,7 +411,9 @@ def test_corece_null_value_over_not_null(schema: Schema) -> None: row = {"timestamp": 82178.1298812} _, new_table = schema.coerce_row("event_user", None, row) schema.update_table(new_table) - schema.get_table_columns("event_user", include_incomplete=True)["timestamp"]["nullable"] = False + schema.get_table_columns("event_user", include_incomplete=True)["timestamp"][ + "nullable" + ] = False row = {"timestamp": None} with pytest.raises(CannotCoerceNullException): schema.coerce_row("event_user", None, row) diff --git a/tests/common/schema/test_merges.py b/tests/common/schema/test_merges.py index fe9e4b1476..b928b373fb 100644 --- a/tests/common/schema/test_merges.py +++ b/tests/common/schema/test_merges.py @@ -85,9 +85,9 @@ def test_remove_defaults_stored_schema() -> None: "x-top-level": True, } # mock the case in table_copy where resource == table_name - stored_schema["tables"]["table_copy"]["resource"] = stored_schema["tables"]["table_copy"][ - "name" - ] = "table_copy" + stored_schema["tables"]["table_copy"]["resource"] = stored_schema["tables"][ + "table_copy" + ]["name"] = "table_copy" default_stored = utils.remove_defaults(stored_schema) # nullability always present @@ -96,7 +96,9 @@ def test_remove_defaults_stored_schema() -> None: # not removed in complete column (as it was explicitly set to False) assert default_stored["tables"]["table"]["columns"]["test"]["cluster"] is False # not removed in incomplete one - assert default_stored["tables"]["table"]["columns"]["test_2"]["primary_key"] is False + assert ( + default_stored["tables"]["table"]["columns"]["test_2"]["primary_key"] is False + ) # resource present assert default_stored["tables"]["table"]["resource"] == "🦚Table" # resource removed because identical to table name @@ -132,7 +134,9 @@ def test_new_incomplete_column() -> None: def test_merge_columns() -> None: # tab_b overrides non default - col_a = utils.merge_column(copy(COL_1_HINTS), copy(COL_2_HINTS), merge_defaults=False) + col_a = utils.merge_column( + copy(COL_1_HINTS), copy(COL_2_HINTS), merge_defaults=False + ) # nullable is False - tab_b has it as default and those are not merged assert col_a == { "name": "test_2", @@ -146,7 +150,9 @@ def test_merge_columns() -> None: "prop": None, } - col_a = utils.merge_column(copy(COL_1_HINTS), copy(COL_2_HINTS), merge_defaults=True) + col_a = utils.merge_column( + copy(COL_1_HINTS), copy(COL_2_HINTS), merge_defaults=True + ) # nullable is True and primary_key is present - default values are merged assert col_a == { "name": "test_2", @@ -186,7 +192,11 @@ def test_diff_tables() -> None: changed["name"] = "new name" partial = utils.diff_table(deepcopy(table), changed) print(partial) - assert partial == {"name": "new name", "description": "new description", "columns": {}} + assert partial == { + "name": "new name", + "description": "new description", + "columns": {}, + } # ignore identical table props existing = deepcopy(table) @@ -203,7 +213,11 @@ def test_diff_tables() -> None: existing["write_disposition"] = "append" existing["schema_contract"] = "freeze" partial = utils.diff_table(deepcopy(existing), changed) - assert partial == {"name": "new name", "description": "new description", "columns": {}} + assert partial == { + "name": "new name", + "description": "new description", + "columns": {}, + } # detect changed column existing = deepcopy(table) @@ -212,7 +226,11 @@ def test_diff_tables() -> None: partial = utils.diff_table(existing, changed) assert "test" in partial["columns"] assert "test_2" not in partial["columns"] - assert existing["columns"]["test"] == table["columns"]["test"] != partial["columns"]["test"] + assert ( + existing["columns"]["test"] + == table["columns"]["test"] + != partial["columns"]["test"] + ) # defaults are not ignored existing = deepcopy(table) diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 887b0aa9a0..116779b4b7 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -75,7 +75,10 @@ def cn_schema() -> Schema: def test_normalize_schema_name(schema: Schema) -> None: assert schema.naming.normalize_table_identifier("BAN_ANA") == "ban_ana" assert schema.naming.normalize_table_identifier("event-.!:value") == "event_value" - assert schema.naming.normalize_table_identifier("123event-.!:value") == "_123event_value" + assert ( + schema.naming.normalize_table_identifier("123event-.!:value") + == "_123event_value" + ) with pytest.raises(ValueError): assert schema.naming.normalize_table_identifier("") with pytest.raises(ValueError): @@ -107,7 +110,9 @@ def test_new_schema_custom_normalizers(cn_schema: Schema) -> None: assert_new_schema_props_custom_normalizers(cn_schema) -def test_schema_config_normalizers(schema: Schema, schema_storage_no_import: SchemaStorage) -> None: +def test_schema_config_normalizers( + schema: Schema, schema_storage_no_import: SchemaStorage +) -> None: # save snake case schema schema_storage_no_import.save_schema(schema) # config direct naming convention @@ -141,7 +146,10 @@ def test_simple_regex_validator() -> None: # validate regex assert ( - utils.simple_regex_validator(".", "k", TSimpleRegex("re:^_record$"), TSimpleRegex) is True + utils.simple_regex_validator( + ".", "k", TSimpleRegex("re:^_record$"), TSimpleRegex + ) + is True ) # invalid regex with pytest.raises(DictValidationException) as e: @@ -164,16 +172,31 @@ def test_load_corrupted_schema() -> None: def test_column_name_validator(schema: Schema) -> None: assert utils.column_name_validator(schema.naming)(".", "k", "v", str) is False - assert utils.column_name_validator(schema.naming)(".", "k", "v", TColumnName) is True + assert ( + utils.column_name_validator(schema.naming)(".", "k", "v", TColumnName) is True + ) - assert utils.column_name_validator(schema.naming)(".", "k", "snake_case", TColumnName) is True + assert ( + utils.column_name_validator(schema.naming)(".", "k", "snake_case", TColumnName) + is True + ) # double underscores are accepted - assert utils.column_name_validator(schema.naming)(".", "k", "snake__case", TColumnName) is True + assert ( + utils.column_name_validator(schema.naming)(".", "k", "snake__case", TColumnName) + is True + ) # triple underscores are accepted - assert utils.column_name_validator(schema.naming)(".", "k", "snake___case", TColumnName) is True + assert ( + utils.column_name_validator(schema.naming)( + ".", "k", "snake___case", TColumnName + ) + is True + ) # quadruple underscores generate empty identifier with pytest.raises(DictValidationException) as e: - utils.column_name_validator(schema.naming)(".", "k", "snake____case", TColumnName) + utils.column_name_validator(schema.naming)( + ".", "k", "snake____case", TColumnName + ) assert "not a valid column name" in str(e.value) # this name is invalid with pytest.raises(DictValidationException) as e: @@ -203,7 +226,9 @@ def test_create_schema_with_normalize_name() -> None: def test_schema_descriptions_and_annotations(schema_storage: SchemaStorage): schema = SchemaStorage.load_schema_file( - os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "event", extensions=("yaml",) + os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), + "event", + extensions=("yaml",), ) assert schema.tables["blocks"]["description"] == "Ethereum blocks" assert schema.tables["blocks"]["x-annotation"] == "this will be preserved on save" # type: ignore[typeddict-item] @@ -223,9 +248,9 @@ def test_schema_descriptions_and_annotations(schema_storage: SchemaStorage): loaded_schema = schema_storage.load_schema("event") assert loaded_schema.tables["blocks"]["description"].endswith("Saved") assert loaded_schema.tables["blocks"]["x-annotation"].endswith("Saved") # type: ignore[typeddict-item] - assert loaded_schema.tables["blocks"]["columns"]["_dlt_load_id"]["description"].endswith( - "Saved" - ) + assert loaded_schema.tables["blocks"]["columns"]["_dlt_load_id"][ + "description" + ].endswith("Saved") assert loaded_schema.tables["blocks"]["columns"]["_dlt_load_id"]["x-column-annotation"].endswith("Saved") # type: ignore[typeddict-item] @@ -357,7 +382,13 @@ def test_clone(schema: Schema) -> None: "columns,hint,value", [ ( - ["_dlt_id", "_dlt_root_id", "_dlt_load_id", "_dlt_parent_id", "_dlt_list_idx"], + [ + "_dlt_id", + "_dlt_root_id", + "_dlt_load_id", + "_dlt_parent_id", + "_dlt_list_idx", + ], "nullable", False, ), @@ -484,7 +515,11 @@ def test_preserve_column_order(schema: Schema, schema_storage: SchemaStorage) -> schema.update_table(utils.new_table("event_test_order", columns=update)) def verify_items(table, update) -> None: - assert [i[0] for i in table.items()] == list(table.keys()) == [u["name"] for u in update] + assert ( + [i[0] for i in table.items()] + == list(table.keys()) + == [u["name"] for u in update] + ) assert [i[1] for i in table.items()] == list(table.values()) == update table = schema.get_table_columns("event_test_order") @@ -661,7 +696,9 @@ def test_default_table_resource() -> None: def test_data_tables(schema: Schema, schema_storage: SchemaStorage) -> None: assert schema.data_tables() == [] dlt_tables = schema.dlt_tables() - assert set([t["name"] for t in dlt_tables]) == set([LOADS_TABLE_NAME, VERSION_TABLE_NAME]) + assert set([t["name"] for t in dlt_tables]) == set( + [LOADS_TABLE_NAME, VERSION_TABLE_NAME] + ) # with tables schema = schema_storage.load_schema("event") # some of them are incomplete @@ -679,7 +716,9 @@ def test_data_tables(schema: Schema, schema_storage: SchemaStorage) -> None: schema.update_table( { "name": "event_user", - "columns": {"name": {"name": "name", "primary_key": True, "nullable": False}}, + "columns": { + "name": {"name": "name", "primary_key": True, "nullable": False} + }, } ) assert [t["name"] for t in schema.data_tables()] == ["event_slot"] @@ -689,7 +728,10 @@ def test_data_tables(schema: Schema, schema_storage: SchemaStorage) -> None: # make it complete schema.update_table( - {"name": "event_user", "columns": {"name": {"name": "name", "data_type": "text"}}} + { + "name": "event_user", + "columns": {"name": {"name": "name", "data_type": "text"}}, + } ) assert [t["name"] for t in schema.data_tables()] == ["event_slot", "event_user"] assert [t["name"] for t in schema.data_tables(include_incomplete=True)] == [ @@ -711,7 +753,9 @@ def test_write_disposition(schema_storage: SchemaStorage) -> None: schema.get_table("event_user")["write_disposition"] = "replace" schema.update_table(utils.new_table("event_user__intents", "event_user")) assert schema.get_table("event_user__intents").get("write_disposition") is None - assert utils.get_write_disposition(schema.tables, "event_user__intents") == "replace" + assert ( + utils.get_write_disposition(schema.tables, "event_user__intents") == "replace" + ) schema.get_table("event_user__intents")["write_disposition"] = "append" assert utils.get_write_disposition(schema.tables, "event_user__intents") == "append" @@ -740,22 +784,34 @@ def test_compare_columns() -> None: for c in table["columns"].values(): assert utils.compare_complete_columns(c, c) is True assert ( - utils.compare_complete_columns(table["columns"]["col3"], table["columns"]["col4"]) is False + utils.compare_complete_columns( + table["columns"]["col3"], table["columns"]["col4"] + ) + is False ) # data type may not differ assert ( - utils.compare_complete_columns(table["columns"]["col1"], table["columns"]["col3"]) is False + utils.compare_complete_columns( + table["columns"]["col1"], table["columns"]["col3"] + ) + is False ) # nullability may differ assert ( - utils.compare_complete_columns(table["columns"]["col1"], table2["columns"]["col1"]) is True + utils.compare_complete_columns( + table["columns"]["col1"], table2["columns"]["col1"] + ) + is True ) # any of the hints may differ for hint in COLUMN_HINTS: table["columns"]["col3"][hint] = True # type: ignore[typeddict-unknown-key] # name may not differ assert ( - utils.compare_complete_columns(table["columns"]["col3"], table["columns"]["col4"]) is False + utils.compare_complete_columns( + table["columns"]["col3"], table["columns"]["col4"] + ) + is False ) @@ -775,7 +831,12 @@ def test_normalize_table_identifiers() -> None: def test_normalize_table_identifiers_merge_columns() -> None: # create conflicting columns table_create = [ - {"name": "case", "data_type": "bigint", "nullable": False, "x-description": "desc"}, + { + "name": "case", + "data_type": "bigint", + "nullable": False, + "x-description": "desc", + }, {"name": "Case", "data_type": "double", "nullable": True, "primary_key": True}, ] # schema normalizing to snake case will conflict on case and Case @@ -794,7 +855,10 @@ def test_normalize_table_identifiers_merge_columns() -> None: def assert_new_schema_props_custom_normalizers(schema: Schema) -> None: # check normalizers config - assert schema._normalizers_config["names"] == "tests.common.normalizers.custom_normalizers" + assert ( + schema._normalizers_config["names"] + == "tests.common.normalizers.custom_normalizers" + ) assert ( schema._normalizers_config["json"]["module"] == "tests.common.normalizers.custom_normalizers" @@ -828,11 +892,16 @@ def assert_new_schema_props(schema: Schema) -> None: assert len(schema.settings["default_hints"]) > 0 # check settings assert ( - utils.standard_type_detections() == schema.settings["detections"] == schema._type_detections + utils.standard_type_detections() + == schema.settings["detections"] + == schema._type_detections ) # check normalizers config assert schema._normalizers_config["names"] == "snake_case" - assert schema._normalizers_config["json"]["module"] == "dlt.common.normalizers.json.relational" + assert ( + schema._normalizers_config["json"]["module"] + == "dlt.common.normalizers.json.relational" + ) assert isinstance(schema.naming, snake_case.NamingConvention) # check if schema was extended by json normalizer assert set( @@ -859,11 +928,15 @@ def test_group_tables_by_resource(schema: Schema) -> None: schema.update_table(utils.new_table("a_events", columns=[])) schema.update_table(utils.new_table("b_events", columns=[])) schema.update_table(utils.new_table("c_products", columns=[], resource="products")) - schema.update_table(utils.new_table("a_events__1", columns=[], parent_table_name="a_events")) + schema.update_table( + utils.new_table("a_events__1", columns=[], parent_table_name="a_events") + ) schema.update_table( utils.new_table("a_events__1__2", columns=[], parent_table_name="a_events__1") ) - schema.update_table(utils.new_table("b_events__1", columns=[], parent_table_name="b_events")) + schema.update_table( + utils.new_table("b_events__1", columns=[], parent_table_name="b_events") + ) # All resources without filter expected_tables = { @@ -882,7 +955,8 @@ def test_group_tables_by_resource(schema: Schema) -> None: # With resource filter result = utils.group_tables_by_resource( - schema.tables, pattern=utils.compile_simple_regex(TSimpleRegex("re:[a-z]_events")) + schema.tables, + pattern=utils.compile_simple_regex(TSimpleRegex("re:[a-z]_events")), ) assert result == { "a_events": [ diff --git a/tests/common/schema/test_schema_contract.py b/tests/common/schema/test_schema_contract.py index 32f9583b26..861a39ebb5 100644 --- a/tests/common/schema/test_schema_contract.py +++ b/tests/common/schema/test_schema_contract.py @@ -29,13 +29,21 @@ def get_schema() -> Schema: s.update_table(cast(TTableSchema, {"name": "tables", "columns": columns})) s.update_table( - cast(TTableSchema, {"name": "child_table", "parent": "tables", "columns": columns}) + cast( + TTableSchema, + {"name": "child_table", "parent": "tables", "columns": columns}, + ) ) - s.update_table(cast(TTableSchema, {"name": "incomplete_table", "columns": incomplete_columns})) + s.update_table( + cast(TTableSchema, {"name": "incomplete_table", "columns": incomplete_columns}) + ) s.update_table( - cast(TTableSchema, {"name": "mixed_table", "columns": {**incomplete_columns, **columns}}) + cast( + TTableSchema, + {"name": "mixed_table", "columns": {**incomplete_columns, **columns}}, + ) ) s.update_table( @@ -55,8 +63,14 @@ def get_schema() -> Schema: def test_resolve_contract_settings() -> None: # defaults schema = get_schema() - assert schema.resolve_contract_settings_for_table("tables") == DEFAULT_SCHEMA_CONTRACT_MODE - assert schema.resolve_contract_settings_for_table("child_table") == DEFAULT_SCHEMA_CONTRACT_MODE + assert ( + schema.resolve_contract_settings_for_table("tables") + == DEFAULT_SCHEMA_CONTRACT_MODE + ) + assert ( + schema.resolve_contract_settings_for_table("child_table") + == DEFAULT_SCHEMA_CONTRACT_MODE + ) # table specific full setting schema = get_schema() @@ -143,7 +157,11 @@ def test_resolve_contract_settings() -> None: base_settings = [ {"tables": "evolve", "columns": "evolve", "data_type": "evolve"}, {"tables": "discard_row", "columns": "discard_row", "data_type": "discard_row"}, - {"tables": "discard_value", "columns": "discard_value", "data_type": "discard_value"}, + { + "tables": "discard_value", + "columns": "discard_value", + "data_type": "discard_value", + }, {"tables": "freeze", "columns": "freeze", "data_type": "freeze"}, ] @@ -162,11 +180,13 @@ def test_check_adding_table(base_settings) -> None: ) assert (partial, filters) == (new_table, []) partial, filters = schema.apply_schema_contract( - cast(TSchemaContractDict, {**base_settings, **{"tables": "discard_row"}}), new_table + cast(TSchemaContractDict, {**base_settings, **{"tables": "discard_row"}}), + new_table, ) assert (partial, filters) == (None, [("tables", "new_table", "discard_row")]) partial, filters = schema.apply_schema_contract( - cast(TSchemaContractDict, {**base_settings, **{"tables": "discard_value"}}), new_table + cast(TSchemaContractDict, {**base_settings, **{"tables": "discard_value"}}), + new_table, ) assert (partial, filters) == (None, [("tables", "new_table", "discard_value")]) partial, filters = schema.apply_schema_contract( @@ -187,7 +207,9 @@ def test_check_adding_table(base_settings) -> None: assert val_ex.value.column_name is None assert val_ex.value.schema_entity == "tables" assert val_ex.value.contract_mode == "freeze" - assert val_ex.value.table_schema is None # there's no validating schema on new table + assert ( + val_ex.value.table_schema is None + ) # there's no validating schema on new table assert val_ex.value.data_item == {"item": 1} @@ -213,7 +235,9 @@ def assert_new_column(table_update: TTableSchema, column_name: str) -> None: [("columns", column_name, "discard_row")], ) partial, filters = schema.apply_schema_contract( - cast(TSchemaContractDict, {**base_settings, **{"columns": "discard_value"}}), + cast( + TSchemaContractDict, {**base_settings, **{"columns": "discard_value"}} + ), copy.deepcopy(table_update), ) assert (partial, filters) == ( @@ -225,7 +249,10 @@ def assert_new_column(table_update: TTableSchema, column_name: str) -> None: copy.deepcopy(table_update), raise_on_freeze=False, ) - assert (partial, filters) == (popped_table_update, [("columns", column_name, "freeze")]) + assert (partial, filters) == ( + popped_table_update, + [("columns", column_name, "freeze")], + ) with pytest.raises(DataValidationError) as val_ex: schema.apply_schema_contract( @@ -275,7 +302,9 @@ def assert_new_column(table_update: TTableSchema, column_name: str) -> None: }, }, } - partial, filters = schema.apply_schema_contract(base_settings, copy.deepcopy(table_update)) + partial, filters = schema.apply_schema_contract( + base_settings, copy.deepcopy(table_update) + ) assert (partial, filters) == (table_update, []) @@ -288,19 +317,29 @@ def test_check_adding_new_variant() -> None: table_update: TTableSchema = { "name": "tables", "columns": { - "column_2_variant": {"name": "column_2_variant", "data_type": "bigint", "variant": True} + "column_2_variant": { + "name": "column_2_variant", + "data_type": "bigint", + "variant": True, + } }, } popped_table_update = copy.deepcopy(table_update) popped_table_update["columns"].pop("column_2_variant") partial, filters = schema.apply_schema_contract( - cast(TSchemaContractDict, {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "evolve"}}), + cast( + TSchemaContractDict, + {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "evolve"}}, + ), copy.deepcopy(table_update), ) assert (partial, filters) == (table_update, []) partial, filters = schema.apply_schema_contract( - cast(TSchemaContractDict, {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "discard_row"}}), + cast( + TSchemaContractDict, + {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "discard_row"}}, + ), copy.deepcopy(table_update), ) assert (partial, filters) == ( @@ -309,7 +348,8 @@ def test_check_adding_new_variant() -> None: ) partial, filters = schema.apply_schema_contract( cast( - TSchemaContractDict, {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "discard_value"}} + TSchemaContractDict, + {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "discard_value"}}, ), copy.deepcopy(table_update), ) @@ -318,15 +358,24 @@ def test_check_adding_new_variant() -> None: [("columns", "column_2_variant", "discard_value")], ) partial, filters = schema.apply_schema_contract( - cast(TSchemaContractDict, {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "freeze"}}), + cast( + TSchemaContractDict, + {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "freeze"}}, + ), copy.deepcopy(table_update), raise_on_freeze=False, ) - assert (partial, filters) == (popped_table_update, [("columns", "column_2_variant", "freeze")]) + assert (partial, filters) == ( + popped_table_update, + [("columns", "column_2_variant", "freeze")], + ) with pytest.raises(DataValidationError) as val_ex: schema.apply_schema_contract( - cast(TSchemaContractDict, {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "freeze"}}), + cast( + TSchemaContractDict, + {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "freeze"}}, + ), copy.deepcopy(table_update), ) assert val_ex.value.schema_name == schema.name @@ -341,7 +390,10 @@ def test_check_adding_new_variant() -> None: partial, filters = schema.apply_schema_contract( cast( TSchemaContractDict, - {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "evolve", "columns": "freeze"}}, + { + **DEFAULT_SCHEMA_CONTRACT_MODE, + **{"data_type": "evolve", "columns": "freeze"}, + }, ), copy.deepcopy(table_update), ) @@ -351,6 +403,9 @@ def test_check_adding_new_variant() -> None: table_update["name"] = "evolve_once_table" with pytest.raises(DataValidationError): schema.apply_schema_contract( - cast(TSchemaContractDict, {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "freeze"}}), + cast( + TSchemaContractDict, + {**DEFAULT_SCHEMA_CONTRACT_MODE, **{"data_type": "freeze"}}, + ), copy.deepcopy(table_update), ) diff --git a/tests/common/schema/test_versioning.py b/tests/common/schema/test_versioning.py index b67b028161..df95d64198 100644 --- a/tests/common/schema/test_versioning.py +++ b/tests/common/schema/test_versioning.py @@ -156,6 +156,10 @@ def test_create_ancestry() -> None: assert schema._stored_version == version + i # we never have more than 10 previous_hashes - assert len(schema._stored_previous_hashes) == i + hash_count if i + hash_count <= 10 else 10 + assert ( + len(schema._stored_previous_hashes) == i + hash_count + if i + hash_count <= 10 + else 10 + ) assert len(schema._stored_previous_hashes) == 10 diff --git a/tests/common/storages/test_file_storage.py b/tests/common/storages/test_file_storage.py index 9f212070e8..dcfb239cfa 100644 --- a/tests/common/storages/test_file_storage.py +++ b/tests/common/storages/test_file_storage.py @@ -8,7 +8,12 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import encoding_for_mode, set_working_dir, uniq_id -from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage, test_storage, skipifnotwindows +from tests.utils import ( + TEST_STORAGE_ROOT, + autouse_test_storage, + test_storage, + skipifnotwindows, +) def test_storage_init(test_storage: FileStorage) -> None: @@ -70,7 +75,9 @@ def test_in_storage(test_storage: FileStorage) -> None: assert test_storage.in_storage(os.curdir) is True assert test_storage.in_storage(os.path.realpath(os.curdir)) is False assert ( - test_storage.in_storage(os.path.join(os.path.realpath(os.curdir), TEST_STORAGE_ROOT)) + test_storage.in_storage( + os.path.join(os.path.realpath(os.curdir), TEST_STORAGE_ROOT) + ) is True ) @@ -130,7 +137,9 @@ def test_validate_file_name_component() -> None: FileStorage.validate_file_name_component("BAN__ANA is allowed") -@pytest.mark.parametrize("action", ("rename_tree_files", "rename_tree", "atomic_rename")) +@pytest.mark.parametrize( + "action", ("rename_tree_files", "rename_tree", "atomic_rename") +) def test_rename_nested_tree(test_storage: FileStorage, action: str) -> None: source_dir = os.path.join(test_storage.storage_path, "source") nested_dir_1 = os.path.join(source_dir, "nested1") diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index d61029c8cf..0ed07e9489 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -10,7 +10,11 @@ from dlt.common.storages import PackageStorage, LoadStorage, ParsedLoadJobFileName from dlt.common.utils import uniq_id -from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage +from tests.common.storages.utils import ( + start_loading_file, + assert_package_info, + load_storage, +) from tests.utils import autouse_test_storage from dlt.common.pendulum import pendulum from dlt.common.configuration.container import Container @@ -44,7 +48,9 @@ def test_is_partially_loaded(load_storage: LoadStorage) -> None: assert PackageStorage.is_package_partially_loaded(info) is False # abort package - load_id, file_name = start_loading_file(load_storage, [{"content": "a"}, {"content": "b"}]) + load_id, file_name = start_loading_file( + load_storage, [{"content": "a"}, {"content": "b"}] + ) load_storage.complete_load_package(load_id, True) info = load_storage.get_load_package_info(load_id) assert PackageStorage.is_package_partially_loaded(info) is True @@ -59,7 +65,9 @@ def test_save_load_schema(load_storage: LoadStorage) -> None: saved_file_name = load_storage.new_packages.save_schema("copy", schema) assert saved_file_name.endswith( os.path.join( - load_storage.new_packages.storage.storage_path, "copy", PackageStorage.SCHEMA_FILE_NAME + load_storage.new_packages.storage.storage_path, + "copy", + PackageStorage.SCHEMA_FILE_NAME, ) ) assert load_storage.new_packages.storage.has_file( @@ -109,15 +117,24 @@ def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None injected_state["state"]["new_key"] = "new_value" # type: ignore # not persisted yet - assert load_storage.new_packages.get_load_package_state("copy").get("new_key") is None + assert ( + load_storage.new_packages.get_load_package_state("copy").get("new_key") + is None + ) # commit commit_load_package_state() # now it should be persisted assert ( - load_storage.new_packages.get_load_package_state("copy").get("new_key") == "new_value" + load_storage.new_packages.get_load_package_state("copy").get("new_key") + == "new_value" + ) + assert ( + load_storage.new_packages.get_load_package_state("copy").get( + "_state_version" + ) + == 1 ) - assert load_storage.new_packages.get_load_package_state("copy").get("_state_version") == 1 # check that second injection is the same as first second_injected_instance = load_package() @@ -125,7 +142,9 @@ def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None # check scoped destination states assert ( - load_storage.new_packages.get_load_package_state("copy").get("destination_state") + load_storage.new_packages.get_load_package_state("copy").get( + "destination_state" + ) is None ) dstate = destination_state() @@ -141,7 +160,9 @@ def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None # clear destination state clear_destination_state() assert ( - load_storage.new_packages.get_load_package_state("copy").get("destination_state") + load_storage.new_packages.get_load_package_state("copy").get( + "destination_state" + ) is None ) diff --git a/tests/common/storages/test_load_storage.py b/tests/common/storages/test_load_storage.py index 0fe112581e..908133f2a2 100644 --- a/tests/common/storages/test_load_storage.py +++ b/tests/common/storages/test_load_storage.py @@ -6,22 +6,34 @@ from dlt.common.storages import PackageStorage, LoadStorage from dlt.common.storages.exceptions import LoadPackageNotFound, NoMigrationPathException -from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage +from tests.common.storages.utils import ( + start_loading_file, + assert_package_info, + load_storage, +) from tests.utils import write_version, autouse_test_storage def test_complete_successful_package(load_storage: LoadStorage) -> None: # should delete package in full load_storage.config.delete_completed_jobs = True - load_id, file_name = start_loading_file(load_storage, [{"content": "a"}, {"content": "b"}]) - assert load_storage.storage.has_folder(load_storage.get_normalized_package_path(load_id)) + load_id, file_name = start_loading_file( + load_storage, [{"content": "a"}, {"content": "b"}] + ) + assert load_storage.storage.has_folder( + load_storage.get_normalized_package_path(load_id) + ) load_storage.normalized_packages.complete_job(load_id, file_name) assert_package_info(load_storage, load_id, "normalized", "completed_jobs") load_storage.complete_load_package(load_id, False) # deleted from loading - assert not load_storage.storage.has_folder(load_storage.get_normalized_package_path(load_id)) + assert not load_storage.storage.has_folder( + load_storage.get_normalized_package_path(load_id) + ) # has package - assert load_storage.storage.has_folder(load_storage.get_loaded_package_path(load_id)) + assert load_storage.storage.has_folder( + load_storage.get_loaded_package_path(load_id) + ) assert load_storage.storage.has_file( os.path.join( load_storage.get_loaded_package_path(load_id), @@ -36,16 +48,24 @@ def test_complete_successful_package(load_storage: LoadStorage) -> None: assert_package_info(load_storage, load_id, "loaded", "completed_jobs", jobs_count=0) # delete completed package load_storage.delete_loaded_package(load_id) - assert not load_storage.storage.has_folder(load_storage.get_loaded_package_path(load_id)) + assert not load_storage.storage.has_folder( + load_storage.get_loaded_package_path(load_id) + ) # do not delete completed jobs load_storage.config.delete_completed_jobs = False - load_id, file_name = start_loading_file(load_storage, [{"content": "a"}, {"content": "b"}]) + load_id, file_name = start_loading_file( + load_storage, [{"content": "a"}, {"content": "b"}] + ) load_storage.normalized_packages.complete_job(load_id, file_name) load_storage.complete_load_package(load_id, False) # deleted from loading - assert not load_storage.storage.has_folder(load_storage.get_normalized_package_path(load_id)) + assert not load_storage.storage.has_folder( + load_storage.get_normalized_package_path(load_id) + ) # has load preserved - assert load_storage.storage.has_folder(load_storage.get_loaded_package_path(load_id)) + assert load_storage.storage.has_folder( + load_storage.get_loaded_package_path(load_id) + ) assert load_storage.storage.has_file( os.path.join( load_storage.get_loaded_package_path(load_id), @@ -57,11 +77,15 @@ def test_complete_successful_package(load_storage: LoadStorage) -> None: load_storage.loaded_packages.get_job_folder_path(load_id, "completed_jobs") ) load_storage.delete_loaded_package(load_id) - assert not load_storage.storage.has_folder(load_storage.get_loaded_package_path(load_id)) + assert not load_storage.storage.has_folder( + load_storage.get_loaded_package_path(load_id) + ) def test_wipe_normalized_packages(load_storage: LoadStorage) -> None: - load_id, file_name = start_loading_file(load_storage, [{"content": "a"}, {"content": "b"}]) + load_id, file_name = start_loading_file( + load_storage, [{"content": "a"}, {"content": "b"}] + ) load_storage.wipe_normalized_packages() assert not load_storage.storage.has_folder(load_storage.NORMALIZED_FOLDER) @@ -69,15 +93,23 @@ def test_wipe_normalized_packages(load_storage: LoadStorage) -> None: def test_complete_package_failed_jobs(load_storage: LoadStorage) -> None: # loads with failed jobs are always persisted load_storage.config.delete_completed_jobs = True - load_id, file_name = start_loading_file(load_storage, [{"content": "a"}, {"content": "b"}]) - assert load_storage.storage.has_folder(load_storage.get_normalized_package_path(load_id)) + load_id, file_name = start_loading_file( + load_storage, [{"content": "a"}, {"content": "b"}] + ) + assert load_storage.storage.has_folder( + load_storage.get_normalized_package_path(load_id) + ) load_storage.normalized_packages.fail_job(load_id, file_name, "EXCEPTION") assert_package_info(load_storage, load_id, "normalized", "failed_jobs") load_storage.complete_load_package(load_id, False) # deleted from loading - assert not load_storage.storage.has_folder(load_storage.get_normalized_package_path(load_id)) + assert not load_storage.storage.has_folder( + load_storage.get_normalized_package_path(load_id) + ) # present in completed loads folder - assert load_storage.storage.has_folder(load_storage.get_loaded_package_path(load_id)) + assert load_storage.storage.has_folder( + load_storage.get_loaded_package_path(load_id) + ) # has completed loads assert load_storage.loaded_packages.storage.has_folder( load_storage.loaded_packages.get_job_folder_path(load_id, "completed_jobs") @@ -90,9 +122,9 @@ def test_complete_package_failed_jobs(load_storage: LoadStorage) -> None: assert len(failed_files) == 2 assert load_storage.loaded_packages.storage.has_file(failed_files[0]) failed_info = load_storage.list_failed_jobs_in_loaded_package(load_id) - assert failed_info[0].file_path == load_storage.loaded_packages.storage.make_full_path( - failed_files[0] - ) + assert failed_info[ + 0 + ].file_path == load_storage.loaded_packages.storage.make_full_path(failed_files[0]) assert failed_info[0].failed_message == "EXCEPTION" assert failed_info[0].job_file_info.table_name == "mock_table" # a few stats @@ -109,8 +141,12 @@ def test_complete_package_failed_jobs(load_storage: LoadStorage) -> None: def test_abort_package(load_storage: LoadStorage) -> None: # loads with failed jobs are always persisted load_storage.config.delete_completed_jobs = True - load_id, file_name = start_loading_file(load_storage, [{"content": "a"}, {"content": "b"}]) - assert load_storage.storage.has_folder(load_storage.get_normalized_package_path(load_id)) + load_id, file_name = start_loading_file( + load_storage, [{"content": "a"}, {"content": "b"}] + ) + assert load_storage.storage.has_folder( + load_storage.get_normalized_package_path(load_id) + ) load_storage.normalized_packages.fail_job(load_id, file_name, "EXCEPTION") assert_package_info(load_storage, load_id, "normalized", "failed_jobs") load_storage.complete_load_package(load_id, True) @@ -140,7 +176,9 @@ def test_process_schema_update(load_storage: LoadStorage) -> None: assert load_storage.storage.has_file(applied_update_path) is True assert json.loads(load_storage.storage.load(applied_update_path)) == applied_update # verify info package - package_info = assert_package_info(load_storage, load_id, "normalized", "started_jobs") + package_info = assert_package_info( + load_storage, load_id, "normalized", "started_jobs" + ) # applied update is present assert package_info.schema_update == applied_update # should be in dict diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index 6cb76fba9d..b789597f2d 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -106,7 +106,9 @@ def test_import_overwrites_existing_if_modified( assert_schema_imported(synced_storage, storage) -def test_skip_import_if_not_modified(synced_storage: SchemaStorage, storage: SchemaStorage) -> None: +def test_skip_import_if_not_modified( + synced_storage: SchemaStorage, storage: SchemaStorage +) -> None: storage_schema = assert_schema_imported(synced_storage, storage) assert not storage_schema.is_modified initial_version = storage_schema.stored_version @@ -125,13 +127,17 @@ def test_skip_import_if_not_modified(synced_storage: SchemaStorage, storage: Sch assert "event_user" in reloaded_schema.tables assert storage_schema.version == reloaded_schema.stored_version assert storage_schema.version_hash == reloaded_schema.stored_version_hash - assert storage_schema._imported_version_hash == reloaded_schema._imported_version_hash + assert ( + storage_schema._imported_version_hash == reloaded_schema._imported_version_hash + ) assert storage_schema.previous_hashes == reloaded_schema.previous_hashes # the import schema gets modified storage_schema.tables["_dlt_loads"]["write_disposition"] = "append" storage_schema.tables.pop("event_user") # we save the import schema (using export method) - synced_storage._export_schema(storage_schema, synced_storage.config.export_schema_path) + synced_storage._export_schema( + storage_schema, synced_storage.config.export_schema_path + ) # now load will import again reloaded_schema = synced_storage.load_schema("ethereum") # we have overwritten storage schema @@ -143,10 +149,14 @@ def test_skip_import_if_not_modified(synced_storage: SchemaStorage, storage: Sch assert storage_schema.previous_hashes == reloaded_schema.previous_hashes # but original version has increased twice (because it was modified twice) - assert reloaded_schema.stored_version == storage_schema.version == initial_version + 2 + assert ( + reloaded_schema.stored_version == storage_schema.version == initial_version + 2 + ) -def test_store_schema_tampered(synced_storage: SchemaStorage, storage: SchemaStorage) -> None: +def test_store_schema_tampered( + synced_storage: SchemaStorage, storage: SchemaStorage +) -> None: storage_schema = assert_schema_imported(synced_storage, storage) # break hash stored_schema = storage_schema.to_dict() @@ -247,7 +257,9 @@ def test_getter_with_import(ie_storage: SchemaStorage) -> None: assert not schema.is_modified # now load via getter schema_copy = ie_storage["ethereum"] - assert schema_copy.version_hash == schema_copy.stored_version_hash == mod_version_hash + assert ( + schema_copy.version_hash == schema_copy.stored_version_hash == mod_version_hash + ) assert schema_copy._imported_version_hash == version_hash # now save the schema as import @@ -334,14 +346,18 @@ def test_schema_from_file() -> None: assert schema.name == "event" schema = SchemaStorage.load_schema_file( - os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "event", extensions=("yaml",) + os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), + "event", + extensions=("yaml",), ) assert schema.name == "event" assert "blocks" in schema.tables with pytest.raises(SchemaNotFoundError): SchemaStorage.load_schema_file( - os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), "eth", extensions=("yaml",) + os.path.join(COMMON_TEST_CASES_PATH, "schemas/local"), + "eth", + extensions=("yaml",), ) # file name and schema content mismatch @@ -481,7 +497,9 @@ def prepare_import_folder(storage: SchemaStorage) -> None: ) -def assert_schema_imported(synced_storage: SchemaStorage, storage: SchemaStorage) -> Schema: +def assert_schema_imported( + synced_storage: SchemaStorage, storage: SchemaStorage +) -> Schema: prepare_import_folder(synced_storage) eth_V9: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v9") schema = synced_storage.load_schema("ethereum") diff --git a/tests/common/storages/test_transactional_file.py b/tests/common/storages/test_transactional_file.py index 7afdf10c38..14cb7a6d9b 100644 --- a/tests/common/storages/test_transactional_file.py +++ b/tests/common/storages/test_transactional_file.py @@ -69,7 +69,9 @@ def test_file_transaction_no_content(fs: fsspec.AbstractFileSystem, file_name: s file.release_lock() -def test_file_transaction_multiple_writers(fs: fsspec.AbstractFileSystem, file_name: str): +def test_file_transaction_multiple_writers( + fs: fsspec.AbstractFileSystem, file_name: str +): writer_1 = TransactionalFile(file_name, fs) writer_2 = TransactionalFile(file_name, fs) writer_3 = TransactionalFile(file_name, fs) @@ -139,7 +141,9 @@ def test_file_transaction_simultaneous(fs: fsspec.AbstractFileSystem): assert sum(results) == 1 -def test_file_transaction_ttl_expiry(fs: fsspec.AbstractFileSystem, monkeypatch, file_name: str): +def test_file_transaction_ttl_expiry( + fs: fsspec.AbstractFileSystem, monkeypatch, file_name: str +): monkeypatch.setattr(TransactionalFile, "LOCK_TTL_SECONDS", 1) writer_1 = TransactionalFile(file_name, fs) writer_2 = TransactionalFile(file_name, fs) @@ -153,13 +157,17 @@ def test_file_transaction_ttl_expiry(fs: fsspec.AbstractFileSystem, monkeypatch, @skipifwindows -def test_file_transaction_maintain_lock(fs: fsspec.AbstractFileSystem, monkeypatch, file_name: str): +def test_file_transaction_maintain_lock( + fs: fsspec.AbstractFileSystem, monkeypatch, file_name: str +): monkeypatch.setattr(TransactionalFile, "LOCK_TTL_SECONDS", 1) writer_1 = TransactionalFile(file_name, fs) writer_2 = TransactionalFile(file_name, fs) writer_1.acquire_lock() - thread = Thread(target=functools.partial(writer_2.acquire_lock, timeout=5), daemon=True) + thread = Thread( + target=functools.partial(writer_2.acquire_lock, timeout=5), daemon=True + ) try: thread.start() time.sleep(2.5) diff --git a/tests/common/storages/test_versioned_storage.py b/tests/common/storages/test_versioned_storage.py index 2859c7662c..0a16099ce2 100644 --- a/tests/common/storages/test_versioned_storage.py +++ b/tests/common/storages/test_versioned_storage.py @@ -2,7 +2,10 @@ import semver from dlt.common.storages.file_storage import FileStorage -from dlt.common.storages.exceptions import NoMigrationPathException, WrongStorageVersionException +from dlt.common.storages.exceptions import ( + NoMigrationPathException, + WrongStorageVersionException, +) from dlt.common.storages.versioned_storage import VersionedStorage from tests.utils import write_version, test_storage diff --git a/tests/common/test_arithmetics.py b/tests/common/test_arithmetics.py index 87c0a94751..06014ec8d3 100644 --- a/tests/common/test_arithmetics.py +++ b/tests/common/test_arithmetics.py @@ -1,6 +1,10 @@ import pytest from dlt.common import Decimal -from dlt.common.arithmetics import numeric_default_context, numeric_default_quantize, Inexact +from dlt.common.arithmetics import ( + numeric_default_context, + numeric_default_quantize, + Inexact, +) def test_default_numeric_quantize() -> None: diff --git a/tests/common/test_destination.py b/tests/common/test_destination.py index 24b0928463..d9054b6176 100644 --- a/tests/common/test_destination.py +++ b/tests/common/test_destination.py @@ -1,8 +1,14 @@ import pytest -from dlt.common.destination.reference import DestinationClientDwhConfiguration, Destination +from dlt.common.destination.reference import ( + DestinationClientDwhConfiguration, + Destination, +) from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.exceptions import InvalidDestinationReference, UnknownDestinationModule +from dlt.common.destination.exceptions import ( + InvalidDestinationReference, + UnknownDestinationModule, +) from dlt.common.schema import Schema from tests.utils import ACTIVE_DESTINATIONS @@ -19,7 +25,9 @@ def test_import_unknown_destination() -> None: def test_invalid_destination_reference() -> None: with pytest.raises(InvalidDestinationReference): - Destination.from_reference("tests.load.cases.fake_destination.not_a_destination") + Destination.from_reference( + "tests.load.cases.fake_destination.not_a_destination" + ) def test_custom_destination_module() -> None: @@ -75,7 +83,9 @@ def dest_callable(items, table) -> None: def test_import_destination_config() -> None: # importing destination by type will work - dest = Destination.from_reference(ref="dlt.destinations.duckdb", environment="stage") + dest = Destination.from_reference( + ref="dlt.destinations.duckdb", environment="stage" + ) assert dest.destination_type == "dlt.destinations.duckdb" assert dest.config_params["environment"] == "stage" config = dest.configuration(dest.spec()._bind_dataset_name(dataset_name="dataset")) # type: ignore @@ -84,7 +94,9 @@ def test_import_destination_config() -> None: assert config.environment == "stage" # importing destination by will work - dest = Destination.from_reference(ref=None, destination_name="duckdb", environment="production") + dest = Destination.from_reference( + ref=None, destination_name="duckdb", environment="production" + ) assert dest.destination_type == "dlt.destinations.duckdb" assert dest.config_params["environment"] == "production" config = dest.configuration(dest.spec()._bind_dataset_name(dataset_name="dataset")) # type: ignore @@ -113,14 +125,18 @@ def test_normalize_dataset_name() -> None: assert ( DestinationClientDwhConfiguration() - ._bind_dataset_name(dataset_name="ban_ana_dataset", default_schema_name="default") + ._bind_dataset_name( + dataset_name="ban_ana_dataset", default_schema_name="default" + ) .normalize_dataset_name(Schema("banana")) == "ban_ana_dataset_banana" ) # without schema name appended assert ( DestinationClientDwhConfiguration() - ._bind_dataset_name(dataset_name="ban_ana_dataset", default_schema_name="default") + ._bind_dataset_name( + dataset_name="ban_ana_dataset", default_schema_name="default" + ) .normalize_dataset_name(Schema("default")) == "ban_ana_dataset" ) diff --git a/tests/common/test_git.py b/tests/common/test_git.py index 10bc05970e..e30f6133f0 100644 --- a/tests/common/test_git.py +++ b/tests/common/test_git.py @@ -14,7 +14,11 @@ ) from tests.utils import test_storage, skipifwindows -from tests.common.utils import load_secret, modify_and_commit_file, restore_secret_storage_path +from tests.common.utils import ( + load_secret, + modify_and_commit_file, + restore_secret_storage_path, +) AWESOME_REPO = "https://github.com/sindresorhus/awesome.git" @@ -66,7 +70,9 @@ def test_clone_with_wrong_branch(test_storage: FileStorage) -> None: repo_path = test_storage.make_full_path("awesome_repo") # clone a small public repo with pytest.raises(GitCommandError): - clone_repo(AWESOME_REPO, repo_path, with_git_command=None, branch="wrong_branch") + clone_repo( + AWESOME_REPO, repo_path, with_git_command=None, branch="wrong_branch" + ) def test_clone_with_deploy_key_access_denied(test_storage: FileStorage) -> None: @@ -82,7 +88,9 @@ def test_clone_with_deploy_key(test_storage: FileStorage) -> None: secret = load_secret("deploy_key") repo_path = test_storage.make_full_path("private_repo_access") with git_custom_key_command(secret) as git_command: - clone_repo(PRIVATE_REPO_WITH_ACCESS, repo_path, with_git_command=git_command).close() + clone_repo( + PRIVATE_REPO_WITH_ACCESS, repo_path, with_git_command=git_command + ).close() ensure_remote_head(repo_path, with_git_command=git_command) @@ -91,7 +99,9 @@ def test_repo_status_update(test_storage: FileStorage) -> None: secret = load_secret("deploy_key") repo_path = test_storage.make_full_path("private_repo_access") with git_custom_key_command(secret) as git_command: - clone_repo(PRIVATE_REPO_WITH_ACCESS, repo_path, with_git_command=git_command).close() + clone_repo( + PRIVATE_REPO_WITH_ACCESS, repo_path, with_git_command=git_command + ).close() # modify README.md readme_path, _ = modify_and_commit_file(repo_path, "README.md") assert test_storage.has_file(readme_path) @@ -100,13 +110,17 @@ def test_repo_status_update(test_storage: FileStorage) -> None: def test_fresh_repo_files_branch_change(test_storage: FileStorage) -> None: - repo_storage = get_fresh_repo_files(AWESOME_REPO, test_storage.storage_path, branch="gh-pages") + repo_storage = get_fresh_repo_files( + AWESOME_REPO, test_storage.storage_path, branch="gh-pages" + ) with get_repo(repo_storage.storage_path) as repo: assert repo.active_branch.name == "gh-pages" assert not is_dirty(repo) assert is_clean_and_synced(repo) # change to main - repo_storage = get_fresh_repo_files(AWESOME_REPO, test_storage.storage_path, branch="main") + repo_storage = get_fresh_repo_files( + AWESOME_REPO, test_storage.storage_path, branch="main" + ) with get_repo(repo_storage.storage_path) as repo: assert repo.active_branch.name == "main" assert not is_dirty(repo) @@ -114,7 +128,9 @@ def test_fresh_repo_files_branch_change(test_storage: FileStorage) -> None: def test_refresh_repo_files_local_mod(test_storage: FileStorage) -> None: - repo_storage = get_fresh_repo_files(AWESOME_REPO, test_storage.storage_path, branch="main") + repo_storage = get_fresh_repo_files( + AWESOME_REPO, test_storage.storage_path, branch="main" + ) with get_repo(repo_storage.storage_path) as repo: origin_head_sha = repo.head.commit.hexsha repo_storage.save("addition.py", "# new file") @@ -131,7 +147,9 @@ def test_refresh_repo_files_local_mod(test_storage: FileStorage) -> None: # we are in non-synced state, folder does not refresh not clean assert commit.hexsha == repo.head.commit.hexsha # this should reset to the origin - repo_storage = get_fresh_repo_files(AWESOME_REPO, test_storage.storage_path, branch="main") + repo_storage = get_fresh_repo_files( + AWESOME_REPO, test_storage.storage_path, branch="main" + ) with get_repo(repo_storage.storage_path) as repo: assert origin_head_sha == repo.head.commit.hexsha diff --git a/tests/common/test_json.py b/tests/common/test_json.py index 79037ebf93..0ed6c235f4 100644 --- a/tests/common/test_json.py +++ b/tests/common/test_json.py @@ -225,7 +225,8 @@ def test_json_named_tuple(json_impl: SupportsJson) -> None: with io.BytesIO() as b: json_impl.typed_dump(NamedTupleTest("STR", Decimal("1.3333")), b) assert ( - b.getvalue().decode("utf-8") == '{"str_field":"STR","dec_field":"%s1.3333"}' % _DECIMAL + b.getvalue().decode("utf-8") + == '{"str_field":"STR","dec_field":"%s1.3333"}' % _DECIMAL ) diff --git a/tests/common/test_pipeline_state.py b/tests/common/test_pipeline_state.py index 2c6a89b978..024c790faf 100644 --- a/tests/common/test_pipeline_state.py +++ b/tests/common/test_pipeline_state.py @@ -57,7 +57,9 @@ def test_get_matching_resources() -> None: assert sorted(results) == ["events_a", "events_b"] # with state context - with mock.patch.object(ps, "source_state", autospec=True, return_value=_fake_source_state): + with mock.patch.object( + ps, "source_state", autospec=True, return_value=_fake_source_state + ): results = ps._get_matching_resources(pattern, _fake_source_state) assert sorted(results) == ["events_a", "events_b"] diff --git a/tests/common/test_time.py b/tests/common/test_time.py index 7568e84046..15357696a3 100644 --- a/tests/common/test_time.py +++ b/tests/common/test_time.py @@ -14,7 +14,10 @@ def test_timestamp_within() -> None: - assert timestamp_within(1643470504.782716, 1643470504.782716, 1643470504.782716) is False + assert ( + timestamp_within(1643470504.782716, 1643470504.782716, 1643470504.782716) + is False + ) # true for all timestamps assert timestamp_within(1643470504.782716, None, None) is True # upper bound inclusive @@ -22,9 +25,15 @@ def test_timestamp_within() -> None: # lower bound exclusive assert timestamp_within(1643470504.782716, 1643470504.782716, None) is False assert timestamp_within(1643470504.782716, 1643470504.782715, None) is True - assert timestamp_within(1643470504.782716, 1643470504.782715, 1643470504.782716) is True + assert ( + timestamp_within(1643470504.782716, 1643470504.782715, 1643470504.782716) + is True + ) # typical case - assert timestamp_within(1643470504.782716, 1543470504.782716, 1643570504.782716) is True + assert ( + timestamp_within(1643470504.782716, 1543470504.782716, 1643570504.782716) + is True + ) def test_before() -> None: @@ -79,11 +88,15 @@ def test_before() -> None: def test_parse_iso_like_datetime() -> None: # naive datetime is still naive - assert parse_iso_like_datetime("2021-01-01T05:02:32") == pendulum.DateTime(2021, 1, 1, 5, 2, 32) + assert parse_iso_like_datetime("2021-01-01T05:02:32") == pendulum.DateTime( + 2021, 1, 1, 5, 2, 32 + ) @pytest.mark.parametrize("date_value, expected", test_params) -def test_ensure_pendulum_datetime(date_value: TAnyDateTime, expected: pendulum.DateTime) -> None: +def test_ensure_pendulum_datetime( + date_value: TAnyDateTime, expected: pendulum.DateTime +) -> None: dt = ensure_pendulum_datetime(date_value) assert dt == expected # always UTC @@ -94,7 +107,9 @@ def test_ensure_pendulum_datetime(date_value: TAnyDateTime, expected: pendulum.D def test_ensure_pendulum_date_utc() -> None: # when converting from datetimes make sure to shift to UTC before doing date - assert ensure_pendulum_date("2021-01-01T00:00:00+05:00") == pendulum.date(2020, 12, 31) + assert ensure_pendulum_date("2021-01-01T00:00:00+05:00") == pendulum.date( + 2020, 12, 31 + ) assert ensure_pendulum_date( datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=8))) ) == pendulum.date(2020, 12, 31) diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py index 229ce17085..95a16a2ce7 100644 --- a/tests/common/test_utils.py +++ b/tests/common/test_utils.py @@ -40,7 +40,11 @@ def test_digest128_length() -> None: def test_map_dicts_in_place() -> None: - _d = {"a": "1", "b": ["a", "b", ["a", "b"], {"a": "c"}], "c": {"d": "e", "e": ["a", 2]}} + _d = { + "a": "1", + "b": ["a", "b", ["a", "b"], {"a": "c"}], + "c": {"d": "e", "e": ["a", 2]}, + } exp_d = { "a": "11", "b": ["aa", "bb", ["aa", "bb"], {"a": "cc"}], @@ -80,7 +84,9 @@ def test_get_module_name() -> None: assert get_module_name(m) == "uniq_mod_121" # use exec to get __main__ exception - mod_name = Venv.restore_current().run_script("tests/common/cases/modules/uniq_mod_121.py") + mod_name = Venv.restore_current().run_script( + "tests/common/cases/modules/uniq_mod_121.py" + ) assert mod_name.strip() == "uniq_mod_121" @@ -88,7 +94,9 @@ def test_concat_strings_with_limit() -> None: assert list(concat_strings_with_limit([], " ", 15)) == [] philosopher = ["Bertrand Russell"] - assert list(concat_strings_with_limit(philosopher, ";\n", 15)) == ["Bertrand Russell"] + assert list(concat_strings_with_limit(philosopher, ";\n", 15)) == [ + "Bertrand Russell" + ] # only two strings will be merged (22 chars total) philosophers = [ @@ -116,13 +124,23 @@ def test_concat_strings_with_limit() -> None: # again 1 assert list(concat_strings_with_limit(philosophers, ";\n", 23)) == moore_merged_2 # all merged - assert list(concat_strings_with_limit(philosophers, ";\n", 1024)) == [";\n".join(philosophers)] + assert list(concat_strings_with_limit(philosophers, ";\n", 1024)) == [ + ";\n".join(philosophers) + ] # none will be merged, all below limit assert list(concat_strings_with_limit(philosophers, ";\n", 1)) == philosophers def test_find_scc_nodes() -> None: - edges = [("A", "B"), ("B", "C"), ("D", "E"), ("F", "G"), ("G", "H"), ("I", "I"), ("J", "J")] + edges = [ + ("A", "B"), + ("B", "C"), + ("D", "E"), + ("F", "G"), + ("G", "H"), + ("I", "I"), + ("J", "J"), + ] def _comp(s): return sorted([tuple(sorted(c)) for c in s]) @@ -211,7 +229,9 @@ def test_merge_row_counts() -> None: def test_extend_list_deduplicated() -> None: - assert extend_list_deduplicated(["one", "two", "three"], ["four", "five", "six"]) == [ + assert extend_list_deduplicated( + ["one", "two", "three"], ["four", "five", "six"] + ) == [ "one", "two", "three", @@ -222,12 +242,18 @@ def test_extend_list_deduplicated() -> None: assert extend_list_deduplicated( ["one", "two", "three", "six"], ["two", "four", "five", "six"] ) == ["one", "two", "three", "six", "four", "five"] - assert extend_list_deduplicated(["one", "two", "three"], ["one", "two", "three"]) == [ + assert extend_list_deduplicated( + ["one", "two", "three"], ["one", "two", "three"] + ) == [ + "one", + "two", + "three", + ] + assert extend_list_deduplicated([], ["one", "two", "three"]) == [ "one", "two", "three", ] - assert extend_list_deduplicated([], ["one", "two", "three"]) == ["one", "two", "three"] def test_exception_traces() -> None: @@ -245,7 +271,10 @@ def test_exception_traces() -> None: raise IdentifierTooLongException("postgres", "table", "too_long_table", 8) except Exception as exc: trace = get_exception_trace(exc) - assert trace["exception_type"] == "dlt.common.destination.exceptions.IdentifierTooLongException" + assert ( + trace["exception_type"] + == "dlt.common.destination.exceptions.IdentifierTooLongException" + ) assert isinstance(trace["stack_trace"], list) assert trace["exception_attrs"] == { "destination_name": "postgres", @@ -293,6 +322,6 @@ def test_nested_dict_merge() -> None: assert update_dict_nested(dict(dict_1), dict_2) == {"a": 2, "b": 2, "c": 4} assert update_dict_nested(dict(dict_2), dict_1) == {"a": 1, "b": 2, "c": 4} - assert update_dict_nested(dict(dict_1), dict_2, keep_dst_values=True) == update_dict_nested( - dict_2, dict_1 - ) + assert update_dict_nested( + dict(dict_1), dict_2, keep_dst_values=True + ) == update_dict_nested(dict_2, dict_1) diff --git a/tests/common/test_validation.py b/tests/common/test_validation.py index 3297df1038..18cf2354e4 100644 --- a/tests/common/test_validation.py +++ b/tests/common/test_validation.py @@ -1,7 +1,17 @@ from copy import deepcopy import pytest import yaml -from typing import Callable, List, Literal, Mapping, Sequence, TypedDict, TypeVar, Optional, Union +from typing import ( + Callable, + List, + Literal, + Mapping, + Sequence, + TypedDict, + TypeVar, + Optional, + Union, +) from dlt.common import Decimal from dlt.common.exceptions import DictValidationException @@ -91,7 +101,9 @@ def test_doc() -> TTestRecord: def test_validate_schema_cases() -> None: with open( - "tests/common/cases/schemas/eth/ethereum_schema_v8.yml", mode="r", encoding="utf-8" + "tests/common/cases/schemas/eth/ethereum_schema_v8.yml", + mode="r", + encoding="utf-8", ) as f: schema_dict: TStoredSchema = yaml.safe_load(f) @@ -282,7 +294,10 @@ def f(item: Union[TDataItem, TDynHintType]) -> TDynHintType: test_item = {"prop": f} validate_dict( - TTestRecordCallable, test_item, path=".", validator_f=lambda p, pk, pv, t: callable(pv) + TTestRecordCallable, + test_item, + path=".", + validator_f=lambda p, pk, pv, t: callable(pv), ) diff --git a/tests/common/test_versioned_state.py b/tests/common/test_versioned_state.py index e1f31a8a92..18d2d72123 100644 --- a/tests/common/test_versioned_state.py +++ b/tests/common/test_versioned_state.py @@ -19,7 +19,9 @@ def test_versioned_state() -> None: # change attr, but exclude while generating state["foo"] = "bar" # type: ignore - version, hash_, previous_hash = bump_state_version_if_modified(state, exclude_attrs=["foo"]) + version, hash_, previous_hash = bump_state_version_if_modified( + state, exclude_attrs=["foo"] + ) assert version == 0 assert hash_ == previous_hash diff --git a/tests/common/test_wei.py b/tests/common/test_wei.py index 1f15978ddc..cd5a2d79b6 100644 --- a/tests/common/test_wei.py +++ b/tests/common/test_wei.py @@ -32,13 +32,17 @@ def test_wei_variant() -> None: assert callable(Wei(1)) # we get variant value when we call Wei - assert Wei(578960446186580977117854925043439539266)() == 578960446186580977117854925043439539266 + assert ( + Wei(578960446186580977117854925043439539266)() + == 578960446186580977117854925043439539266 + ) assert Wei(578960446186580977117854925043439539267)() == ( "str", "578960446186580977117854925043439539267", ) assert ( - Wei(-578960446186580977117854925043439539267)() == -578960446186580977117854925043439539267 + Wei(-578960446186580977117854925043439539267)() + == -578960446186580977117854925043439539267 ) assert Wei(-578960446186580977117854925043439539268)() == ( "str", diff --git a/tests/common/utils.py b/tests/common/utils.py index a234937e56..02f4b61fe1 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -18,9 +18,7 @@ # for import schema tests, change when upgrading the schema version IMPORTED_VERSION_HASH_ETH_V9 = "PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4=" # test sentry DSN -TEST_SENTRY_DSN = ( - "https://797678dd0af64b96937435326c7d30c1@o1061158.ingest.sentry.io/4504306172821504" -) +TEST_SENTRY_DSN = "https://797678dd0af64b96937435326c7d30c1@o1061158.ingest.sentry.io/4504306172821504" # preserve secrets path to be able to restore it SECRET_STORAGE_PATH = environ_provider.SECRET_STORAGE_PATH @@ -74,7 +72,9 @@ def modify_and_commit_file( # one file modified index = repo.index.entries assert len(index) > 0 - assert any(e for e in index.keys() if os.path.join(*Path(e[0]).parts) == file_name) + assert any( + e for e in index.keys() if os.path.join(*Path(e[0]).parts) == file_name + ) repo.index.add(file_name) commit = repo.index.commit(f"mod {file_name}") diff --git a/tests/conftest.py b/tests/conftest.py index ab22d6ca7a..dc219db588 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,8 +60,8 @@ def pytest_configure(config): storage_configuration.LoadStorageConfiguration, init=True, repr=False ) - storage_configuration.NormalizeStorageConfiguration.normalize_volume_path = os.path.join( - test_storage_root, "normalize" + storage_configuration.NormalizeStorageConfiguration.normalize_volume_path = ( + os.path.join(test_storage_root, "normalize") ) # delete __init__, otherwise it will not be recreated by dataclass delattr(storage_configuration.NormalizeStorageConfiguration, "__init__") @@ -80,8 +80,9 @@ def pytest_configure(config): assert run_configuration.RunConfiguration.config_files_storage_path == os.path.join( test_storage_root, "config/" ) - assert run_configuration.RunConfiguration().config_files_storage_path == os.path.join( - test_storage_root, "config/" + assert ( + run_configuration.RunConfiguration().config_files_storage_path + == os.path.join(test_storage_root, "config/") ) # path pipeline instance id up to millisecond @@ -96,7 +97,12 @@ def _create_pipeline_instance_id(self) -> str: # os.environ["RUNTIME__SENTRY_DSN"] = "https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752" # disable sqlfluff logging - for log in ["sqlfluff.parser", "sqlfluff.linter", "sqlfluff.templater", "sqlfluff.lexer"]: + for log in [ + "sqlfluff.parser", + "sqlfluff.linter", + "sqlfluff.templater", + "sqlfluff.lexer", + ]: logging.getLogger(log).setLevel("ERROR") # disable snowflake logging diff --git a/tests/destinations/test_custom_destination.py b/tests/destinations/test_custom_destination.py index cfefceac88..40af3164c8 100644 --- a/tests/destinations/test_custom_destination.py +++ b/tests/destinations/test_custom_destination.py @@ -19,7 +19,9 @@ from dlt.common.configuration.specs import BaseConfiguration from dlt.destinations.impl.destination.factory import _DESTINATIONS -from dlt.destinations.impl.destination.configuration import CustomDestinationClientConfiguration +from dlt.destinations.impl.destination.configuration import ( + CustomDestinationClientConfiguration, +) from dlt.pipeline.exceptions import PipelineStepFailed from tests.load.utils import ( @@ -131,7 +133,9 @@ def test_instantiation() -> None: calls: List[Tuple[TDataItems, TTableSchema]] = [] # NOTE: we also test injection of config vars here - def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.value, /) -> None: + def local_sink_func( + items: TDataItems, table: TTableSchema, my_val=dlt.config.value, / + ) -> None: nonlocal calls assert my_val == "something" calls.append((items, table)) @@ -150,7 +154,9 @@ def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.va calls = [] p = dlt.pipeline( "sink_test", - destination=Destination.from_reference("destination", destination_callable=local_sink_func), + destination=Destination.from_reference( + "destination", destination_callable=local_sink_func + ), full_refresh=True, ) p.run([1, 2, 3], table_name="items") @@ -165,7 +171,9 @@ def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.va "sink_test", destination=Destination.from_reference( "destination", - destination_callable="tests.destinations.test_custom_destination.global_sink_func", + destination_callable=( + "tests.destinations.test_custom_destination.global_sink_func" + ), ), full_refresh=True, ) @@ -174,14 +182,18 @@ def local_sink_func(items: TDataItems, table: TTableSchema, my_val=dlt.config.va # global func will create an entry assert _DESTINATIONS["global_sink_func"] - assert issubclass(_DESTINATIONS["global_sink_func"][0], CustomDestinationClientConfiguration) + assert issubclass( + _DESTINATIONS["global_sink_func"][0], CustomDestinationClientConfiguration + ) assert _DESTINATIONS["global_sink_func"][1] == global_sink_func assert _DESTINATIONS["global_sink_func"][2] == inspect.getmodule(global_sink_func) # pass None as callable arg will fail on load p = dlt.pipeline( "sink_test", - destination=Destination.from_reference("destination", destination_callable=None), + destination=Destination.from_reference( + "destination", destination_callable=None + ), full_refresh=True, ) with pytest.raises(PipelineStepFailed): @@ -213,7 +225,9 @@ def simple_decorator_sink(items, table, my_val=dlt.config.value): @pytest.mark.parametrize("loader_file_format", SUPPORTED_LOADER_FORMATS) @pytest.mark.parametrize("batch_size", [1, 10, 23]) -def test_batched_transactions(loader_file_format: TLoaderFileFormat, batch_size: int) -> None: +def test_batched_transactions( + loader_file_format: TLoaderFileFormat, batch_size: int +) -> None: calls: Dict[str, List[TDataItems]] = {} # provoke errors on resources provoke_error: Dict[str, int] = {} @@ -345,7 +359,9 @@ def direct_sink(items, table): assert table["columns"]["snake_case"]["name"] == "snake_case" assert table["columns"]["camelCase"]["name"] == "camelCase" - dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run(resource()) + dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run( + resource() + ) def test_file_batch() -> None: @@ -366,7 +382,9 @@ def direct_sink(file_path, table): assert table["name"] in ["person", "address"] with pyarrow.parquet.ParquetFile(file_path) as reader: - assert reader.metadata.num_rows == (100 if table["name"] == "person" else 50) + assert reader.metadata.num_rows == ( + 100 if table["name"] == "person" else 50 + ) dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run( [resource1(), resource2()] @@ -389,9 +407,9 @@ def my_sink(file_path, table, my_val=dlt.config.value): ) # we may give the value via __callable__ function - dlt.pipeline("sink_test", destination=my_sink(my_val="something"), full_refresh=True).run( - [1, 2, 3], table_name="items" - ) + dlt.pipeline( + "sink_test", destination=my_sink(my_val="something"), full_refresh=True + ).run([1, 2, 3], table_name="items") # right value will pass os.environ["DESTINATION__MY_SINK__MY_VAL"] = "something" @@ -471,16 +489,16 @@ def sink_func_with_spec( # call fails because `my_predefined_val` is required part of spec, even if not injected with pytest.raises(ConfigFieldMissingException): - info = dlt.pipeline("sink_test", destination=sink_func_with_spec(), full_refresh=True).run( - [1, 2, 3], table_name="items" - ) + info = dlt.pipeline( + "sink_test", destination=sink_func_with_spec(), full_refresh=True + ).run([1, 2, 3], table_name="items") info.raise_on_failed_jobs() # call happens now os.environ["MY_PREDEFINED_VAL"] = "VAL" - info = dlt.pipeline("sink_test", destination=sink_func_with_spec(), full_refresh=True).run( - [1, 2, 3], table_name="items" - ) + info = dlt.pipeline( + "sink_test", destination=sink_func_with_spec(), full_refresh=True + ).run([1, 2, 3], table_name="items") info.raise_on_failed_jobs() # check destination with additional config params @@ -564,7 +582,9 @@ def test_max_nesting_level(nesting: int) -> None: data = [ { "level": 1, - "children": [{"level": 2, "children": [{"level": 3, "children": [{"level": 4}]}]}], + "children": [ + {"level": 2, "children": [{"level": 3, "children": [{"level": 4}]}]} + ], } ] @@ -579,7 +599,9 @@ def nesting_sink(items, table): def source(): yield dlt.resource(data, name="data") - p = dlt.pipeline("sink_test_max_nesting", destination=nesting_sink, full_refresh=True) + p = dlt.pipeline( + "sink_test_max_nesting", destination=nesting_sink, full_refresh=True + ) p.run(source()) # fall back to source setting diff --git a/tests/destinations/test_destination_name_and_config.py b/tests/destinations/test_destination_name_and_config.py index 930a72d95d..6aeff21923 100644 --- a/tests/destinations/test_destination_name_and_config.py +++ b/tests/destinations/test_destination_name_and_config.py @@ -32,13 +32,17 @@ def test_set_name_and_environment() -> None: ) p = dlt.pipeline(pipeline_name="quack_pipeline", destination=duck) assert ( - p.destination.destination_type == "dlt.destinations.duckdb" == p.state["destination_type"] + p.destination.destination_type + == "dlt.destinations.duckdb" + == p.state["destination_type"] ) assert p.destination.destination_name == "duck1" == p.state["destination_name"] load_info = p.run([1, 2, 3], table_name="table", dataset_name="dataset") assert ( - p.destination.destination_type == "dlt.destinations.duckdb" == p.state["destination_type"] + p.destination.destination_type + == "dlt.destinations.duckdb" + == p.state["destination_type"] ) assert p.destination.destination_name == "duck1" == p.state["destination_name"] @@ -63,7 +67,9 @@ def test_preserve_destination_instance() -> None: destination_name="local_fs", environment="devel", ) - p = dlt.pipeline(pipeline_name="dummy_pipeline", destination=dummy1, staging=filesystem1) + p = dlt.pipeline( + pipeline_name="dummy_pipeline", destination=dummy1, staging=filesystem1 + ) destination_id = id(p.destination) staging_id = id(p.staging) import os @@ -87,7 +93,11 @@ def test_preserve_destination_instance() -> None: == p.state["destination_type"] == load_info.destination_type ) - assert p.destination.config_params["environment"] == "dev/null/1" == load_info.environment + assert ( + p.destination.config_params["environment"] + == "dev/null/1" + == load_info.environment + ) assert ( p.staging.destination_name == "local_fs" @@ -105,9 +115,17 @@ def test_preserve_destination_instance() -> None: # attach pipeline p = dlt.attach(pipeline_name="dummy_pipeline") assert p.destination.destination_name == "dummy1" == p.state["destination_name"] - assert p.destination.destination_type == "dlt.destinations.dummy" == p.state["destination_type"] + assert ( + p.destination.destination_type + == "dlt.destinations.dummy" + == p.state["destination_type"] + ) assert p.staging.destination_name == "local_fs" == p.state["staging_name"] - assert p.staging.destination_type == "dlt.destinations.filesystem" == p.state["staging_type"] + assert ( + p.staging.destination_type + == "dlt.destinations.filesystem" + == p.state["staging_type"] + ) # config args should not contain self assert "self" not in p.destination.config_params @@ -147,7 +165,8 @@ def test_config_respects_dataset_name(environment: DictStrStr) -> None: # duck1 will be staging duck = duckdb( - credentials=os.path.join(TEST_STORAGE_ROOT, "quack.duckdb"), destination_name="duck1" + credentials=os.path.join(TEST_STORAGE_ROOT, "quack.duckdb"), + destination_name="duck1", ) p = dlt.pipeline(pipeline_name="quack_pipeline_staging", destination=duck) load_info = p.run([1, 2, 3], table_name="table") @@ -158,7 +177,8 @@ def test_config_respects_dataset_name(environment: DictStrStr) -> None: # duck2 will be production duck = duckdb( - credentials=os.path.join(TEST_STORAGE_ROOT, "quack.duckdb"), destination_name="duck2" + credentials=os.path.join(TEST_STORAGE_ROOT, "quack.duckdb"), + destination_name="duck2", ) p = dlt.pipeline(pipeline_name="quack_pipeline_production", destination=duck) load_info = p.run([1, 2, 3], table_name="table") @@ -209,7 +229,7 @@ def test_destination_config_in_name(environment: DictStrStr) -> None: with pytest.raises(ConfigFieldMissingException): p.destination_client() - environment["DESTINATION__FILESYSTEM-PROD__BUCKET_URL"] = "file://" + posixpath.abspath( - TEST_STORAGE_ROOT + environment["DESTINATION__FILESYSTEM-PROD__BUCKET_URL"] = ( + "file://" + posixpath.abspath(TEST_STORAGE_ROOT) ) assert p.destination_client().fs_path.endswith(TEST_STORAGE_ROOT) # type: ignore[attr-defined] diff --git a/tests/destinations/test_path_utils.py b/tests/destinations/test_path_utils.py index 1cf2b17d76..4dd673adec 100644 --- a/tests/destinations/test_path_utils.py +++ b/tests/destinations/test_path_utils.py @@ -55,7 +55,9 @@ def test_get_table_prefix_layout() -> None: # disallow other params before table_name with pytest.raises(CantExtractTablePrefix): - path_utils.get_table_prefix_layout("{file_id}some_random{table_name}/stuff_in_between/") + path_utils.get_table_prefix_layout( + "{file_id}some_random{table_name}/stuff_in_between/" + ) # disallow any placeholders before table name (ie. Athena) with pytest.raises(CantExtractTablePrefix): @@ -66,4 +68,6 @@ def test_get_table_prefix_layout() -> None: # disallow table_name without following separator with pytest.raises(CantExtractTablePrefix): - path_utils.get_table_prefix_layout("{schema_name}/{table_name}{load_id}.{file_id}.{ext}") + path_utils.get_table_prefix_layout( + "{schema_name}/{table_name}{load_id}.{file_id}.{ext}" + ) diff --git a/tests/extract/data_writers/test_buffered_writer.py b/tests/extract/data_writers/test_buffered_writer.py index aff49e06ac..e375a15c81 100644 --- a/tests/extract/data_writers/test_buffered_writer.py +++ b/tests/extract/data_writers/test_buffered_writer.py @@ -62,7 +62,9 @@ def c3_doc(count: int) -> Iterator[DictStrAny]: assert writer.closed_files[0].items_count == 9 assert writer.closed_files[0].file_size > 0 # check the content, mind that we swapped the columns - with FileStorage.open_zipsafe_ro(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: + with FileStorage.open_zipsafe_ro( + writer.closed_files[0].file_path, "r", encoding="utf-8" + ) as f: content = f.readlines() assert "col2,col1" in content[0] assert "NULL,0" in content[2] @@ -114,7 +116,9 @@ def c3_doc(count: int) -> Iterator[DictStrAny]: assert len(writer.closed_files) == 2 assert writer._buffered_items == [] # the last file must contain text value of the column3 - with FileStorage.open_zipsafe_ro(writer.closed_files[-1].file_path, "r", encoding="utf-8") as f: + with FileStorage.open_zipsafe_ro( + writer.closed_files[-1].file_path, "r", encoding="utf-8" + ) as f: content = f.readlines() assert "(col3_value" in content[-1] # check metrics @@ -149,7 +153,9 @@ def c2_doc(count: int) -> Iterator[DictStrAny]: # only the initial 15 items written assert writer._writer.items_count == 15 # all written - with FileStorage.open_zipsafe_ro(writer.closed_files[-1].file_path, "r", encoding="utf-8") as f: + with FileStorage.open_zipsafe_ro( + writer.closed_files[-1].file_path, "r", encoding="utf-8" + ) as f: content = f.readlines() assert content[-1] == '{"col1":1,"col2":3}\n' @@ -183,7 +189,9 @@ def test_writer_optional_schema(disable_compression: bool) -> None: "disable_compression", [True, False], ids=["no_compression", "compression"] ) @pytest.mark.parametrize("format_", ALL_WRITERS - {"arrow"}) -def test_write_empty_file(disable_compression: bool, format_: TLoaderFileFormat) -> None: +def test_write_empty_file( + disable_compression: bool, format_: TLoaderFileFormat +) -> None: # just single schema is enough c1 = new_column("col1", "bigint") t1 = {"col1": c1} @@ -226,7 +234,10 @@ def test_gather_metrics(disable_compression: bool, format_: TLoaderFileFormat) - c1 = new_column("col1", "bigint") t1 = {"col1": c1} with get_writer( - format_, disable_compression=disable_compression, buffer_max_items=2, file_max_items=2 + format_, + disable_compression=disable_compression, + buffer_max_items=2, + file_max_items=2, ) as writer: time.sleep(0.55) count = writer.write_data_item([{"col1": 182812}, {"col1": -1}], t1) @@ -238,7 +249,9 @@ def test_gather_metrics(disable_compression: bool, format_: TLoaderFileFormat) - assert metrics.last_modified - metrics.created >= 0.55 assert metrics.created >= now time.sleep(0.35) - count = writer.write_data_item([{"col1": 182812}, {"col1": -1}, {"col1": 182811}], t1) + count = writer.write_data_item( + [{"col1": 182812}, {"col1": -1}, {"col1": 182811}], t1 + ) assert count == 3 # file rotated assert len(writer.closed_files) == 2 @@ -254,11 +267,16 @@ def test_gather_metrics(disable_compression: bool, format_: TLoaderFileFormat) - "disable_compression", [True, False], ids=["no_compression", "compression"] ) @pytest.mark.parametrize("format_", ALL_WRITERS - {"arrow"}) -def test_special_write_rotates(disable_compression: bool, format_: TLoaderFileFormat) -> None: +def test_special_write_rotates( + disable_compression: bool, format_: TLoaderFileFormat +) -> None: c1 = new_column("col1", "bigint") t1 = {"col1": c1} with get_writer( - format_, disable_compression=disable_compression, buffer_max_items=100, file_max_items=100 + format_, + disable_compression=disable_compression, + buffer_max_items=100, + file_max_items=100, ) as writer: writer.write_data_item([{"col1": 182812}, {"col1": -1}], t1) assert len(writer.closed_files) == 0 diff --git a/tests/extract/data_writers/test_data_item_storage.py b/tests/extract/data_writers/test_data_item_storage.py index 1e6327a3ba..0345afa5da 100644 --- a/tests/extract/data_writers/test_data_item_storage.py +++ b/tests/extract/data_writers/test_data_item_storage.py @@ -3,7 +3,10 @@ from dlt.common.configuration.container import Container from dlt.common.data_writers.writers import DataWriterMetrics -from dlt.common.destination.capabilities import TLoaderFileFormat, DestinationCapabilitiesContext +from dlt.common.destination.capabilities import ( + TLoaderFileFormat, + DestinationCapabilitiesContext, +) from dlt.common.schema.utils import new_column from tests.common.data_writers.utils import ALL_WRITERS from dlt.common.storages.data_item_storage import DataItemStorage @@ -12,8 +15,12 @@ class ItemTestStorage(DataItemStorage): - def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str: - return os.path.join(TEST_STORAGE_ROOT, f"{load_id}.{schema_name}.{table_name}.%s") + def _get_data_item_path_template( + self, load_id: str, schema_name: str, table_name: str + ) -> str: + return os.path.join( + TEST_STORAGE_ROOT, f"{load_id}.{schema_name}.{table_name}.%s" + ) @pytest.mark.parametrize("format_", ALL_WRITERS - {"arrow"}) @@ -43,8 +50,12 @@ def test_write_items(format_: TLoaderFileFormat) -> None: assert item_storage.closed_files("load_1")[1] == metrics # closed files are separate - item_storage.write_data_item("load_1", "schema", "t1", [{"col1": 182812}, {"col1": -1}], t1) - item_storage.write_data_item("load_2", "schema", "t1", [{"col1": 182812}, {"col1": -1}], t1) + item_storage.write_data_item( + "load_1", "schema", "t1", [{"col1": 182812}, {"col1": -1}], t1 + ) + item_storage.write_data_item( + "load_2", "schema", "t1", [{"col1": 182812}, {"col1": -1}], t1 + ) item_storage.close_writers("load_1") assert len(item_storage.closed_files("load_1")) == 3 assert len(item_storage.closed_files("load_2")) == 1 diff --git a/tests/extract/test_decorators.py b/tests/extract/test_decorators.py index dca4c0be6e..387c42fa79 100644 --- a/tests/extract/test_decorators.py +++ b/tests/extract/test_decorators.py @@ -257,12 +257,16 @@ def get_users(): def test_apply_hints_columns() -> None: - @dlt.resource(name="user", columns={"tags": {"data_type": "complex", "primary_key": True}}) + @dlt.resource( + name="user", columns={"tags": {"data_type": "complex", "primary_key": True}} + ) def get_users(): yield {"u": "u", "tags": [1, 2, 3]} users = get_users() - assert users.columns == {"tags": {"data_type": "complex", "name": "tags", "primary_key": True}} + assert users.columns == { + "tags": {"data_type": "complex", "name": "tags", "primary_key": True} + } assert ( cast(TTableSchemaColumns, users.columns)["tags"] == users.compute_table_schema()["columns"]["tags"] @@ -372,30 +376,34 @@ def test_source_sections() -> None: assert list(resource_f_2()) == ["SOURCES LEVEL"] # values in module section - os.environ[f"{known_sections.SOURCES.upper()}__SECTION_SOURCE__VAL"] = "SECTION SOURCE LEVEL" + os.environ[f"{known_sections.SOURCES.upper()}__SECTION_SOURCE__VAL"] = ( + "SECTION SOURCE LEVEL" + ) assert list(init_source_f_1()) == ["SECTION SOURCE LEVEL"] assert list(init_resource_f_2()) == ["SECTION SOURCE LEVEL"] # here overridden by __source_name__ - os.environ[f"{known_sections.SOURCES.upper()}__NAME_OVERRIDDEN__VAL"] = "NAME OVERRIDDEN LEVEL" + os.environ[f"{known_sections.SOURCES.upper()}__NAME_OVERRIDDEN__VAL"] = ( + "NAME OVERRIDDEN LEVEL" + ) assert list(source_f_1()) == ["NAME OVERRIDDEN LEVEL"] assert list(resource_f_2()) == ["NAME OVERRIDDEN LEVEL"] # values in function name section - os.environ[f"{known_sections.SOURCES.upper()}__SECTION_SOURCE__INIT_SOURCE_F_1__VAL"] = ( - "SECTION INIT_SOURCE_F_1 LEVEL" - ) + os.environ[ + f"{known_sections.SOURCES.upper()}__SECTION_SOURCE__INIT_SOURCE_F_1__VAL" + ] = "SECTION INIT_SOURCE_F_1 LEVEL" assert list(init_source_f_1()) == ["SECTION INIT_SOURCE_F_1 LEVEL"] - os.environ[f"{known_sections.SOURCES.upper()}__SECTION_SOURCE__INIT_RESOURCE_F_2__VAL"] = ( - "SECTION INIT_RESOURCE_F_2 LEVEL" - ) + os.environ[ + f"{known_sections.SOURCES.upper()}__SECTION_SOURCE__INIT_RESOURCE_F_2__VAL" + ] = "SECTION INIT_RESOURCE_F_2 LEVEL" assert list(init_resource_f_2()) == ["SECTION INIT_RESOURCE_F_2 LEVEL"] - os.environ[f"{known_sections.SOURCES.upper()}__NAME_OVERRIDDEN__SOURCE_F_1__VAL"] = ( - "NAME SOURCE_F_1 LEVEL" - ) + os.environ[ + f"{known_sections.SOURCES.upper()}__NAME_OVERRIDDEN__SOURCE_F_1__VAL" + ] = "NAME SOURCE_F_1 LEVEL" assert list(source_f_1()) == ["NAME SOURCE_F_1 LEVEL"] - os.environ[f"{known_sections.SOURCES.upper()}__NAME_OVERRIDDEN__RESOURCE_F_2__VAL"] = ( - "NAME RESOURCE_F_2 LEVEL" - ) + os.environ[ + f"{known_sections.SOURCES.upper()}__NAME_OVERRIDDEN__RESOURCE_F_2__VAL" + ] = "NAME RESOURCE_F_2 LEVEL" assert list(resource_f_2()) == ["NAME RESOURCE_F_2 LEVEL"] @@ -606,7 +614,9 @@ def schema_test(): @dlt.resource -def standalone_resource(secret=dlt.secrets.value, config=dlt.config.value, opt: str = "A"): +def standalone_resource( + secret=dlt.secrets.value, config=dlt.config.value, opt: str = "A" +): yield 1 @@ -755,7 +765,9 @@ def many_instances(): @dlt.transformer(standalone=True) -def standalone_transformer(item: TDataItem, init: int, secret_end: int = dlt.secrets.value): +def standalone_transformer( + item: TDataItem, init: int, secret_end: int = dlt.secrets.value +): """Has fine transformer docstring""" yield from range(item + init, secret_end) @@ -800,7 +812,9 @@ def test_standalone_transformer() -> None: @dlt.transformer(standalone=True, name=lambda args: args["res_name"]) -def standalone_tx_with_name(item: TDataItem, res_name: str, init: int = dlt.config.value): +def standalone_tx_with_name( + item: TDataItem, res_name: str, init: int = dlt.config.value +): return res_name * item * init @@ -917,7 +931,11 @@ async def _assert_source(source_coro_f, expected_data) -> None: # make sure the config injection works with custom_environ( - {f"SOURCES__{source.section.upper()}__{source.name.upper()}__REVERSE": "True"} + { + f"SOURCES__{source.section.upper()}__{source.name.upper()}__REVERSE": ( + "True" + ) + } ): assert list(await source_coro_f()) == list(reversed(expected_data)) diff --git a/tests/extract/test_extract.py b/tests/extract/test_extract.py index 1879eaa9eb..a47eccbb7d 100644 --- a/tests/extract/test_extract.py +++ b/tests/extract/test_extract.py @@ -24,7 +24,9 @@ def extract_step() -> Extract: clean_test_storage(init_normalize=True) schema_storage = SchemaStorage( - SchemaStorageConfiguration(schema_volume_path=os.path.join(TEST_STORAGE_ROOT, "schemas")), + SchemaStorageConfiguration( + schema_volume_path=os.path.join(TEST_STORAGE_ROOT, "schemas") + ), makedirs=True, ) return Extract(schema_storage, NormalizeStorageConfiguration()) @@ -44,7 +46,9 @@ def test_storage_reuse_package() -> None: # we have a new load id (the package with schema moved to extracted) load_id_3 = storage.create_load_package(dlt.Schema("first")) assert load_id != load_id_3 - load_id_4 = storage.create_load_package(dlt.Schema("first"), reuse_exiting_package=False) + load_id_4 = storage.create_load_package( + dlt.Schema("first"), reuse_exiting_package=False + ) assert load_id_4 != load_id_3 # this will fail - not all extracts committed @@ -103,7 +107,9 @@ def test_extract_hints_mark(extract_step: Extract) -> None: def with_table_hints(): yield dlt.mark.with_hints( {"id": 1, "pk": "A"}, - make_hints(columns=[{"name": "id", "data_type": "bigint"}], primary_key="pk"), + make_hints( + columns=[{"name": "id", "data_type": "bigint"}], primary_key="pk" + ), ) schema = dlt.current.source_schema() # table and columns got updated in the schema @@ -125,7 +131,10 @@ def with_table_hints(): {"id": 1, "pk2": "B"}, make_hints( write_disposition="merge", - columns=[{"name": "id", "precision": 16}, {"name": "text", "data_type": "decimal"}], + columns=[ + {"name": "id", "precision": 16}, + {"name": "text", "data_type": "decimal"}, + ], primary_key="pk2", ), ) @@ -146,7 +155,8 @@ def with_table_hints(): # make table name dynamic yield dlt.mark.with_hints( - {"namer": "dynamic"}, make_hints(table_name=lambda item: f"{item['namer']}_table") + {"namer": "dynamic"}, + make_hints(table_name=lambda item: f"{item['namer']}_table"), ) # dynamic table was created in the schema and it contains the newest resource table schema table = schema.tables["dynamic_table"] @@ -172,7 +182,9 @@ def test_extract_hints_table_variant(extract_step: Extract) -> None: def with_table_hints(): yield dlt.mark.with_hints( {"id": 1, "pk": "A"}, - make_hints(table_name="table_a", columns=[{"name": "id", "data_type": "bigint"}]), + make_hints( + table_name="table_a", columns=[{"name": "id", "data_type": "bigint"}] + ), create_table_variant=True, ) # get the resource @@ -240,7 +252,9 @@ def tx_step(item): input_tx = DltResource.from_data(tx_step, data_from=DltResource.Empty) source = DltSource( - dlt.Schema("selectables"), "module", [input_r, (input_r | input_tx).with_name("tx_clone")] + dlt.Schema("selectables"), + "module", + [input_r, (input_r | input_tx).with_name("tx_clone")], ) extract_step.extract(source, 20, 1) assert "input_gen" in source.schema._schema_tables @@ -263,10 +277,16 @@ def expect_tables(extract_step: Extract, resource: DltResource) -> dlt.Schema: # check resulting files assert len(extract_step.extract_storage.list_files_to_normalize_sorted()) == 2 expect_extracted_file( - extract_step.extract_storage, "selectables", "odd_table", json.dumps([1, 3, 5, 7, 9]) + extract_step.extract_storage, + "selectables", + "odd_table", + json.dumps([1, 3, 5, 7, 9]), ) expect_extracted_file( - extract_step.extract_storage, "selectables", "even_table", json.dumps([0, 2, 4, 6, 8]) + extract_step.extract_storage, + "selectables", + "even_table", + json.dumps([0, 2, 4, 6, 8]), ) schema = source.schema diff --git a/tests/extract/test_extract_pipe.py b/tests/extract/test_extract_pipe.py index 68c1c82124..18acccae3a 100644 --- a/tests/extract/test_extract_pipe.py +++ b/tests/extract/test_extract_pipe.py @@ -9,7 +9,11 @@ import dlt from dlt.common import sleep from dlt.common.typing import TDataItems -from dlt.extract.exceptions import CreatePipeException, ResourceExtractionError, UnclosablePipe +from dlt.extract.exceptions import ( + CreatePipeException, + ResourceExtractionError, + UnclosablePipe, +) from dlt.extract.items import DataItemWithMeta, FilterItem, MapItem, YieldMapItem from dlt.extract.pipe import Pipe from dlt.extract.pipe_iterator import PipeIterator, ManagedPipeIterator, PipeItem @@ -44,7 +48,11 @@ def get_pipes(): assert [pi.item for pi in _l] == [1, 2, 3, 4, 10, 5, 6, 8, 7, 9, 11, 12, 13, 14, 15] # force fifo, no rotation at all when crossing the initial source count - _l = list(PipeIterator.from_pipes(get_pipes(), next_item_mode="fifo", max_parallel_items=1)) + _l = list( + PipeIterator.from_pipes( + get_pipes(), next_item_mode="fifo", max_parallel_items=1 + ) + ) # order the same as above - same rules apply assert [pi.item for pi in _l] == [1, 2, 3, 4, 10, 5, 6, 8, 7, 9, 11, 12, 13, 14, 15] @@ -56,7 +64,9 @@ def get_pipes(): # round robin with max parallel items triggers strict fifo in some cases (after gen2 and 3 are exhausted we already have the first yielded gen, # items appear in order as sources are processed strictly from front) _l = list( - PipeIterator.from_pipes(get_pipes(), next_item_mode="round_robin", max_parallel_items=1) + PipeIterator.from_pipes( + get_pipes(), next_item_mode="round_robin", max_parallel_items=1 + ) ) # items will be in order of the pipes, nested iterator items appear inline, None triggers rotation # NOTE: 4, 10, 5 - after 4 there's NONE in fifo so we do next element (round robin style) @@ -459,7 +469,9 @@ def test_map_step() -> None: def test_yield_map_step() -> None: p = Pipe.from_data("data", [1, 2, 3]) # this creates number of rows as passed by the data - p.append_step(YieldMapItem(lambda item: (yield from [f"item_{x}" for x in range(item)]))) + p.append_step( + YieldMapItem(lambda item: (yield from [f"item_{x}" for x in range(item)])) + ) assert _f_items(list(PipeIterator.from_pipe(p))) == [ "item_0", "item_0", @@ -474,7 +486,9 @@ def test_yield_map_step() -> None: meta_data = [DataItemWithMeta(m, d) for m, d in zip(meta, data)] p = Pipe.from_data("data", meta_data) p.append_step( - YieldMapItem(lambda item, meta: (yield from [f"item_{meta}_{x}" for x in range(item)])) + YieldMapItem( + lambda item, meta: (yield from [f"item_{meta}_{x}" for x in range(item)]) + ) ) assert _f_items(list(PipeIterator.from_pipe(p))) == [ "item_A_0", @@ -493,12 +507,20 @@ def test_pipe_copy_on_fork() -> None: child2 = Pipe("tr2", [lambda x: x], parent=parent) # no copy, construct iterator - elems = list(PipeIterator.from_pipes([child1, child2], yield_parents=False, copy_on_fork=False)) + elems = list( + PipeIterator.from_pipes( + [child1, child2], yield_parents=False, copy_on_fork=False + ) + ) # those are the same instances assert doc is elems[0].item is elems[1].item # copy item on fork - elems = list(PipeIterator.from_pipes([child1, child2], yield_parents=False, copy_on_fork=True)) + elems = list( + PipeIterator.from_pipes( + [child1, child2], yield_parents=False, copy_on_fork=True + ) + ) # first fork does not copy assert doc is elems[0].item # second fork copies @@ -754,7 +776,9 @@ def assert_pipes_closed(raise_gen, long_gen) -> None: pit: PipeIterator = None with PipeIterator.from_pipe( - Pipe.from_data("failing", raise_gen, parent=Pipe.from_data("endless", long_gen())) + Pipe.from_data( + "failing", raise_gen, parent=Pipe.from_data("endless", long_gen()) + ) ) as pit: with pytest.raises(ResourceExtractionError) as py_ex: list(pit) @@ -768,7 +792,9 @@ def assert_pipes_closed(raise_gen, long_gen) -> None: close_pipe_got_exit = False close_pipe_yielding = False pit = ManagedPipeIterator.from_pipe( - Pipe.from_data("failing", raise_gen, parent=Pipe.from_data("endless", long_gen())) + Pipe.from_data( + "failing", raise_gen, parent=Pipe.from_data("endless", long_gen()) + ) ) with pytest.raises(ResourceExtractionError) as py_ex: list(pit) diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index a393706de7..39b436de19 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -12,7 +12,10 @@ import dlt from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.base_configuration import configspec, BaseConfiguration +from dlt.common.configuration.specs.base_configuration import ( + configspec, + BaseConfiguration, +) from dlt.common.configuration import ConfigurationValueError from dlt.common.pendulum import pendulum, timedelta from dlt.common.pipeline import NormalizeInfo, StateInjectableContext, resource_state @@ -58,7 +61,9 @@ def some_data(created_at=dlt.sources.incremental("created_at")): @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_single_items_last_value_state_is_updated_transformer(item_type: TDataItemFormat) -> None: +def test_single_items_last_value_state_is_updated_transformer( + item_type: TDataItemFormat, +) -> None: data = [ {"created_at": 425}, {"created_at": 426}, @@ -92,9 +97,9 @@ def some_data(created_at=dlt.sources.incremental("created_at")): p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ - "created_at" - ] + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"][ + "incremental" + ]["created_at"] assert s["last_value"] == 9 @@ -145,16 +150,28 @@ def some_data(created_at=dlt.sources.incremental("created_at")): yield from source_items2 p = dlt.pipeline( - pipeline_name=uniq_id(), destination="duckdb", credentials=duckdb.connect(":memory:") + pipeline_name=uniq_id(), + destination="duckdb", + credentials=duckdb.connect(":memory:"), ) p.run(some_data()).raise_on_failed_jobs() p.run(some_data()).raise_on_failed_jobs() with p.sql_client() as c: - with c.execute_query("SELECT created_at, id FROM some_data order by created_at, id") as cur: + with c.execute_query( + "SELECT created_at, id FROM some_data order by created_at, id" + ) as cur: rows = cur.fetchall() - assert rows == [(1, "a"), (2, "b"), (3, "c"), (3, "d"), (3, "e"), (3, "f"), (4, "g")] + assert rows == [ + (1, "a"), + (2, "b"), + (3, "c"), + (3, "d"), + (3, "e"), + (3, "f"), + (4, "g"), + ] @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) @@ -185,16 +202,28 @@ def some_data(created_at=dlt.sources.incremental("created_at")): yield from source_items2 p = dlt.pipeline( - pipeline_name=uniq_id(), destination="duckdb", credentials=duckdb.connect(":memory:") + pipeline_name=uniq_id(), + destination="duckdb", + credentials=duckdb.connect(":memory:"), ) p.run(some_data()).raise_on_failed_jobs() p.run(some_data()).raise_on_failed_jobs() with p.sql_client() as c: - with c.execute_query("SELECT created_at, id FROM some_data order by created_at, id") as cur: + with c.execute_query( + "SELECT created_at, id FROM some_data order by created_at, id" + ) as cur: rows = cur.fetchall() - assert rows == [(1, "a"), (2, "b"), (3, "c"), (3, "d"), (3, "e"), (3, "f"), (4, "g")] + assert rows == [ + (1, "a"), + (2, "b"), + (3, "c"), + (3, "d"), + (3, "e"), + (3, "f"), + (4, "g"), + ] def test_nested_cursor_path() -> None: @@ -205,9 +234,9 @@ def some_data(created_at=dlt.sources.incremental("data.items[0].created_at")): p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ - "data.items[0].created_at" - ] + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"][ + "incremental" + ]["data.items[0].created_at"] assert s["last_value"] == 2 @@ -239,9 +268,9 @@ def some_data(created_at=dlt.sources.incremental("created_at")): p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data(created_at=4242)) - s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ - "created_at" - ] + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"][ + "incremental" + ]["created_at"] assert s["last_value"] == 4242 @@ -257,7 +286,9 @@ def some_data(incremental=dlt.sources.incremental("created_at", initial_value=0) yield from source_items p = dlt.pipeline(pipeline_name=uniq_id()) - p.extract(some_data(incremental=dlt.sources.incremental("inserted_at", initial_value=241))) + p.extract( + some_data(incremental=dlt.sources.incremental("inserted_at", initial_value=241)) + ) @dlt.resource @@ -281,12 +312,12 @@ def some_data_from_config( @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) def test_optional_incremental_from_config(item_type: TDataItemFormat) -> None: - os.environ["SOURCES__TEST_INCREMENTAL__SOME_DATA_FROM_CONFIG__CREATED_AT__CURSOR_PATH"] = ( - "created_at" - ) - os.environ["SOURCES__TEST_INCREMENTAL__SOME_DATA_FROM_CONFIG__CREATED_AT__INITIAL_VALUE"] = ( - "2022-02-03T00:00:00Z" - ) + os.environ[ + "SOURCES__TEST_INCREMENTAL__SOME_DATA_FROM_CONFIG__CREATED_AT__CURSOR_PATH" + ] = "created_at" + os.environ[ + "SOURCES__TEST_INCREMENTAL__SOME_DATA_FROM_CONFIG__CREATED_AT__INITIAL_VALUE" + ] = "2022-02-03T00:00:00Z" p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data_from_config(1, item_type)) @@ -314,7 +345,8 @@ class OptionalIncrementalConfig(BaseConfiguration): @dlt.resource(spec=OptionalIncrementalConfig) def optional_incremental_arg_resource( - item_type: TDataItemFormat, incremental: Optional[dlt.sources.incremental[Any]] = None + item_type: TDataItemFormat, + incremental: Optional[dlt.sources.incremental[Any]] = None, ) -> Any: data = [1, 2, 3] source_items = data_to_item_format(item_type, data) @@ -336,7 +368,8 @@ class SomeDataOverrideConfiguration(BaseConfiguration): # provide what to inject via spec. the spec contain the default @dlt.resource(spec=SomeDataOverrideConfiguration) def some_data_override_config( - item_type: TDataItemFormat, created_at: dlt.sources.incremental[str] = dlt.config.value + item_type: TDataItemFormat, + created_at: dlt.sources.incremental[str] = dlt.config.value, ): assert created_at.cursor_path == "created_at" assert created_at.initial_value == "2000-02-03T00:00:00Z" @@ -358,7 +391,10 @@ def test_override_initial_value_from_config(item_type: TDataItemFormat) -> None: @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) def test_override_primary_key_in_pipeline(item_type: TDataItemFormat) -> None: """Primary key hint passed to pipeline is propagated through apply_hints""" - data = [{"created_at": 22, "id": 2, "other_id": 5}, {"created_at": 22, "id": 2, "other_id": 6}] + data = [ + {"created_at": 22, "id": 2, "other_id": 5}, + {"created_at": 22, "id": 2, "other_id": 6}, + ] source_items = data_to_item_format(item_type, data) @dlt.resource(primary_key="id") @@ -390,13 +426,16 @@ def some_data(created_at=dlt.sources.incremental("created_at")): yield from source_items p = dlt.pipeline( - pipeline_name=uniq_id(), destination="duckdb", credentials=duckdb.connect(":memory:") + pipeline_name=uniq_id(), + destination="duckdb", + credentials=duckdb.connect(":memory:"), ) p.run(some_data()).raise_on_failed_jobs() with p.sql_client() as c: with c.execute_query( - "SELECT created_at, isrc, market FROM some_data order by created_at, isrc, market" + "SELECT created_at, isrc, market FROM some_data order by created_at, isrc," + " market" ) as cur: rows = cur.fetchall() @@ -425,15 +464,17 @@ def test_last_value_func_min(item_type: TDataItemFormat) -> None: source_items = data_to_item_format(item_type, data) @dlt.resource - def some_data(created_at=dlt.sources.incremental("created_at", last_value_func=min)): + def some_data( + created_at=dlt.sources.incremental("created_at", last_value_func=min) + ): yield from source_items p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ - "created_at" - ] + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"][ + "incremental" + ]["created_at"] assert s["last_value"] == 8 @@ -443,16 +484,18 @@ def last_value(values): return max(values) + 1 @dlt.resource - def some_data(created_at=dlt.sources.incremental("created_at", last_value_func=last_value)): + def some_data( + created_at=dlt.sources.incremental("created_at", last_value_func=last_value) + ): yield {"created_at": 9} yield {"created_at": 10} p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ - "created_at" - ] + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"][ + "incremental" + ]["created_at"] assert s["last_value"] == 11 @@ -476,9 +519,9 @@ def some_data(created_at=dlt.sources.incremental("created_at", initial_value)): p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ - "created_at" - ] + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"][ + "incremental" + ]["created_at"] assert s["last_value"] == initial_value + timedelta(minutes=4) @@ -497,9 +540,9 @@ def some_data(created_at=dlt.sources.incremental("created_at", 20)): p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) - s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ - "created_at" - ] + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"][ + "incremental" + ]["created_at"] last_hash = digest128(json.dumps({"created_at": 24})) @@ -515,14 +558,19 @@ def test_unique_keys_json_identifiers(item_type: TDataItemFormat) -> None: @dlt.resource(primary_key="DelTa") def some_data(last_timestamp=dlt.sources.incremental("ts")): - data = [{"DelTa": i, "ts": pendulum.now().add(days=i).timestamp()} for i in range(-10, 10)] + data = [ + {"DelTa": i, "ts": pendulum.now().add(days=i).timestamp()} + for i in range(-10, 10) + ] source_items = data_to_item_format(item_type, data) yield from source_items p = dlt.pipeline(pipeline_name=uniq_id()) p.run(some_data, destination="duckdb") # check if default schema contains normalized PK - assert p.default_schema.tables["some_data"]["columns"]["del_ta"]["primary_key"] is True + assert ( + p.default_schema.tables["some_data"]["columns"]["del_ta"]["primary_key"] is True + ) with p.sql_client() as c: with c.execute_query("SELECT del_ta FROM some_data") as cur: rows = cur.fetchall() @@ -546,7 +594,10 @@ def some_data(last_timestamp=dlt.sources.incremental("ts")): def test_missing_primary_key(item_type: TDataItemFormat) -> None: @dlt.resource(primary_key="DELTA") def some_data(last_timestamp=dlt.sources.incremental("ts")): - data = [{"delta": i, "ts": pendulum.now().add(days=i).timestamp()} for i in range(-10, 10)] + data = [ + {"delta": i, "ts": pendulum.now().add(days=i).timestamp()} + for i in range(-10, 10) + ] source_items = data_to_item_format(item_type, data) yield from source_items @@ -561,7 +612,10 @@ def test_missing_cursor_field(item_type: TDataItemFormat) -> None: @dlt.resource def some_data(last_timestamp=dlt.sources.incremental("item.timestamp")): - data = [{"delta": i, "ts": pendulum.now().add(days=i).timestamp()} for i in range(-10, 10)] + data = [ + {"delta": i, "ts": pendulum.now().add(days=i).timestamp()} + for i in range(-10, 10) + ] source_items = data_to_item_format(item_type, data) yield from source_items @@ -614,9 +668,9 @@ def some_data( def test_remove_incremental_with_incremental_empty() -> None: @dlt.resource def some_data_optional( - last_timestamp: Optional[dlt.sources.incremental[float]] = dlt.sources.incremental( - "item.timestamp" - ), + last_timestamp: Optional[ + dlt.sources.incremental[float] + ] = dlt.sources.incremental("item.timestamp"), ): assert last_timestamp is None yield 1 @@ -627,7 +681,9 @@ def some_data_optional( @dlt.resource(standalone=True) def some_data( - last_timestamp: dlt.sources.incremental[float] = dlt.sources.incremental("item.timestamp"), + last_timestamp: dlt.sources.incremental[float] = dlt.sources.incremental( + "item.timestamp" + ), ): assert last_timestamp is None yield 1 @@ -643,7 +699,9 @@ def test_filter_processed_items(item_type: TDataItemFormat) -> None: @dlt.resource def standalone_some_data( - item_type: TDataItemFormat, now=None, last_timestamp=dlt.sources.incremental("timestamp") + item_type: TDataItemFormat, + now=None, + last_timestamp=dlt.sources.incremental("timestamp"), ): data = [ {"delta": i, "timestamp": (now or pendulum.now()).add(days=i).timestamp()} @@ -658,7 +716,9 @@ def standalone_some_data( assert len(values) == 20 # provide initial value using max function - values = list(standalone_some_data(item_type, last_timestamp=pendulum.now().timestamp())) + values = list( + standalone_some_data(item_type, last_timestamp=pendulum.now().timestamp()) + ) values = data_item_to_list(item_type, values) assert len(values) == 10 # only the future timestamps @@ -668,7 +728,9 @@ def standalone_some_data( values = list( standalone_some_data( item_type, - last_timestamp=dlt.sources.incremental("timestamp", pendulum.now().timestamp(), min), + last_timestamp=dlt.sources.incremental( + "timestamp", pendulum.now().timestamp(), min + ), ) ) values = data_item_to_list(item_type, values) @@ -690,14 +752,18 @@ def some_data(step, last_timestamp=dlt.sources.incremental("ts")): else: # print(last_timestamp.initial_value) # print(now.add(days=step-1).timestamp()) - assert last_timestamp.start_value == last_timestamp.last_value == expected_last + assert ( + last_timestamp.start_value == last_timestamp.last_value == expected_last + ) data = [{"delta": i, "ts": now.add(days=i)} for i in range(-10, 10)] yield from data # after all yielded if step == -10: assert last_timestamp.start_value is None else: - assert last_timestamp.start_value == expected_last != last_timestamp.last_value + assert ( + last_timestamp.start_value == expected_last != last_timestamp.last_value + ) for i in range(-10, 10): r = some_data(i) @@ -721,13 +787,21 @@ def some_data(first: bool, last_timestamp=dlt.sources.incremental("ts")): else: # print(last_timestamp.initial_value) # print(now.add(days=step-1).timestamp()) - assert last_timestamp.start_value == last_timestamp.last_value == data[-1]["ts"] + assert ( + last_timestamp.start_value + == last_timestamp.last_value + == data[-1]["ts"] + ) yield from source_items # after all yielded if first: assert last_timestamp.start_value is None else: - assert last_timestamp.start_value == data[-1]["ts"] == last_timestamp.last_value + assert ( + last_timestamp.start_value + == data[-1]["ts"] + == last_timestamp.last_value + ) p.run(some_data(True)) p.run(some_data(False)) @@ -740,7 +814,9 @@ def test_replace_resets_state(item_type: TDataItemFormat) -> None: @dlt.resource def standalone_some_data( - item_type: TDataItemFormat, now=None, last_timestamp=dlt.sources.incremental("timestamp") + item_type: TDataItemFormat, + now=None, + last_timestamp=dlt.sources.incremental("timestamp"), ): data = [ {"delta": i, "timestamp": (now or pendulum.now()).add(days=i).timestamp()} @@ -806,7 +882,8 @@ def child(item): # there will be a load package to reset the state but also a load package to update the child table assert len(info.load_packages[0].jobs["completed_jobs"]) == 2 assert { - job.job_file_info.table_name for job in info.load_packages[0].jobs["completed_jobs"] + job.job_file_info.table_name + for job in info.load_packages[0].jobs["completed_jobs"] } == {"_dlt_pipeline_state", "child"} # now we add child that has parent_r as parent but we add another instance of standalone_some_data explicitly @@ -823,19 +900,24 @@ def test_incremental_as_transform(item_type: TDataItemFormat) -> None: @dlt.resource def some_data(): - last_value: dlt.sources.incremental[float] = dlt.sources.incremental.from_existing_state( - "some_data", "ts" + last_value: dlt.sources.incremental[float] = ( + dlt.sources.incremental.from_existing_state("some_data", "ts") ) assert last_value.initial_value == now assert last_value.start_value == now assert last_value.cursor_path == "ts" assert last_value.last_value == now - data = [{"delta": i, "ts": pendulum.now().add(days=i).timestamp()} for i in range(-10, 10)] + data = [ + {"delta": i, "ts": pendulum.now().add(days=i).timestamp()} + for i in range(-10, 10) + ] source_items = data_to_item_format(item_type, data) yield from source_items - r = some_data().add_step(dlt.sources.incremental("ts", initial_value=now, primary_key="delta")) + r = some_data().add_step( + dlt.sources.incremental("ts", initial_value=now, primary_key="delta") + ) p = dlt.pipeline(pipeline_name=uniq_id()) info = p.run(r, destination="duckdb") assert len(info.loads_ids) == 1 @@ -872,7 +954,9 @@ def some_data(created_at: Optional[dlt.sources.incremental[int]] = None): # the incremental wrapper is created for a resource and the incremental value is provided via apply hints r = some_data() assert r is not some_data - r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=max)) + r.apply_hints( + incremental=dlt.sources.incremental("created_at", last_value_func=max) + ) if item_type == "pandas": assert list(r)[0].equals(source_items[0]) else: @@ -883,7 +967,9 @@ def some_data(created_at: Optional[dlt.sources.incremental[int]] = None): # same thing with explicit None r = some_data(created_at=None).with_name("copy") - r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=max)) + r.apply_hints( + incremental=dlt.sources.incremental("created_at", last_value_func=max) + ) if item_type == "pandas": assert list(r)[0].equals(source_items[0]) else: @@ -896,14 +982,18 @@ def some_data(created_at: Optional[dlt.sources.incremental[int]] = None): p = p.drop() r = some_data(created_at=dlt.sources.incremental("created_at", last_value_func=max)) # explicit has precedence here and hints will be ignored - r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=min)) + r.apply_hints( + incremental=dlt.sources.incremental("created_at", last_value_func=min) + ) p.extract(r) assert "incremental" in r.state # max value assert r.state["incremental"]["created_at"]["last_value"] == 3 @dlt.resource - def some_data_w_default(created_at=dlt.sources.incremental("created_at", last_value_func=min)): + def some_data_w_default( + created_at=dlt.sources.incremental("created_at", last_value_func=min) + ): # make sure that incremental from apply_hints is here assert created_at is not None assert created_at.last_value_func is max @@ -912,7 +1002,9 @@ def some_data_w_default(created_at=dlt.sources.incremental("created_at", last_va # default is overridden by apply hints p = p.drop() r = some_data_w_default() - r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=max)) + r.apply_hints( + incremental=dlt.sources.incremental("created_at", last_value_func=max) + ) p.extract(r) assert "incremental" in r.state # min value @@ -925,7 +1017,9 @@ def some_data_no_incremental(): # we add incremental as a step p = p.drop() r = some_data_no_incremental() - r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=max)) + r.apply_hints( + incremental=dlt.sources.incremental("created_at", last_value_func=max) + ) assert r.incremental is not None p.extract(r) assert "incremental" in r.state @@ -956,7 +1050,9 @@ def _get_shuffled_events( last_created_at=dlt.sources.incremental("$", last_value_func=by_event_type) ): with open( - "tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8" + "tests/normalize/cases/github.events.load_page_1_duck.json", + "r", + encoding="utf-8", ) as f: yield json.load(f) @@ -977,15 +1073,16 @@ def _get_shuffled_events( @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) def test_timezone_naive_datetime(item_type: TDataItemFormat) -> None: """Resource has timezone naive datetime objects, but incremental stored state is - converted to tz aware pendulum dates. Can happen when loading e.g. from sql database""" + converted to tz aware pendulum dates. Can happen when loading e.g. from sql database + """ start_dt = datetime.now() pendulum_start_dt = pendulum.instance(start_dt) # With timezone @dlt.resource(standalone=True, primary_key="hour") def some_data( - updated_at: dlt.sources.incremental[pendulum.DateTime] = dlt.sources.incremental( - "updated_at", initial_value=pendulum_start_dt - ), + updated_at: dlt.sources.incremental[ + pendulum.DateTime + ] = dlt.sources.incremental("updated_at", initial_value=pendulum_start_dt), max_hours: int = 2, tz: str = None, ): @@ -1029,7 +1126,9 @@ def some_data( # it should be merged resource.apply_hints( incremental=dlt.sources.incremental( - "updated_at", initial_value=pendulum_start_dt, end_value=pendulum_start_dt.add(hours=3) + "updated_at", + initial_value=pendulum_start_dt, + end_value=pendulum_start_dt.add(hours=3), ) ) extract_info = pipeline.extract(resource) @@ -1041,8 +1140,12 @@ def some_data( ) # initial value is naive - resource = some_data(max_hours=4).with_name("copy_1") # also make new resource state - resource.apply_hints(incremental=dlt.sources.incremental("updated_at", initial_value=start_dt)) + resource = some_data(max_hours=4).with_name( + "copy_1" + ) # also make new resource state + resource.apply_hints( + incremental=dlt.sources.incremental("updated_at", initial_value=start_dt) + ) # and the data is naive. so it will work as expected with naive datetimes in the result set data = list(resource) if item_type == "json": @@ -1050,10 +1153,14 @@ def some_data( assert data[0]["updated_at"].tzinfo is None # end value is naive - resource = some_data(max_hours=4).with_name("copy_2") # also make new resource state + resource = some_data(max_hours=4).with_name( + "copy_2" + ) # also make new resource state resource.apply_hints( incremental=dlt.sources.incremental( - "updated_at", initial_value=start_dt, end_value=start_dt + timedelta(hours=3) + "updated_at", + initial_value=start_dt, + end_value=start_dt + timedelta(hours=3), ) ) data = list(resource) @@ -1061,7 +1168,9 @@ def some_data( assert data[0]["updated_at"].tzinfo is None # now use naive initial value but data is UTC - resource = some_data(max_hours=4, tz="UTC").with_name("copy_3") # also make new resource state + resource = some_data(max_hours=4, tz="UTC").with_name( + "copy_3" + ) # also make new resource state resource.apply_hints( incremental=dlt.sources.incremental( "updated_at", initial_value=start_dt + timedelta(hours=3) @@ -1097,7 +1206,9 @@ def endless_sequence( def test_chunked_ranges(item_type: TDataItemFormat) -> None: """Load chunked ranges with end value along with incremental""" - pipeline = dlt.pipeline(pipeline_name="incremental_" + uniq_id(), destination="duckdb") + pipeline = dlt.pipeline( + pipeline_name="incremental_" + uniq_id(), destination="duckdb" + ) chunks = [ # Load some start/end ranges in and out of order @@ -1117,7 +1228,8 @@ def test_chunked_ranges(item_type: TDataItemFormat) -> None: for start, end in chunks: pipeline.run( endless_sequence( - item_type, updated_at=dlt.sources.incremental(initial_value=start, end_value=end) + item_type, + updated_at=dlt.sources.incremental(initial_value=start, end_value=end), ), write_disposition="append", ) @@ -1162,10 +1274,14 @@ def batched_sequence( data = [{"updated_at": i} for i in range(start + 12, start + 20)] yield data_to_item_format(item_type, data) - pipeline = dlt.pipeline(pipeline_name="incremental_" + uniq_id(), destination="duckdb") + pipeline = dlt.pipeline( + pipeline_name="incremental_" + uniq_id(), destination="duckdb" + ) pipeline.run( - batched_sequence(updated_at=dlt.sources.incremental(initial_value=1, end_value=10)), + batched_sequence( + updated_at=dlt.sources.incremental(initial_value=1, end_value=10) + ), write_disposition="append", ) @@ -1180,7 +1296,9 @@ def batched_sequence( assert items == list(range(1, 10)) pipeline.run( - batched_sequence(updated_at=dlt.sources.incremental(initial_value=10, end_value=14)), + batched_sequence( + updated_at=dlt.sources.incremental(initial_value=10, end_value=14) + ), write_disposition="append", ) @@ -1198,11 +1316,14 @@ def batched_sequence( @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) def test_load_with_end_value_does_not_write_state(item_type: TDataItemFormat) -> None: """When loading chunk with initial/end value range. The resource state is untouched.""" - pipeline = dlt.pipeline(pipeline_name="incremental_" + uniq_id(), destination="duckdb") + pipeline = dlt.pipeline( + pipeline_name="incremental_" + uniq_id(), destination="duckdb" + ) pipeline.extract( endless_sequence( - item_type, updated_at=dlt.sources.incremental(initial_value=20, end_value=30) + item_type, + updated_at=dlt.sources.incremental(initial_value=20, end_value=30), ) ) @@ -1213,7 +1334,9 @@ def test_load_with_end_value_does_not_write_state(item_type: TDataItemFormat) -> def test_end_value_initial_value_errors(item_type: TDataItemFormat) -> None: @dlt.resource def some_data( - updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at"), + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at" + ), ) -> Any: yield {"updated_at": 1} @@ -1221,11 +1344,17 @@ def some_data( with pytest.raises(ConfigurationValueError) as ex: list(some_data(updated_at=dlt.sources.incremental(end_value=22))) - assert str(ex.value).startswith("Incremental 'end_value' was specified without 'initial_value'") + assert str(ex.value).startswith( + "Incremental 'end_value' was specified without 'initial_value'" + ) # max function and end_value lower than initial_value with pytest.raises(ConfigurationValueError) as ex: - list(some_data(updated_at=dlt.sources.incremental(initial_value=42, end_value=22))) + list( + some_data( + updated_at=dlt.sources.incremental(initial_value=42, end_value=22) + ) + ) assert str(ex.value).startswith( "Incremental 'initial_value' (42) is higher than 'end_value` (22)" @@ -1259,7 +1388,8 @@ def custom_last_value(items): ) assert ( - "The result of 'custom_last_value([end_value, initial_value])' must equal 'end_value'" + "The result of 'custom_last_value([end_value, initial_value])' must equal" + " 'end_value'" in str(ex.value) ) @@ -1330,7 +1460,9 @@ def ascending_single_item( assert updated_at.end_out_of_range is True return - pipeline = dlt.pipeline(pipeline_name="incremental_" + uniq_id(), destination="duckdb") + pipeline = dlt.pipeline( + pipeline_name="incremental_" + uniq_id(), destination="duckdb" + ) pipeline.extract(descending()) @@ -1460,7 +1592,9 @@ def ascending_desc( @pytest.mark.parametrize("order", ["random", "desc", "asc"]) @pytest.mark.parametrize("primary_key", [[], None, "updated_at"]) @pytest.mark.parametrize( - "deterministic", (True, False), ids=("deterministic-record", "non-deterministic-record") + "deterministic", + (True, False), + ids=("deterministic-record", "non-deterministic-record"), ) def test_unique_values_unordered_rows( item_type: TDataItemFormat, order: str, primary_key: Any, deterministic: bool @@ -1491,21 +1625,33 @@ def random_ascending_chunks( os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately pipeline = dlt.pipeline("test_unique_values_unordered_rows", destination="dummy") pipeline.run(random_ascending_chunks(order)) - assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == 121 + assert ( + pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] + == 121 + ) # 120 rows (one overlap - incremental reacquires and deduplicates) pipeline.run(random_ascending_chunks(order)) # overlapping element must be deduped when: # 1. we have primary key on just updated at # OR we have a key on full record but the record is deterministic so duplicate may be found - rows = 120 if primary_key == "updated_at" or (deterministic and primary_key != []) else 121 - assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == rows + rows = ( + 120 + if primary_key == "updated_at" or (deterministic and primary_key != []) + else 121 + ) + assert ( + pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] + == rows + ) @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) @pytest.mark.parametrize("primary_key", [[], None, "updated_at"]) # [], None, @pytest.mark.parametrize( - "deterministic", (True, False), ids=("deterministic-record", "non-deterministic-record") + "deterministic", + (True, False), + ids=("deterministic-record", "non-deterministic-record"), ) def test_carry_unique_hashes( item_type: TDataItemFormat, primary_key: Any, deterministic: bool @@ -1596,8 +1742,14 @@ def _assert_state(r_: DltResource, day: int, info: NormalizeInfo) -> None: @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) def test_get_incremental_value_type(item_type: TDataItemFormat) -> None: assert dlt.sources.incremental("id").get_incremental_value_type() is Any - assert dlt.sources.incremental("id", initial_value=0).get_incremental_value_type() is int - assert dlt.sources.incremental("id", initial_value=None).get_incremental_value_type() is Any + assert ( + dlt.sources.incremental("id", initial_value=0).get_incremental_value_type() + is int + ) + assert ( + dlt.sources.incremental("id", initial_value=None).get_incremental_value_type() + is Any + ) assert dlt.sources.incremental[int]("id").get_incremental_value_type() is int assert ( dlt.sources.incremental[pendulum.DateTime]("id").get_incremental_value_type() @@ -1640,7 +1792,9 @@ def test_type_3(updated_at: dlt.sources.incremental[int]): data = [{"updated_at": d} for d in [1, 2, 3]] yield data_to_item_format(item_type, data) - r = test_type_3(dlt.sources.incremental[float]("updated_at", allow_external_schedulers=True)) + r = test_type_3( + dlt.sources.incremental[float]("updated_at", allow_external_schedulers=True) + ) list(r) assert r.incremental._incremental.get_incremental_value_type() is float @@ -1652,7 +1806,9 @@ def test_type_4( data = [{"updated_at": d} for d in [1, 2, 3]] yield data_to_item_format(item_type, data) - r = test_type_4(dlt.sources.incremental[str]("updated_at", allow_external_schedulers=True)) + r = test_type_4( + dlt.sources.incremental[str]("updated_at", allow_external_schedulers=True) + ) list(r) assert r.incremental._incremental.get_incremental_value_type() is str @@ -1690,7 +1846,10 @@ def test_type_2( # set start and end values os.environ["DLT_START_VALUE"] = "2" result = list(test_type_2()) - assert data_item_to_list(item_type, result) == [{"updated_at": 2}, {"updated_at": 3}] + assert data_item_to_list(item_type, result) == [ + {"updated_at": 2}, + {"updated_at": 3}, + ] os.environ["DLT_END_VALUE"] = "3" result = list(test_type_2()) assert data_item_to_list(item_type, result) == [{"updated_at": 2}] @@ -1733,7 +1892,9 @@ def test_type_2( def test_allow_external_schedulers(item_type: TDataItemFormat) -> None: @dlt.resource() def test_type_2( - updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at"), + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at" + ), ): data = [{"updated_at": d} for d in [1, 2, 3]] yield data_to_item_format(item_type, data) diff --git a/tests/extract/test_sources.py b/tests/extract/test_sources.py index 6ff1a0bf5f..b22378864d 100644 --- a/tests/extract/test_sources.py +++ b/tests/extract/test_sources.py @@ -294,7 +294,9 @@ def some_data(param: str): # create two resource instances and extract in single ad hoc resource data1 = some_data("state1") data1._pipe.name = "state1_data" - dlt.pipeline(full_refresh=True).extract([data1, some_data("state2")], schema=Schema("default")) + dlt.pipeline(full_refresh=True).extract( + [data1, some_data("state2")], schema=Schema("default") + ) # both should be extracted. what we test here is the combination of binding the resource by calling it that clones the internal pipe # and then creating a source with both clones. if we keep same pipe id when cloning on call, a single pipe would be created shared by two resources assert all_yields == ["state1", "state2"] @@ -346,7 +348,12 @@ def yield_twice(item): # filter out small caps and insert this before the head tx_stage.add_filter(lambda letter: letter.isupper(), 0) # be got filtered out before duplication - assert list(dlt.resource(["A", "b", "C"], name="data") | tx_stage) == ["A", "A", "C", "C"] + assert list(dlt.resource(["A", "b", "C"], name="data") | tx_stage) == [ + "A", + "A", + "C", + "C", + ] # filter after duplication tx_stage = dlt.transformer()(yield_twice)() @@ -766,7 +773,11 @@ def test_add_transform_steps() -> None: def test_add_transform_steps_pipe() -> None: - r = dlt.resource([1, 2, 3], name="all") | (lambda i: str(i) * i) | (lambda i: (yield from i)) + r = ( + dlt.resource([1, 2, 3], name="all") + | (lambda i: str(i) * i) + | (lambda i: (yield from i)) + ) assert list(r) == ["1", "2", "2", "3", "3", "3"] @@ -1243,7 +1254,13 @@ def empty_gen(): empty_r = empty() # check defaults - assert empty_r.name == empty.name == empty_r.table_name == empty.table_name == "empty_gen" + assert ( + empty_r.name + == empty.name + == empty_r.table_name + == empty.table_name + == "empty_gen" + ) # assert empty_r._table_schema_template is None assert empty_r.compute_table_schema() == empty_table_schema assert empty_r.write_disposition == "append" @@ -1271,7 +1288,11 @@ def empty_gen(): "nullable": False, "primary_key": True, } - assert table["columns"]["b"] == {"name": "b", "nullable": False, "primary_key": True} + assert table["columns"]["b"] == { + "name": "b", + "nullable": False, + "primary_key": True, + } assert table["columns"]["c"] == {"merge_key": True, "name": "c", "nullable": False} assert table["name"] == "table" assert table["parent"] == "parent" @@ -1309,7 +1330,11 @@ def empty_gen(): merge_key="tags", ) # primary key not set here - assert empty_r.columns["tags"] == {"data_type": "complex", "name": "tags", "primary_key": False} + assert empty_r.columns["tags"] == { + "data_type": "complex", + "name": "tags", + "primary_key": False, + } # only in the computed table assert empty_r.compute_table_schema()["columns"]["tags"] == { "data_type": "complex", @@ -1327,10 +1352,14 @@ def empty_gen(): empty_r = empty() with pytest.raises(InconsistentTableTemplate): - empty_r.apply_hints(parent_table_name=lambda ev: ev["p"], write_disposition=None) + empty_r.apply_hints( + parent_table_name=lambda ev: ev["p"], write_disposition=None + ) empty_r.apply_hints( - table_name=lambda ev: ev["t"], parent_table_name=lambda ev: ev["p"], write_disposition=None + table_name=lambda ev: ev["t"], + parent_table_name=lambda ev: ev["p"], + write_disposition=None, ) assert empty_r._table_name_hint_fun is not None assert empty_r._table_has_other_dynamic_hints is True @@ -1342,7 +1371,9 @@ def empty_gen(): assert table["parent"] == "parent" # try write disposition and primary key - empty_r.apply_hints(primary_key=lambda ev: ev["pk"], write_disposition=lambda ev: ev["wd"]) + empty_r.apply_hints( + primary_key=lambda ev: ev["pk"], write_disposition=lambda ev: ev["wd"] + ) table = empty_r.compute_table_schema( {"t": "table", "p": "parent", "pk": ["a", "b"], "wd": "skip"} ) @@ -1358,7 +1389,13 @@ def empty_gen(): # dynamic columns empty_r.apply_hints(columns=lambda ev: ev["c"]) table = empty_r.compute_table_schema( - {"t": "table", "p": "parent", "pk": ["a", "b"], "wd": "skip", "c": [{"name": "tags"}]} + { + "t": "table", + "p": "parent", + "pk": ["a", "b"], + "wd": "skip", + "c": [{"name": "tags"}], + } ) assert table["columns"]["tags"] == {"name": "tags"} @@ -1374,11 +1411,15 @@ def empty_gen(): empty.apply_hints(write_disposition="append", create_table_variant=True) with pytest.raises(ValueError): empty.apply_hints( - table_name=lambda ev: ev["t"], write_disposition="append", create_table_variant=True + table_name=lambda ev: ev["t"], + write_disposition="append", + create_table_variant=True, ) # table a with replace - empty.apply_hints(table_name="table_a", write_disposition="replace", create_table_variant=True) + empty.apply_hints( + table_name="table_a", write_disposition="replace", create_table_variant=True + ) table_a = empty.compute_table_schema(meta=TableNameMeta("table_a")) assert table_a["name"] == "table_a" assert table_a["write_disposition"] == "replace" @@ -1394,14 +1435,18 @@ def empty_gen(): incremental=dlt.sources.incremental(cursor_path="x"), columns=[{"name": "id", "data_type": "bigint"}], ) - empty.apply_hints(table_name="table_b", write_disposition="merge", create_table_variant=True) + empty.apply_hints( + table_name="table_b", write_disposition="merge", create_table_variant=True + ) table_b = empty.compute_table_schema(meta=TableNameMeta("table_b")) assert table_b["name"] == "table_b" assert table_b["write_disposition"] == "merge" assert len(table_b["columns"]) == 1 assert table_b["columns"]["id"]["primary_key"] is True # overwrite table_b, remove column def and primary_key - empty.apply_hints(table_name="table_b", columns=[], primary_key=(), create_table_variant=True) + empty.apply_hints( + table_name="table_b", columns=[], primary_key=(), create_table_variant=True + ) table_b = empty.compute_table_schema(meta=TableNameMeta("table_b")) assert table_b["name"] == "table_b" assert table_b["write_disposition"] == "merge" @@ -1410,7 +1455,9 @@ def empty_gen(): # dyn hints not allowed with pytest.raises(InconsistentTableTemplate): empty.apply_hints( - table_name="table_b", write_disposition=lambda ev: ev["wd"], create_table_variant=True + table_name="table_b", + write_disposition=lambda ev: ev["wd"], + create_table_variant=True, ) @@ -1448,7 +1495,10 @@ def tx_step(item): source = DltSource( Schema("dupes"), "module", - [DltResource.from_data(input_gen), DltResource.from_data(input_gen).with_name("gen_2")], + [ + DltResource.from_data(input_gen), + DltResource.from_data(input_gen).with_name("gen_2"), + ], ) assert list(source) == [1, 2, 3, 1, 2, 3] diff --git a/tests/extract/test_validation.py b/tests/extract/test_validation.py index b9307ab97c..af3f2b81c6 100644 --- a/tests/extract/test_validation.py +++ b/tests/extract/test_validation.py @@ -35,7 +35,9 @@ def some_data() -> t.Iterator[TDataItems]: # Items are passed through model data = list(some_data()) # compare content-wise. model names change due to extra settings on columns - assert json.dumpb(data) == json.dumpb([SimpleModel(a=1, b="2"), SimpleModel(a=2, b="3")]) + assert json.dumpb(data) == json.dumpb( + [SimpleModel(a=1, b="2"), SimpleModel(a=2, b="3")] + ) @pytest.mark.parametrize("yield_list", [True, False]) @@ -55,7 +57,9 @@ def some_data() -> t.Iterator[TDataItems]: # Items are passed through model data = list(resource) - assert json.dumpb(data) == json.dumpb([SimpleModel(a=1, b="2"), SimpleModel(a=2, b="3")]) + assert json.dumpb(data) == json.dumpb( + [SimpleModel(a=1, b="2"), SimpleModel(a=2, b="3")] + ) @pytest.mark.parametrize("yield_list", [True, False]) @@ -131,7 +135,9 @@ class AnotherModel(BaseModel): b: str c: float = 0.5 - resource.validator = PydanticValidator(AnotherModel, column_mode="freeze", data_mode="freeze") + resource.validator = PydanticValidator( + AnotherModel, column_mode="freeze", data_mode="freeze" + ) assert resource.validator and resource.validator.model.__name__.startswith( AnotherModel.__name__ @@ -206,7 +212,9 @@ def some_data() -> t.Iterator[TDataItems]: yield from items # let it evolve - r: DltResource = dlt.resource(some_data(), schema_contract="evolve", columns=SimpleModel) + r: DltResource = dlt.resource( + some_data(), schema_contract="evolve", columns=SimpleModel + ) validator: PydanticValidator[SimpleModel] = r.validator # type: ignore[assignment] assert validator.column_mode == "evolve" assert validator.data_mode == "evolve" diff --git a/tests/extract/utils.py b/tests/extract/utils.py index 170781ba3c..fb4f6eeff5 100644 --- a/tests/extract/utils.py +++ b/tests/extract/utils.py @@ -29,7 +29,9 @@ def expect_extracted_file( file = next(gen, None) if file is None: raise FileNotFoundError( - PackageStorage.build_job_file_name(table_name, schema_name, validate_components=False) + PackageStorage.build_job_file_name( + table_name, schema_name, validate_components=False + ) ) assert file is not None # get remaining file names @@ -46,7 +48,9 @@ def expect_extracted_file( class AssertItems(ItemTransform[TDataItem]): - def __init__(self, expected_items: Any, item_type: TDataItemFormat = "json") -> None: + def __init__( + self, expected_items: Any, item_type: TDataItemFormat = "json" + ) -> None: self.expected_items = expected_items self.item_type = item_type diff --git a/tests/helpers/airflow_tests/conftest.py b/tests/helpers/airflow_tests/conftest.py index 3d040b4a11..28a1f75655 100644 --- a/tests/helpers/airflow_tests/conftest.py +++ b/tests/helpers/airflow_tests/conftest.py @@ -1,2 +1,7 @@ from tests.helpers.airflow_tests.utils import initialize_airflow_db -from tests.utils import preserve_environ, autouse_test_storage, TEST_STORAGE_ROOT, patch_home_dir +from tests.utils import ( + preserve_environ, + autouse_test_storage, + TEST_STORAGE_ROOT, + patch_home_dir, +) diff --git a/tests/helpers/airflow_tests/test_airflow_provider.py b/tests/helpers/airflow_tests/test_airflow_provider.py index 68e426deb9..ed4f727ee5 100644 --- a/tests/helpers/airflow_tests/test_airflow_provider.py +++ b/tests/helpers/airflow_tests/test_airflow_provider.py @@ -9,7 +9,9 @@ import dlt from dlt.common import pendulum from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.configuration.providers.toml import SECRETS_TOML_KEY DEFAULT_DATE = pendulum.datetime(2023, 4, 18, tz="Europe/Berlin") @@ -23,7 +25,9 @@ def test_airflow_secrets_toml_provider() -> None: @dag(start_date=DEFAULT_DATE) def test_dag(): - from dlt.common.configuration.providers.airflow import AirflowSecretsTomlProvider + from dlt.common.configuration.providers.airflow import ( + AirflowSecretsTomlProvider, + ) Variable.set(SECRETS_TOML_KEY, SECRETS_TOML_CONTENT) # make sure provider works while creating DAG @@ -149,7 +153,9 @@ def test_airflow_secrets_toml_provider_is_loaded(): dag = DAG(dag_id="test_dag", start_date=DEFAULT_DATE) def test_task(): - from dlt.common.configuration.providers.airflow import AirflowSecretsTomlProvider + from dlt.common.configuration.providers.airflow import ( + AirflowSecretsTomlProvider, + ) Variable.set(SECRETS_TOML_KEY, SECRETS_TOML_CONTENT) @@ -203,13 +209,17 @@ def test_airflow_secrets_toml_provider_missing_variable(): def test_task(): from dlt.common.configuration.specs import config_providers_context - from dlt.common.configuration.providers.airflow import AirflowSecretsTomlProvider + from dlt.common.configuration.providers.airflow import ( + AirflowSecretsTomlProvider, + ) # Make sure the variable is not set Variable.delete(SECRETS_TOML_KEY) providers = config_providers_context._extra_providers() provider = next( - provider for provider in providers if isinstance(provider, AirflowSecretsTomlProvider) + provider + for provider in providers + if isinstance(provider, AirflowSecretsTomlProvider) ) return { "airflow_secrets_toml": provider._toml.as_string(), @@ -239,7 +249,9 @@ def test_airflow_secrets_toml_provider_invalid_content(): def test_task(): import tomlkit - from dlt.common.configuration.providers.airflow import AirflowSecretsTomlProvider + from dlt.common.configuration.providers.airflow import ( + AirflowSecretsTomlProvider, + ) Variable.set(SECRETS_TOML_KEY, "invalid_content") diff --git a/tests/helpers/airflow_tests/test_airflow_wrapper.py b/tests/helpers/airflow_tests/test_airflow_wrapper.py index 84a30f730c..a98b19830c 100644 --- a/tests/helpers/airflow_tests/test_airflow_wrapper.py +++ b/tests/helpers/airflow_tests/test_airflow_wrapper.py @@ -12,7 +12,9 @@ import dlt from dlt.common import pendulum from dlt.common.utils import uniq_id -from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention +from dlt.common.normalizers.naming.snake_case import ( + NamingConvention as SnakeCaseNamingConvention, +) from dlt.helpers.airflow_helper import PipelineTasksGroup, DEFAULT_RETRY_BACKOFF from dlt.pipeline.exceptions import CannotRestorePipelineException, PipelineStepFailed @@ -155,16 +157,21 @@ def test_regular_run() -> None: ) pipeline_standalone.run(mock_data_source()) pipeline_standalone_counts = load_table_counts( - pipeline_standalone, *[t["name"] for t in pipeline_standalone.default_schema.data_tables()] + pipeline_standalone, + *[t["name"] for t in pipeline_standalone.default_schema.data_tables()], ) tasks_list: List[PythonOperator] = None - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_regular(): nonlocal tasks_list tasks = PipelineTasksGroup( - "pipeline_dag_regular", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + "pipeline_dag_regular", + local_data_folder=TEST_STORAGE_ROOT, + wipe_local_data=False, ) pipeline_dag_regular = dlt.pipeline( @@ -203,11 +210,15 @@ def dag_regular(): quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_decomposed.duckdb") - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_decomposed(): nonlocal tasks_list tasks = PipelineTasksGroup( - "pipeline_dag_decomposed", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + "pipeline_dag_decomposed", + local_data_folder=TEST_STORAGE_ROOT, + wipe_local_data=False, ) # set duckdb to be outside of pipeline folder which is dropped on each task @@ -229,8 +240,13 @@ def dag_decomposed(): dag_def = dag_decomposed() assert len(tasks_list) == 3 # task one by one - assert tasks_list[0].task_id == "pipeline_dag_decomposed.mock_data_source__r_init-_t_init_post" - assert tasks_list[1].task_id == "pipeline_dag_decomposed.mock_data_source__t1-_t2-_t3" + assert ( + tasks_list[0].task_id + == "pipeline_dag_decomposed.mock_data_source__r_init-_t_init_post" + ) + assert ( + tasks_list[1].task_id == "pipeline_dag_decomposed.mock_data_source__t1-_t2-_t3" + ) assert tasks_list[2].task_id == "pipeline_dag_decomposed.mock_data_source__r_isolee" dag_def.test() pipeline_dag_decomposed = dlt.attach(pipeline_name="pipeline_dag_decomposed") @@ -252,16 +268,21 @@ def test_run() -> None: ) pipeline_standalone.run(mock_data_source()) pipeline_standalone_counts = load_table_counts( - pipeline_standalone, *[t["name"] for t in pipeline_standalone.default_schema.data_tables()] + pipeline_standalone, + *[t["name"] for t in pipeline_standalone.default_schema.data_tables()], ) quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_regular.duckdb") - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_regular(): nonlocal task tasks = PipelineTasksGroup( - "pipeline_dag_regular", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + "pipeline_dag_regular", + local_data_folder=TEST_STORAGE_ROOT, + wipe_local_data=False, ) # set duckdb to be outside of pipeline folder which is dropped on each task @@ -297,18 +318,23 @@ def test_parallel_run(): ) pipeline_standalone.run(mock_data_source()) pipeline_standalone_counts = load_table_counts( - pipeline_standalone, *[t["name"] for t in pipeline_standalone.default_schema.data_tables()] + pipeline_standalone, + *[t["name"] for t in pipeline_standalone.default_schema.data_tables()], ) tasks_list: List[PythonOperator] = None quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_parallel.duckdb") - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_parallel(): nonlocal tasks_list tasks = PipelineTasksGroup( - "pipeline_dag_parallel", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + "pipeline_dag_parallel", + local_data_folder=TEST_STORAGE_ROOT, + wipe_local_data=False, ) # set duckdb to be outside of pipeline folder which is dropped on each task @@ -358,11 +384,15 @@ def test_parallel_incremental(): quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_parallel.duckdb") - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_parallel(): nonlocal tasks_list tasks = PipelineTasksGroup( - "pipeline_dag_parallel", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + "pipeline_dag_parallel", + local_data_folder=TEST_STORAGE_ROOT, + wipe_local_data=False, ) # set duckdb to be outside of pipeline folder which is dropped on each task @@ -396,18 +426,23 @@ def test_parallel_isolated_run(): ) pipeline_standalone.run(mock_data_source()) pipeline_standalone_counts = load_table_counts( - pipeline_standalone, *[t["name"] for t in pipeline_standalone.default_schema.data_tables()] + pipeline_standalone, + *[t["name"] for t in pipeline_standalone.default_schema.data_tables()], ) tasks_list: List[PythonOperator] = None quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_parallel.duckdb") - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_parallel(): nonlocal tasks_list tasks = PipelineTasksGroup( - "pipeline_dag_parallel", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + "pipeline_dag_parallel", + local_data_folder=TEST_STORAGE_ROOT, + wipe_local_data=False, ) # set duckdb to be outside of pipeline folder which is dropped on each task @@ -461,18 +496,23 @@ def test_parallel_run_single_resource(): ) pipeline_standalone.run(mock_data_single_resource()) pipeline_standalone_counts = load_table_counts( - pipeline_standalone, *[t["name"] for t in pipeline_standalone.default_schema.data_tables()] + pipeline_standalone, + *[t["name"] for t in pipeline_standalone.default_schema.data_tables()], ) tasks_list: List[PythonOperator] = None quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_parallel.duckdb") - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_parallel(): nonlocal tasks_list tasks = PipelineTasksGroup( - "pipeline_dag_parallel", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + "pipeline_dag_parallel", + local_data_folder=TEST_STORAGE_ROOT, + wipe_local_data=False, ) # set duckdb to be outside of pipeline folder which is dropped on each task @@ -535,11 +575,15 @@ def _fail_3(): raise Exception(f"Failed on retry #{retries}") yield from "ABC" - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_fail_3(): # by default we do not retry so this will fail tasks = PipelineTasksGroup( - "pipeline_fail_3", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + "pipeline_fail_3", + local_data_folder=TEST_STORAGE_ROOT, + wipe_local_data=False, ) pipeline_fail_3 = dlt.pipeline( @@ -549,7 +593,11 @@ def dag_fail_3(): credentials=":pipeline:", ) tasks.add_run( - pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True + pipeline_fail_3, + _fail_3, + trigger_rule="all_done", + retries=0, + provide_context=True, ) dag_def: DAG = dag_fail_3() @@ -559,7 +607,9 @@ def dag_fail_3(): ti._run_raw_task() assert pip_ex.value.step == "extract" - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_fail_4(): # by default we do not retry extract so we fail tasks = PipelineTasksGroup( @@ -576,7 +626,11 @@ def dag_fail_4(): credentials=":pipeline:", ) tasks.add_run( - pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True + pipeline_fail_3, + _fail_3, + trigger_rule="all_done", + retries=0, + provide_context=True, ) dag_def = dag_fail_4() @@ -587,7 +641,9 @@ def dag_fail_4(): ti._run_raw_task() assert pip_ex.value.step == "extract" - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_fail_5(): # this will retry tasks = PipelineTasksGroup( @@ -605,7 +661,11 @@ def dag_fail_5(): credentials=":pipeline:", ) tasks.add_run( - pipeline_fail_3, _fail_3, trigger_rule="all_done", retries=0, provide_context=True + pipeline_fail_3, + _fail_3, + trigger_rule="all_done", + retries=0, + provide_context=True, ) dag_def = dag_fail_5() @@ -619,7 +679,9 @@ def test_run_decomposed_with_state_wipe() -> None: dataset_name = "mock_data_" + uniq_id() pipeline_name = "pipeline_dag_regular_" + uniq_id() - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_regular(): tasks = PipelineTasksGroup( pipeline_name, @@ -672,7 +734,9 @@ def test_run_multiple_sources() -> None: dataset_name = "mock_data_" + uniq_id() pipeline_name = "pipeline_dag_regular_" + uniq_id() - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_serialize(): tasks = PipelineTasksGroup( pipeline_name, local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=True @@ -708,14 +772,23 @@ def dag_serialize(): ) pipeline_dag_serial.sync_destination() # we should have two schemas - assert set(pipeline_dag_serial.schema_names) == {"mock_data_source_state", "mock_data_source"} + assert set(pipeline_dag_serial.schema_names) == { + "mock_data_source_state", + "mock_data_source", + } counters_st_tasks = load_table_counts( pipeline_dag_serial, - *[t["name"] for t in pipeline_dag_serial.schemas["mock_data_source_state"].data_tables()], + *[ + t["name"] + for t in pipeline_dag_serial.schemas["mock_data_source_state"].data_tables() + ], ) counters_nst_tasks = load_table_counts( pipeline_dag_serial, - *[t["name"] for t in pipeline_dag_serial.schemas["mock_data_source"].data_tables()], + *[ + t["name"] + for t in pipeline_dag_serial.schemas["mock_data_source"].data_tables() + ], ) # print(counters_st_tasks) # print(counters_nst_tasks) @@ -736,7 +809,9 @@ def dag_serialize(): dataset_name = "mock_data_" + uniq_id() pipeline_name = "pipeline_dag_regular_" + uniq_id() - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_parallel(): tasks = PipelineTasksGroup( pipeline_name, local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=True @@ -771,25 +846,40 @@ def dag_parallel(): ) pipeline_dag_parallel.sync_destination() # we should have two schemas - assert set(pipeline_dag_parallel.schema_names) == {"mock_data_source_state", "mock_data_source"} + assert set(pipeline_dag_parallel.schema_names) == { + "mock_data_source_state", + "mock_data_source", + } counters_st_tasks_par = load_table_counts( pipeline_dag_parallel, - *[t["name"] for t in pipeline_dag_parallel.schemas["mock_data_source_state"].data_tables()], + *[ + t["name"] + for t in pipeline_dag_parallel.schemas[ + "mock_data_source_state" + ].data_tables() + ], ) counters_nst_tasks_par = load_table_counts( pipeline_dag_parallel, - *[t["name"] for t in pipeline_dag_parallel.schemas["mock_data_source"].data_tables()], + *[ + t["name"] + for t in pipeline_dag_parallel.schemas["mock_data_source"].data_tables() + ], ) assert counters_st_tasks == counters_st_tasks_par assert counters_nst_tasks == counters_nst_tasks_par - assert pipeline_dag_serial.state["sources"] == pipeline_dag_parallel.state["sources"] + assert ( + pipeline_dag_serial.state["sources"] == pipeline_dag_parallel.state["sources"] + ) # here two runs are mixed together dataset_name = "mock_data_" + uniq_id() pipeline_name = "pipeline_dag_regular_" + uniq_id() - @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + @dag( + schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args + ) def dag_mixed(): tasks = PipelineTasksGroup( pipeline_name, local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=True @@ -826,14 +916,23 @@ def dag_mixed(): ) pipeline_dag_mixed.sync_destination() # we should have two schemas - assert set(pipeline_dag_mixed.schema_names) == {"mock_data_source_state", "mock_data_source"} + assert set(pipeline_dag_mixed.schema_names) == { + "mock_data_source_state", + "mock_data_source", + } counters_st_tasks_par = load_table_counts( pipeline_dag_mixed, - *[t["name"] for t in pipeline_dag_mixed.schemas["mock_data_source_state"].data_tables()], + *[ + t["name"] + for t in pipeline_dag_mixed.schemas["mock_data_source_state"].data_tables() + ], ) counters_nst_tasks_par = load_table_counts( pipeline_dag_mixed, - *[t["name"] for t in pipeline_dag_mixed.schemas["mock_data_source"].data_tables()], + *[ + t["name"] + for t in pipeline_dag_mixed.schemas["mock_data_source"].data_tables() + ], ) assert counters_st_tasks == counters_st_tasks_par assert counters_nst_tasks == counters_nst_tasks_par @@ -887,7 +986,10 @@ def dag_parallel(): retries=0, provide_context=True, )[0] - assert task.task_id == "test_pipeline.mock_data_source__r_init-_t_init_post-_t1-_t2-2-more" + assert ( + task.task_id + == "test_pipeline.mock_data_source__r_init-_t_init_post-_t1-_t2-2-more" + ) task = tasks.add_run( pipe, @@ -898,7 +1000,8 @@ def dag_parallel(): provide_context=True, )[0] assert ( - task.task_id == "test_pipeline.mock_data_source__r_init-_t_init_post-_t1-_t2-2-more-2" + task.task_id + == "test_pipeline.mock_data_source__r_init-_t_init_post-_t1-_t2-2-more-2" ) tasks_list = tasks.add_run( diff --git a/tests/helpers/airflow_tests/test_join_airflow_scheduler.py b/tests/helpers/airflow_tests/test_join_airflow_scheduler.py index 8c1992c506..37ad425693 100644 --- a/tests/helpers/airflow_tests/test_join_airflow_scheduler.py +++ b/tests/helpers/airflow_tests/test_join_airflow_scheduler.py @@ -56,10 +56,16 @@ def scheduled() -> None: assert "Europe/Berlin" in str(state["updated_at"].tz) # must have UTC timezone assert ( - state["state"]["initial_value"] == CATCHUP_BEGIN == context["data_interval_start"] + state["state"]["initial_value"] + == CATCHUP_BEGIN + == context["data_interval_start"] ) assert state["state"]["initial_value"].tz == UTC - assert state["state"]["last_value"] == CATCHUP_BEGIN == context["data_interval_start"] + assert ( + state["state"]["last_value"] + == CATCHUP_BEGIN + == context["data_interval_start"] + ) assert state["state"]["last_value"].tz == UTC # end date assert r.incremental._incremental.end_value == context["data_interval_end"] @@ -81,7 +87,9 @@ def incremental_datetime( state = list(r)[0] # must have UTC timezone assert ( - state["state"]["initial_value"] == CATCHUP_BEGIN == context["data_interval_start"] + state["state"]["initial_value"] + == CATCHUP_BEGIN + == context["data_interval_start"] ) assert state["state"]["initial_value"].tz == UTC @@ -111,13 +119,20 @@ def incremental_datetime( "updated_at", allow_external_schedulers=True ) ): - yield {"updated_at": CATCHUP_BEGIN.int_timestamp, "state": updated_at.get_state()} + yield { + "updated_at": CATCHUP_BEGIN.int_timestamp, + "state": updated_at.get_state(), + } r = incremental_datetime() state = list(r)[0] - assert state["state"]["initial_value"] == context["data_interval_start"].int_timestamp assert ( - r.incremental._incremental.end_value == context["data_interval_end"].int_timestamp + state["state"]["initial_value"] + == context["data_interval_start"].int_timestamp + ) + assert ( + r.incremental._incremental.end_value + == context["data_interval_end"].int_timestamp ) # coerce to float @@ -127,12 +142,21 @@ def incremental_datetime( "updated_at", allow_external_schedulers=True ) ): - yield {"updated_at": CATCHUP_BEGIN.timestamp(), "state": updated_at.get_state()} + yield { + "updated_at": CATCHUP_BEGIN.timestamp(), + "state": updated_at.get_state(), + } r = incremental_datetime() state = list(r)[0] - assert state["state"]["initial_value"] == context["data_interval_start"].timestamp() - assert r.incremental._incremental.end_value == context["data_interval_end"].timestamp() + assert ( + state["state"]["initial_value"] + == context["data_interval_start"].timestamp() + ) + assert ( + r.incremental._incremental.end_value + == context["data_interval_end"].timestamp() + ) # coerce to str @dlt.resource() # type: ignore[no-redef] @@ -341,7 +365,12 @@ def scheduled() -> None: dag_def.test(execution_date=CATCHUP_BEGIN) assert "sources" not in pipeline.state - @dag(schedule=None, start_date=CATCHUP_BEGIN, catchup=False, default_args=default_args) + @dag( + schedule=None, + start_date=CATCHUP_BEGIN, + catchup=False, + default_args=default_args, + ) def dag_no_schedule(): @task def unscheduled() -> None: diff --git a/tests/helpers/airflow_tests/utils.py b/tests/helpers/airflow_tests/utils.py index 50aab77505..de17a90cc9 100644 --- a/tests/helpers/airflow_tests/utils.py +++ b/tests/helpers/airflow_tests/utils.py @@ -7,7 +7,9 @@ from airflow.models.variable import Variable from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.configuration.providers.toml import SECRETS_TOML_KEY diff --git a/tests/helpers/dbt_cloud_tests/test_dbt_cloud.py b/tests/helpers/dbt_cloud_tests/test_dbt_cloud.py index 600a11558b..37e9ca0a81 100644 --- a/tests/helpers/dbt_cloud_tests/test_dbt_cloud.py +++ b/tests/helpers/dbt_cloud_tests/test_dbt_cloud.py @@ -16,7 +16,9 @@ def test_trigger_run(wait_outcome): def test_run_status(wait_outcome): # Trigger job run and wait for an outcome run_status = run_dbt_cloud_job(wait_for_outcome=False) - run_status = get_dbt_cloud_run_status(run_id=run_status["id"], wait_for_outcome=wait_outcome) + run_status = get_dbt_cloud_run_status( + run_id=run_status["id"], wait_for_outcome=wait_outcome + ) print(run_status) assert run_status.get("id") is not None diff --git a/tests/helpers/dbt_tests/local/test_dbt_utils.py b/tests/helpers/dbt_tests/local/test_dbt_utils.py index 6c2d28ed23..aeb948af89 100644 --- a/tests/helpers/dbt_tests/local/test_dbt_utils.py +++ b/tests/helpers/dbt_tests/local/test_dbt_utils.py @@ -42,7 +42,9 @@ def test_dbt_commands(test_storage: FileStorage) -> None: dbt_vars = {"dbt_schema": schema_name} # extract postgres creds from env, parse and emit - credentials = resolve_configuration(PostgresCredentials(), sections=("destination", "postgres")) + credentials = resolve_configuration( + PostgresCredentials(), sections=("destination", "postgres") + ) add_config_to_env(credentials, ("dlt",)) repo_path = clone_jaffle_repo(test_storage) @@ -76,13 +78,17 @@ def test_dbt_commands(test_storage: FileStorage) -> None: assert results[0] == "jaffle_shop.not_null_orders_amount" # run debug, that will fail with pytest.raises(DBTProcessingError) as dbt_err: - run_dbt_command(repo_path, "debug", ".", global_args=global_args, package_vars=dbt_vars) + run_dbt_command( + repo_path, "debug", ".", global_args=global_args, package_vars=dbt_vars + ) # results are bool assert dbt_err.value.command == "debug" # we have no database connectivity so tests will fail with pytest.raises(DBTProcessingError) as dbt_err: - run_dbt_command(repo_path, "test", ".", global_args=global_args, package_vars=dbt_vars) + run_dbt_command( + repo_path, "test", ".", global_args=global_args, package_vars=dbt_vars + ) # in that case test results are bool, not list of tests runs assert dbt_err.value.command == "test" @@ -101,7 +107,8 @@ def test_dbt_commands(test_storage: FileStorage) -> None: # copy a correct profile shutil.copy( - "./tests/helpers/dbt_tests/cases/profiles.yml", os.path.join(repo_path, "profiles.yml") + "./tests/helpers/dbt_tests/cases/profiles.yml", + os.path.join(repo_path, "profiles.yml"), ) results = run_dbt_command( diff --git a/tests/helpers/dbt_tests/local/test_runner_destinations.py b/tests/helpers/dbt_tests/local/test_runner_destinations.py index c9e4b7c83b..17d2eea973 100644 --- a/tests/helpers/dbt_tests/local/test_runner_destinations.py +++ b/tests/helpers/dbt_tests/local/test_runner_destinations.py @@ -24,10 +24,14 @@ ALL_DBT_DESTINATIONS_NAMES = ["bigquery"] # "redshift", -@pytest.fixture(scope="module", params=ALL_DBT_DESTINATIONS, ids=ALL_DBT_DESTINATIONS_NAMES) +@pytest.fixture( + scope="module", params=ALL_DBT_DESTINATIONS, ids=ALL_DBT_DESTINATIONS_NAMES +) def destination_info(request: Any) -> Iterator[DBTDestinationInfo]: # this resolves credentials and sets up env for dbt then deletes temp datasets - with setup_rasa_runner_client(request.param.destination_name, DESTINATION_DATASET_NAME): + with setup_rasa_runner_client( + request.param.destination_name, DESTINATION_DATASET_NAME + ): # yield DBTDestinationInfo yield request.param @@ -88,7 +92,9 @@ def test_reinitialize_package() -> None: def test_dbt_test_no_raw_schema(destination_info: DBTDestinationInfo) -> None: # force non existing dataset - runner = setup_rasa_runner(destination_info.destination_name, "jm_dev_2" + uniq_id()) + runner = setup_rasa_runner( + destination_info.destination_name, "jm_dev_2" + uniq_id() + ) # source test should not pass with pytest.raises(PrerequisitesException) as prq_ex: runner.run_all( @@ -109,13 +115,23 @@ def test_dbt_run_full_refresh(destination_info: DBTDestinationInfo) -> None: additional_vars={"user_id": "metadata__user_id"}, source_tests_selector="tag:prerequisites", ) - assert all(r.message.startswith(destination_info.replace_strategy) for r in run_results) is True + assert ( + all( + r.message.startswith(destination_info.replace_strategy) for r in run_results + ) + is True + ) assert find_run_result(run_results, "_loads") is not None # all models must be SELECT as we do full refresh assert find_run_result(run_results, "_loads").message.startswith( destination_info.replace_strategy ) - assert all(m.message.startswith(destination_info.replace_strategy) for m in run_results) is True + assert ( + all( + m.message.startswith(destination_info.replace_strategy) for m in run_results + ) + is True + ) # all tests should pass runner.test( @@ -124,7 +140,9 @@ def test_dbt_run_full_refresh(destination_info: DBTDestinationInfo) -> None: ) -def test_dbt_run_error_via_additional_vars(destination_info: DBTDestinationInfo) -> None: +def test_dbt_run_error_via_additional_vars( + destination_info: DBTDestinationInfo, +) -> None: if destination_info.destination_name == "redshift": pytest.skip("redshift disabled due to missing fixtures") # generate with setting external user and session to non existing fields (metadata__sess_id not exists in JM schema) @@ -143,7 +161,9 @@ def test_dbt_run_error_via_additional_vars(destination_info: DBTDestinationInfo) assert "metadata__sess_id" in stg_interactions.message -def test_dbt_incremental_schema_out_of_sync_error(destination_info: DBTDestinationInfo) -> None: +def test_dbt_incremental_schema_out_of_sync_error( + destination_info: DBTDestinationInfo, +) -> None: if destination_info.destination_name == "redshift": pytest.skip("redshift disabled due to missing fixtures") runner = setup_rasa_runner(destination_info.destination_name) diff --git a/tests/helpers/dbt_tests/test_runner_dbt_versions.py b/tests/helpers/dbt_tests/test_runner_dbt_versions.py index a7408f00f3..c722ebc982 100644 --- a/tests/helpers/dbt_tests/test_runner_dbt_versions.py +++ b/tests/helpers/dbt_tests/test_runner_dbt_versions.py @@ -8,7 +8,10 @@ from dlt.common import json from dlt.common.configuration import resolve_configuration -from dlt.common.configuration.specs import GcpServiceAccountCredentials, CredentialsWithDefault +from dlt.common.configuration.specs import ( + GcpServiceAccountCredentials, + CredentialsWithDefault, +) from dlt.common.storages.file_storage import FileStorage from dlt.common.runners import Venv from dlt.common.runners.synth_pickle import decode_obj, encode_obj @@ -71,7 +74,11 @@ def dbt_package_f(request: Any) -> Iterator[Tuple[str, AnyFun]]: def test_infer_venv_deps() -> None: requirements = _create_dbt_deps(["postgres", "mssql"]) - assert requirements[:3] == [f"dbt-core{DEFAULT_DBT_VERSION}", "dbt-postgres", "dbt-sqlserver"] + assert requirements[:3] == [ + f"dbt-core{DEFAULT_DBT_VERSION}", + "dbt-postgres", + "dbt-sqlserver", + ] # should lead to here assert os.path.isdir(requirements[-1]) # provide exact version @@ -87,7 +94,9 @@ def test_infer_venv_deps() -> None: def test_default_profile_name() -> None: - bigquery_config = BigQueryClientConfiguration(credentials=GcpServiceAccountCredentials()) + bigquery_config = BigQueryClientConfiguration( + credentials=GcpServiceAccountCredentials() + ) assert isinstance(bigquery_config.credentials, CredentialsWithDefault) # default credentials are not present assert _default_profile_name(bigquery_config) == "bigquery" @@ -193,10 +202,15 @@ def test_runner_dbt_destinations( jaffle_base_dir = "jaffle_" + destination_name test_storage.create_folder(jaffle_base_dir) results = dbt_func( - client.config, test_storage.make_full_path(jaffle_base_dir), JAFFLE_SHOP_REPO + client.config, + test_storage.make_full_path(jaffle_base_dir), + JAFFLE_SHOP_REPO, ).run_all(["--fail-fast", "--full-refresh"]) assert_jaffle_completed( - test_storage, results, destination_name, jaffle_dir=jaffle_base_dir + "/jaffle_shop" + test_storage, + results, + destination_name, + jaffle_dir=jaffle_base_dir + "/jaffle_shop", ) @@ -212,18 +226,25 @@ def test_run_jaffle_from_folder_incremental( os.path.join(repo_path, "models", "customers.sql"), ) results = dbt_func(client.config, None, repo_path).run_all(run_params=None) - assert_jaffle_completed(test_storage, results, destination_name, jaffle_dir="jaffle_shop") + assert_jaffle_completed( + test_storage, results, destination_name, jaffle_dir="jaffle_shop" + ) results = dbt_func(client.config, None, repo_path).run_all() # out of 100 records 0 was inserted customers = find_run_result(results, "customers") - assert customers.message in JAFFLE_MESSAGES_INCREMENTAL[destination_name]["customers"] + assert ( + customers.message + in JAFFLE_MESSAGES_INCREMENTAL[destination_name]["customers"] + ) # change the column name. that will force dbt to fail (on_schema_change='fail'). the runner should do a full refresh shutil.copy( "./tests/helpers/dbt_tests/cases/jaffle_customers_incremental_new_column.sql", os.path.join(repo_path, "models", "customers.sql"), ) results = dbt_func(client.config, None, repo_path).run_all(run_params=None) - assert_jaffle_completed(test_storage, results, destination_name, jaffle_dir="jaffle_shop") + assert_jaffle_completed( + test_storage, results, destination_name, jaffle_dir="jaffle_shop" + ) def test_run_jaffle_fail_prerequisites( @@ -239,7 +260,9 @@ def test_run_jaffle_fail_prerequisites( ).run_all(["--fail-fast", "--full-refresh"], source_tests_selector="*") proc_err = pr_exc.value.args[0] assert isinstance(proc_err, DBTProcessingError) - customers = find_run_result(proc_err.run_results, "unique_customers_customer_id") + customers = find_run_result( + proc_err.run_results, "unique_customers_customer_id" + ) assert customers.status == "error" assert len(proc_err.run_results) == 20 assert all(r.status == "error" for r in proc_err.run_results) @@ -257,7 +280,10 @@ def test_run_jaffle_invalid_run_args( client.config, test_storage.make_full_path("jaffle"), JAFFLE_SHOP_REPO ).run_all(["--wrong_flag"]) # dbt < 1.5 raises systemexit, dbt >= 1.5 just returns success False - assert isinstance(pr_exc.value.dbt_results, SystemExit) or pr_exc.value.dbt_results is None + assert ( + isinstance(pr_exc.value.dbt_results, SystemExit) + or pr_exc.value.dbt_results is None + ) def test_run_jaffle_failed_run( diff --git a/tests/helpers/providers/test_google_secrets_provider.py b/tests/helpers/providers/test_google_secrets_provider.py index 00c54b5705..c623015a45 100644 --- a/tests/helpers/providers/test_google_secrets_provider.py +++ b/tests/helpers/providers/test_google_secrets_provider.py @@ -3,7 +3,9 @@ from dlt.common.configuration.specs import GcpServiceAccountCredentials from dlt.common.configuration.providers import GoogleSecretsProvider from dlt.common.configuration.accessors import secrets -from dlt.common.configuration.specs.config_providers_context import _google_secrets_provider +from dlt.common.configuration.specs.config_providers_context import ( + _google_secrets_provider, +) from dlt.common.configuration.specs.run_configuration import RunConfiguration from dlt.common.configuration.specs import GcpServiceAccountCredentials, known_sections from dlt.common.typing import AnyType @@ -26,7 +28,8 @@ def test_regular_keys() -> None: init_logging(RunConfiguration()) # copy bigquery credentials into providers credentials c = resolve_configuration( - GcpServiceAccountCredentials(), sections=(known_sections.DESTINATION, "bigquery") + GcpServiceAccountCredentials(), + sections=(known_sections.DESTINATION, "bigquery"), ) secrets[f"{known_sections.PROVIDERS}.google_secrets.credentials"] = dict(c) # c = secrets.get("destination.credentials", GcpServiceAccountCredentials) @@ -38,7 +41,10 @@ def test_regular_keys() -> None: "pipelinex-secret_value", ) assert provider.get_value("secret_value", AnyType, None) == (2137, "secret_value") - assert provider.get_value("secret_key", AnyType, None, "api") == ("ABCD", "api-secret_key") + assert provider.get_value("secret_key", AnyType, None, "api") == ( + "ABCD", + "api-secret_key", + ) # load secrets toml per pipeline provider.get_value("secret_key", AnyType, "pipeline", "api") @@ -52,12 +58,16 @@ def test_regular_keys() -> None: ) # load source test_source which should also load "sources", "pipeline-sources", "sources-test_source" and "pipeline-sources-test_source" - assert provider.get_value("only_pipeline", AnyType, "pipeline", "sources", "test_source") == ( + assert provider.get_value( + "only_pipeline", AnyType, "pipeline", "sources", "test_source" + ) == ( "ONLY", "pipeline-sources-test_source-only_pipeline", ) # we set sources.test_source.secret_prop_1="OVR_A" in pipeline-sources to override value in sources - assert provider.get_value("secret_prop_1", AnyType, None, "sources", "test_source") == ( + assert provider.get_value( + "secret_prop_1", AnyType, None, "sources", "test_source" + ) == ( "OVR_A", "sources-test_source-secret_prop_1", ) @@ -72,19 +82,26 @@ def test_regular_keys() -> None: "sources-all_sources_present", ) # get element unique to sources-test_source - assert provider.get_value("secret_prop_2", AnyType, None, "sources", "test_source") == ( + assert provider.get_value( + "secret_prop_2", AnyType, None, "sources", "test_source" + ) == ( "B", "sources-test_source-secret_prop_2", ) # this destination will not be found - assert provider.get_value("url", AnyType, "pipeline", "destination", "filesystem") == ( + assert provider.get_value( + "url", AnyType, "pipeline", "destination", "filesystem" + ) == ( None, "pipeline-destination-filesystem-url", ) # try a single secret value - assert provider.get_value("secret", TSecretValue, "pipeline") == (None, "pipeline-secret") + assert provider.get_value("secret", TSecretValue, "pipeline") == ( + None, + "pipeline-secret", + ) # enable the single secrets provider.only_toml_fragments = False @@ -99,8 +116,14 @@ def test_regular_keys() -> None: # request json # print(provider._toml.as_string()) - assert provider.get_value("halo", str, None, "halo") == ({"halo": True}, "halo-halo") - assert provider.get_value("halo", str, None, "halo", "halo") == (True, "halo-halo-halo") + assert provider.get_value("halo", str, None, "halo") == ( + {"halo": True}, + "halo-halo", + ) + assert provider.get_value("halo", str, None, "halo", "halo") == ( + True, + "halo-halo-halo", + ) # def test_special_sections() -> None: diff --git a/tests/helpers/streamlit_tests/test_streamlit_show_resources.py b/tests/helpers/streamlit_tests/test_streamlit_show_resources.py index dd807260fe..9743bec7b9 100644 --- a/tests/helpers/streamlit_tests/test_streamlit_show_resources.py +++ b/tests/helpers/streamlit_tests/test_streamlit_show_resources.py @@ -84,11 +84,18 @@ def test_multiple_resources_pipeline(): assert set(load_info.pipeline.schema_names) == set(["source2", "source1"]) # type: ignore[attr-defined] assert source1_schema.data_tables()[0]["name"] == "one" - assert source1_schema.data_tables()[0]["columns"]["column_1"].get("primary_key") is True - assert source1_schema.data_tables()[0]["columns"]["column_1"].get("merge_key") is True + assert ( + source1_schema.data_tables()[0]["columns"]["column_1"].get("primary_key") + is True + ) + assert ( + source1_schema.data_tables()[0]["columns"]["column_1"].get("merge_key") is True + ) assert source1_schema.data_tables()[0]["write_disposition"] == "merge" os.environ["DLT_TEST_PIPELINE_NAME"] = "test_resources_pipeline" - streamlit_app = AppTest.from_file(str(streamlit_app_path / "index.py"), default_timeout=5) + streamlit_app = AppTest.from_file( + str(streamlit_app_path / "index.py"), default_timeout=5 + ) streamlit_app.run() assert not streamlit_app.exception diff --git a/tests/libs/test_parquet_writer.py b/tests/libs/test_parquet_writer.py index b1c19114fe..0e6608cbab 100644 --- a/tests/libs/test_parquet_writer.py +++ b/tests/libs/test_parquet_writer.py @@ -13,7 +13,12 @@ from dlt.common.time import ensure_pendulum_date, ensure_pendulum_datetime from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW_ALL_DATA_TYPES -from tests.utils import TEST_STORAGE_ROOT, write_version, autouse_test_storage, preserve_environ +from tests.utils import ( + TEST_STORAGE_ROOT, + write_version, + autouse_test_storage, + preserve_environ, +) def get_writer( @@ -47,7 +52,15 @@ def test_parquet_writer_schema_evolution_with_big_buffer() -> None: [{"col1": 1, "col2": 2, "col3": "3"}], {"col1": c1, "col2": c2, "col3": c3} ) writer.write_data_item( - [{"col1": 1, "col2": 2, "col3": "3", "col4": "4", "col5": {"hello": "marcin"}}], + [ + { + "col1": 1, + "col2": 2, + "col3": "3", + "col4": "4", + "col5": {"hello": "marcin"}, + } + ], {"col1": c1, "col2": c2, "col3": c3, "col4": c4}, ) @@ -68,11 +81,20 @@ def test_parquet_writer_schema_evolution_with_small_buffer() -> None: with get_writer("parquet", buffer_max_items=4, file_max_items=50) as writer: for _ in range(0, 20): writer.write_data_item( - [{"col1": 1, "col2": 2, "col3": "3"}], {"col1": c1, "col2": c2, "col3": c3} + [{"col1": 1, "col2": 2, "col3": "3"}], + {"col1": c1, "col2": c2, "col3": c3}, ) for _ in range(0, 20): writer.write_data_item( - [{"col1": 1, "col2": 2, "col3": "3", "col4": "4", "col5": {"hello": "marcin"}}], + [ + { + "col1": 1, + "col2": 2, + "col3": "3", + "col4": "4", + "col5": {"hello": "marcin"}, + } + ], {"col1": c1, "col2": c2, "col3": c3, "col4": c4}, ) @@ -193,12 +215,17 @@ def test_parquet_writer_config() -> None: os.environ["NORMALIZE__DATA_WRITER__DATA_PAGE_SIZE"] = str(1024 * 512) os.environ["NORMALIZE__DATA_WRITER__TIMESTAMP_TIMEZONE"] = "America/New York" - with inject_section(ConfigSectionContext(pipeline_name=None, sections=("normalize",))): + with inject_section( + ConfigSectionContext(pipeline_name=None, sections=("normalize",)) + ): with get_writer("parquet", file_max_bytes=2**8, buffer_max_items=2) as writer: for i in range(0, 5): writer.write_data_item( [{"col1": i, "col2": pendulum.now()}], - {"col1": new_column("col1", "bigint"), "col2": new_column("col2", "timestamp")}, + { + "col1": new_column("col1", "bigint"), + "col2": new_column("col2", "timestamp"), + }, ) # force the parquet writer to be created writer._flush_items() @@ -222,7 +249,13 @@ def test_parquet_writer_schema_from_caps() -> None: with get_writer("parquet", file_max_bytes=2**8, buffer_max_items=2) as writer: for _ in range(0, 5): writer.write_data_item( - [{"col1": Decimal("2617.27"), "col2": pendulum.now(), "col3": Decimal(2**250)}], + [ + { + "col1": Decimal("2617.27"), + "col2": pendulum.now(), + "col3": Decimal(2**250), + } + ], { "col1": new_column("col1", "decimal"), "col2": new_column("col2", "timestamp"), diff --git a/tests/libs/test_pyarrow.py b/tests/libs/test_pyarrow.py index 68541e96e0..9f98811f45 100644 --- a/tests/libs/test_pyarrow.py +++ b/tests/libs/test_pyarrow.py @@ -21,10 +21,14 @@ def test_py_arrow_to_table_schema_columns(): caps = DestinationCapabilitiesContext.generic_capabilities() # The arrow schema will add precision dlt_schema["col4"]["precision"] = caps.timestamp_precision - dlt_schema["col6"]["precision"], dlt_schema["col6"]["scale"] = caps.decimal_precision + dlt_schema["col6"]["precision"], dlt_schema["col6"]["scale"] = ( + caps.decimal_precision + ) dlt_schema["col11"]["precision"] = caps.timestamp_precision dlt_schema["col4_null"]["precision"] = caps.timestamp_precision - dlt_schema["col6_null"]["precision"], dlt_schema["col6_null"]["scale"] = caps.decimal_precision + dlt_schema["col6_null"]["precision"], dlt_schema["col6_null"]["scale"] = ( + caps.decimal_precision + ) dlt_schema["col11_null"]["precision"] = caps.timestamp_precision # Ignoring wei as we can't distinguish from decimal @@ -63,9 +67,9 @@ def test_to_arrow_scalar() -> None: naive_dt = get_py_arrow_timestamp(6, tz=None) # print(naive_dt) # naive datetimes are converted as UTC when time aware python objects are used - assert to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt).as_py() == datetime( - 2021, 1, 1, 5, 2, 32 - ) + assert to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32), naive_dt + ).as_py() == datetime(2021, 1, 1, 5, 2, 32) assert to_arrow_scalar( datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc), naive_dt ).as_py() == datetime(2021, 1, 1, 5, 2, 32) diff --git a/tests/libs/test_pydantic.py b/tests/libs/test_pydantic.py index d6dc29e0c8..8f8050fd2c 100644 --- a/tests/libs/test_pydantic.py +++ b/tests/libs/test_pydantic.py @@ -153,7 +153,9 @@ class User(BaseModel): account_id: UUID4 optional_uuid: Optional[UUID4] name: Annotated[str, "PII", "name"] - favorite_book: Annotated[Union[Annotated[BookInfo, "meta"], BookGenre, None], "union metadata"] + favorite_book: Annotated[ + Union[Annotated[BookInfo, "meta"], BookGenre, None], "union metadata" + ] created_at: Optional[datetime] labels: List[str] user_label: UserLabel @@ -358,7 +360,9 @@ def test_nested_model_config_propagation() -> None: # print(model_freeze.__fields__["address"].annotation) -@pytest.mark.skipif(sys.version_info < (3, 10), reason="Runs only on Python 3.10 and later") +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="Runs only on Python 3.10 and later" +) def test_nested_model_config_propagation_optional_with_pipe(): """We would like to test that using Optional and new | syntax works as expected when generating a schema thus two versions of user model are defined and both instantiated @@ -427,7 +431,9 @@ class ItemModel(BaseModel): dlt_config: ClassVar[DltConfig] = {"skip_complex_types": False} # non validating items removed from the list (both extra and declared) - discard_model = apply_schema_contract_to_model(ItemModel, "discard_row", "discard_row") + discard_model = apply_schema_contract_to_model( + ItemModel, "discard_row", "discard_row" + ) discard_list_model = create_list_model(discard_model) # violate data type items = validate_items( @@ -490,7 +496,9 @@ class ItemModel(BaseModel): assert val_ex.value.data_item == {"a": 2, "b": False} # discard values - discard_value_model = apply_schema_contract_to_model(ItemModel, "discard_value", "freeze") + discard_value_model = apply_schema_contract_to_model( + ItemModel, "discard_value", "freeze" + ) discard_list_model = create_list_model(discard_value_model) # violate extra field items = validate_items( @@ -564,13 +572,22 @@ class ItemModel(BaseModel): dlt_config: ClassVar[DltConfig] = {"skip_complex_types": False} # non validating items removed from the list (both extra and declared) - discard_model = apply_schema_contract_to_model(ItemModel, "discard_row", "discard_row") + discard_model = apply_schema_contract_to_model( + ItemModel, "discard_row", "discard_row" + ) # violate data type - assert validate_item("items", discard_model, {"b": 2}, "discard_row", "discard_row") is None + assert ( + validate_item("items", discard_model, {"b": 2}, "discard_row", "discard_row") + is None + ) # violate extra field assert ( validate_item( - "items", discard_model, {"b": False, "a": False}, "discard_row", "discard_row" + "items", + discard_model, + {"b": False, "a": False}, + "discard_row", + "discard_row", ) is None ) @@ -599,10 +616,16 @@ class ItemModel(BaseModel): assert val_ex.value.data_item == {"a": 2, "b": False} # discard values - discard_value_model = apply_schema_contract_to_model(ItemModel, "discard_value", "freeze") + discard_value_model = apply_schema_contract_to_model( + ItemModel, "discard_value", "freeze" + ) # violate extra field item = validate_item( - "items", discard_value_model, {"b": False, "a": False}, "discard_value", "freeze" + "items", + discard_value_model, + {"b": False, "a": False}, + "discard_value", + "freeze", ) # "a" extra got removed assert item.dict() == {"b": False} @@ -613,7 +636,9 @@ class ItemModel(BaseModel): item = validate_item("items", evolve_model, {"b": 2}, "evolve", "evolve") assert item.b == 2 # extra fields allowed - item = validate_item("items", evolve_model, {"b": False, "a": False}, "evolve", "evolve") + item = validate_item( + "items", evolve_model, {"b": False, "a": False}, "evolve", "evolve" + ) assert item.b is False assert item.a is False # type: ignore[attr-defined] @@ -624,7 +649,9 @@ class ItemModel(BaseModel): assert item.b == 3 # extra fields forbidden - full rows discarded assert ( - validate_item("items", mixed_model, {"b": False, "a": False}, "discard_row", "evolve") + validate_item( + "items", mixed_model, {"b": False, "a": False}, "discard_row", "evolve" + ) is None ) @@ -673,7 +700,11 @@ class MyParent(Parent): assert schema == { "child": {"data_type": "complex", "name": "child", "nullable": False}, - "data_dictionary": {"data_type": "complex", "name": "data_dictionary", "nullable": False}, + "data_dictionary": { + "data_type": "complex", + "name": "data_dictionary", + "nullable": False, + }, "optional_parent_attribute": { "data_type": "text", "name": "optional_parent_attribute", diff --git a/tests/load/athena_iceberg/test_athena_iceberg.py b/tests/load/athena_iceberg/test_athena_iceberg.py index 6804b98427..61b8cf6fdb 100644 --- a/tests/load/athena_iceberg/test_athena_iceberg.py +++ b/tests/load/athena_iceberg/test_athena_iceberg.py @@ -37,14 +37,19 @@ def items() -> Iterator[Any]: yield { "id": 1, "name": "item", - "sub_items": [{"id": 101, "name": "sub item 101"}, {"id": 101, "name": "sub item 102"}], + "sub_items": [ + {"id": 101, "name": "sub item 101"}, + {"id": 101, "name": "sub item 102"}, + ], } @dlt.resource(name="items_normal", write_disposition="append") def items_normal(): yield from items() - @dlt.resource(name="items_iceberg", write_disposition="append", table_format="iceberg") + @dlt.resource( + name="items_iceberg", write_disposition="append", table_format="iceberg" + ) def items_iceberg(): yield from items() @@ -67,10 +72,18 @@ def items_iceberg(): # modifying regular athena table will fail with pytest.raises(DatabaseTerminalException) as dbex: client.execute_sql("UPDATE items_normal SET name='new name'") - assert "Modifying Hive table rows is only supported for transactional tables" in str(dbex) + assert ( + "Modifying Hive table rows is only supported for transactional tables" + in str(dbex) + ) with pytest.raises(DatabaseTerminalException) as dbex: - client.execute_sql("UPDATE items_normal__sub_items SET name='super new name'") - assert "Modifying Hive table rows is only supported for transactional tables" in str(dbex) + client.execute_sql( + "UPDATE items_normal__sub_items SET name='super new name'" + ) + assert ( + "Modifying Hive table rows is only supported for transactional tables" + in str(dbex) + ) # modifying iceberg table will succeed client.execute_sql("UPDATE items_iceberg SET name='new name'") diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index a97b612ad0..62e2d13e1f 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -19,8 +19,14 @@ from dlt.common.storages import FileStorage from dlt.common.utils import digest128, uniq_id, custom_environ -from dlt.destinations.impl.bigquery.bigquery import BigQueryClient, BigQueryClientConfiguration -from dlt.destinations.exceptions import LoadJobNotExistsException, LoadJobTerminalException +from dlt.destinations.impl.bigquery.bigquery import ( + BigQueryClient, + BigQueryClientConfiguration, +) +from dlt.destinations.exceptions import ( + LoadJobNotExistsException, + LoadJobTerminalException, +) from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, preserve_environ from tests.common.utils import json_case_path as common_json_case_path @@ -94,7 +100,11 @@ def test_service_credentials_native_credentials_object(environment: Any) -> None def _assert_credentials(gcp_credentials): assert gcp_credentials.to_native_credentials() is credentials # check props - assert gcp_credentials.project_id == credentials.project_id == "level-dragon-333019" + assert ( + gcp_credentials.project_id + == credentials.project_id + == "level-dragon-333019" + ) assert gcp_credentials.client_email == credentials.service_account_email assert gcp_credentials.private_key is credentials @@ -121,7 +131,12 @@ def test_oauth_credentials_with_default(environment: Any) -> None: # resolve will miss values and try to find default credentials on the machine with pytest.raises(ConfigFieldMissingException) as py_ex: resolve_configuration(gcoauth) - assert py_ex.value.fields == ["client_id", "client_secret", "refresh_token", "project_id"] + assert py_ex.value.fields == [ + "client_id", + "client_secret", + "refresh_token", + "project_id", + ] # prepare real service.json oauth_str, _ = prepare_oauth_json() @@ -161,7 +176,9 @@ def test_oauth_credentials_native_credentials_object(environment: Any) -> None: oauth_dict = json.loads(oauth_str) # must add refresh_token oauth_dict["installed"]["refresh_token"] = "REFRESH TOKEN" - credentials = GoogleOAuth2Credentials.from_authorized_user_info(oauth_dict["installed"]) + credentials = GoogleOAuth2Credentials.from_authorized_user_info( + oauth_dict["installed"] + ) def _assert_credentials(gcp_credentials): # check props @@ -238,7 +255,10 @@ def test_bigquery_configuration() -> None: # default fingerprint is empty assert ( - BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset").fingerprint() == "" + BigQueryClientConfiguration() + ._bind_dataset_name(dataset_name="dataset") + .fingerprint() + == "" ) @@ -264,7 +284,9 @@ def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) # start a job with invalid name dest_path = file_storage.save("!!aaaa", b"data") with pytest.raises(LoadJobTerminalException): - client.start_file_load(client.schema.get_table(user_table_name), dest_path, uniq_id()) + client.start_file_load( + client.schema.get_table(user_table_name), dest_path, uniq_id() + ) user_table_name = prepare_table(client) load_json = { @@ -296,7 +318,9 @@ def test_bigquery_location(location: str, file_storage: FileStorage, client) -> "sender_id": "90238094809sajlkjxoiewjhduuiuehd", "timestamp": str(pendulum.now()), } - job = expect_load_file(client, file_storage, json.dumps(load_json), user_table_name) + job = expect_load_file( + client, file_storage, json.dumps(load_json), user_table_name + ) # start a job from the same file. it should be a fallback to retrieve a job silently client.start_file_load( @@ -304,7 +328,9 @@ def test_bigquery_location(location: str, file_storage: FileStorage, client) -> file_storage.make_full_path(job.file_name()), uniq_id(), ) - canonical_name = client.sql_client.make_qualified_table_name(user_table_name, escape=False) + canonical_name = client.sql_client.make_qualified_table_name( + user_table_name, escape=False + ) t = client.sql_client.native_connection.get_table(canonical_name) assert t.location == location @@ -331,7 +357,9 @@ def test_loading_errors(client: BigQueryClient, file_storage: FileStorage) -> No job = expect_load_file( client, file_storage, json.dumps(insert_json), user_table_name, status="failed" ) - assert "Only optional fields can be set to NULL. Field: timestamp;" in job.exception() + assert ( + "Only optional fields can be set to NULL. Field: timestamp;" in job.exception() + ) # insert a wrong type insert_json = copy(load_json) @@ -365,7 +393,8 @@ def test_loading_errors(client: BigQueryClient, file_storage: FileStorage) -> No client, file_storage, json.dumps(insert_json), user_table_name, status="failed" ) assert ( - "Invalid NUMERIC value: 100000000000000000000000000000 Field: parse_data__intent__id;" + "Invalid NUMERIC value: 100000000000000000000000000000 Field:" + " parse_data__intent__id;" in job.exception() ) @@ -379,8 +408,8 @@ def test_loading_errors(client: BigQueryClient, file_storage: FileStorage) -> No ) assert ( "Invalid BIGNUMERIC value:" - " 578960446186580977117854925043439539266.34992332820282019728792003956564819968 Field:" - " parse_data__metadata__rasa_x_id;" + " 578960446186580977117854925043439539266.34992332820282019728792003956564819968" + " Field: parse_data__metadata__rasa_x_id;" in job.exception() ) @@ -388,7 +417,9 @@ def test_loading_errors(client: BigQueryClient, file_storage: FileStorage) -> No def prepare_oauth_json() -> Tuple[str, str]: # prepare real service.json storage = FileStorage("_secrets", makedirs=True) - with open(common_json_case_path("oauth_client_secret_929384042504"), encoding="utf-8") as f: + with open( + common_json_case_path("oauth_client_secret_929384042504"), encoding="utf-8" + ) as f: oauth_str = f.read() dest_path = storage.save("oauth_client_secret_929384042504.json", oauth_str) return oauth_str, dest_path @@ -397,7 +428,9 @@ def prepare_oauth_json() -> Tuple[str, str]: def prepare_service_json() -> Tuple[str, str]: # prepare real service.json storage = FileStorage("_secrets", makedirs=True) - with open(common_json_case_path("level-dragon-333019-707809ee408a") + ".b64", mode="rb") as f: + with open( + common_json_case_path("level-dragon-333019-707809ee408a") + ".b64", mode="rb" + ) as f: services_str = base64.b64decode(f.read().strip(), validate=True).decode() dest_path = storage.save("level-dragon-333019-707809ee408a.json", services_str) return services_str, dest_path diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index fd58a6e033..6dbd3eaa65 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -41,11 +41,15 @@ def test_configuration() -> None: # check names normalised with custom_environ({"MYBG__CREDENTIALS__PRIVATE_KEY": "---NO NEWLINE---\n"}): - c = resolve_configuration(GcpServiceAccountCredentialsWithoutDefaults(), sections=("mybg",)) + c = resolve_configuration( + GcpServiceAccountCredentialsWithoutDefaults(), sections=("mybg",) + ) assert c.private_key == "---NO NEWLINE---\n" with custom_environ({"MYBG__CREDENTIALS__PRIVATE_KEY": "---WITH NEWLINE---\n"}): - c = resolve_configuration(GcpServiceAccountCredentialsWithoutDefaults(), sections=("mybg",)) + c = resolve_configuration( + GcpServiceAccountCredentialsWithoutDefaults(), sections=("mybg",) + ) assert c.private_key == "---WITH NEWLINE---\n" @@ -186,7 +190,11 @@ def test_create_table_with_integer_partition(gcp_client: BigQueryClient) -> None mod_update[0]["partition"] = True sql = gcp_client._get_table_update_sql("event_test_table", mod_update, False)[0] sqlfluff.parse(sql, dialect="bigquery") - assert "PARTITION BY RANGE_BUCKET(`col1`, GENERATE_ARRAY(-172800000, 691200000, 86400))" in sql + assert ( + "PARTITION BY RANGE_BUCKET(`col1`, GENERATE_ARRAY(-172800000, 691200000," + " 86400))" + in sql + ) @pytest.mark.parametrize( @@ -194,18 +202,30 @@ def test_create_table_with_integer_partition(gcp_client: BigQueryClient) -> None destinations_configs(all_staging_configs=True, subset=["bigquery"]), ids=lambda x: x.name, ) -def test_bigquery_partition_by_date(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline(f"bigquery_{uniq_id()}", full_refresh=True) +def test_bigquery_partition_by_date( + destination_config: DestinationTestConfiguration, +) -> None: + pipeline = destination_config.setup_pipeline( + f"bigquery_{uniq_id()}", full_refresh=True + ) @dlt.resource( write_disposition="merge", primary_key="my_date_column", - columns={"my_date_column": {"data_type": "date", "partition": True, "nullable": False}}, + columns={ + "my_date_column": { + "data_type": "date", + "partition": True, + "nullable": False, + } + }, ) def demo_resource() -> Iterator[Dict[str, pendulum.Date]]: for i in range(10): yield { - "my_date_column": pendulum.from_timestamp(1700784000 + i * 50_000).date(), + "my_date_column": pendulum.from_timestamp( + 1700784000 + i * 50_000 + ).date(), } @dlt.source(max_table_nesting=0) @@ -216,8 +236,8 @@ def demo_source() -> DltResource: with pipeline.sql_client() as c: with c.execute_query( - "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE partition_id IS NOT" - " NULL);" + "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE" + " partition_id IS NOT NULL);" ) as cur: has_partitions = cur.fetchone()[0] assert isinstance(has_partitions, bool) @@ -229,18 +249,30 @@ def demo_source() -> DltResource: destinations_configs(all_staging_configs=True, subset=["bigquery"]), ids=lambda x: x.name, ) -def test_bigquery_no_partition_by_date(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline(f"bigquery_{uniq_id()}", full_refresh=True) +def test_bigquery_no_partition_by_date( + destination_config: DestinationTestConfiguration, +) -> None: + pipeline = destination_config.setup_pipeline( + f"bigquery_{uniq_id()}", full_refresh=True + ) @dlt.resource( write_disposition="merge", primary_key="my_date_column", - columns={"my_date_column": {"data_type": "date", "partition": False, "nullable": False}}, + columns={ + "my_date_column": { + "data_type": "date", + "partition": False, + "nullable": False, + } + }, ) def demo_resource() -> Iterator[Dict[str, pendulum.Date]]: for i in range(10): yield { - "my_date_column": pendulum.from_timestamp(1700784000 + i * 50_000).date(), + "my_date_column": pendulum.from_timestamp( + 1700784000 + i * 50_000 + ).date(), } @dlt.source(max_table_nesting=0) @@ -251,8 +283,8 @@ def demo_source() -> DltResource: with pipeline.sql_client() as c: with c.execute_query( - "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE partition_id IS NOT" - " NULL);" + "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE" + " partition_id IS NOT NULL);" ) as cur: has_partitions = cur.fetchone()[0] assert isinstance(has_partitions, bool) @@ -264,14 +296,22 @@ def demo_source() -> DltResource: destinations_configs(all_staging_configs=True, subset=["bigquery"]), ids=lambda x: x.name, ) -def test_bigquery_partition_by_timestamp(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline(f"bigquery_{uniq_id()}", full_refresh=True) +def test_bigquery_partition_by_timestamp( + destination_config: DestinationTestConfiguration, +) -> None: + pipeline = destination_config.setup_pipeline( + f"bigquery_{uniq_id()}", full_refresh=True + ) @dlt.resource( write_disposition="merge", primary_key="my_timestamp_column", columns={ - "my_timestamp_column": {"data_type": "timestamp", "partition": True, "nullable": False} + "my_timestamp_column": { + "data_type": "timestamp", + "partition": True, + "nullable": False, + } }, ) def demo_resource() -> Iterator[Dict[str, pendulum.DateTime]]: @@ -288,8 +328,8 @@ def demo_source() -> DltResource: with pipeline.sql_client() as c: with c.execute_query( - "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE partition_id IS NOT" - " NULL);" + "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE" + " partition_id IS NOT NULL);" ) as cur: has_partitions = cur.fetchone()[0] assert isinstance(has_partitions, bool) @@ -304,13 +344,19 @@ def demo_source() -> DltResource: def test_bigquery_no_partition_by_timestamp( destination_config: DestinationTestConfiguration, ) -> None: - pipeline = destination_config.setup_pipeline(f"bigquery_{uniq_id()}", full_refresh=True) + pipeline = destination_config.setup_pipeline( + f"bigquery_{uniq_id()}", full_refresh=True + ) @dlt.resource( write_disposition="merge", primary_key="my_timestamp_column", columns={ - "my_timestamp_column": {"data_type": "timestamp", "partition": False, "nullable": False} + "my_timestamp_column": { + "data_type": "timestamp", + "partition": False, + "nullable": False, + } }, ) def demo_resource() -> Iterator[Dict[str, pendulum.DateTime]]: @@ -327,8 +373,8 @@ def demo_source() -> DltResource: with pipeline.sql_client() as c: with c.execute_query( - "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE partition_id IS NOT" - " NULL);" + "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE" + " partition_id IS NOT NULL);" ) as cur: has_partitions = cur.fetchone()[0] assert isinstance(has_partitions, bool) @@ -340,11 +386,17 @@ def demo_source() -> DltResource: destinations_configs(all_staging_configs=True, subset=["bigquery"]), ids=lambda x: x.name, ) -def test_bigquery_partition_by_integer(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline(f"bigquery_{uniq_id()}", full_refresh=True) +def test_bigquery_partition_by_integer( + destination_config: DestinationTestConfiguration, +) -> None: + pipeline = destination_config.setup_pipeline( + f"bigquery_{uniq_id()}", full_refresh=True + ) @dlt.resource( - columns={"some_int": {"data_type": "bigint", "partition": True, "nullable": False}}, + columns={ + "some_int": {"data_type": "bigint", "partition": True, "nullable": False} + }, ) def demo_resource() -> Iterator[Dict[str, int]]: for i in range(10): @@ -360,8 +412,8 @@ def demo_source() -> DltResource: with pipeline.sql_client() as c: with c.execute_query( - "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE partition_id IS NOT" - " NULL);" + "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE" + " partition_id IS NOT NULL);" ) as cur: has_partitions = cur.fetchone()[0] assert isinstance(has_partitions, bool) @@ -373,11 +425,17 @@ def demo_source() -> DltResource: destinations_configs(all_staging_configs=True, subset=["bigquery"]), ids=lambda x: x.name, ) -def test_bigquery_no_partition_by_integer(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline(f"bigquery_{uniq_id()}", full_refresh=True) +def test_bigquery_no_partition_by_integer( + destination_config: DestinationTestConfiguration, +) -> None: + pipeline = destination_config.setup_pipeline( + f"bigquery_{uniq_id()}", full_refresh=True + ) @dlt.resource( - columns={"some_int": {"data_type": "bigint", "partition": False, "nullable": False}}, + columns={ + "some_int": {"data_type": "bigint", "partition": False, "nullable": False} + }, ) def demo_resource() -> Iterator[Dict[str, int]]: for i in range(10): @@ -393,8 +451,8 @@ def demo_source() -> DltResource: with pipeline.sql_client() as c: with c.execute_query( - "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE partition_id IS NOT" - " NULL);" + "SELECT EXISTS (SELECT 1 FROM INFORMATION_SCHEMA.PARTITIONS WHERE" + " partition_id IS NOT NULL);" ) as cur: has_partitions = cur.fetchone()[0] assert isinstance(has_partitions, bool) @@ -419,7 +477,10 @@ def some_data() -> Iterator[Dict[str, str]]: def test_adapter_hints_parsing_partitioning_more_than_one_column() -> None: @dlt.resource( - columns=[{"name": "col1", "data_type": "bigint"}, {"name": "col2", "data_type": "bigint"}] + columns=[ + {"name": "col1", "data_type": "bigint"}, + {"name": "col2", "data_type": "bigint"}, + ] ) def some_data() -> Iterator[Dict[str, Any]]: yield from [{"col1": str(i), "col2": i} for i in range(3)] @@ -429,7 +490,9 @@ def some_data() -> Iterator[Dict[str, Any]]: "col2": {"data_type": "bigint", "name": "col2"}, } - with pytest.raises(ValueError, match="^`partition` must be a single column name as a string.$"): + with pytest.raises( + ValueError, match="^`partition` must be a single column name as a string.$" + ): bigquery_adapter(some_data, partition=["col1", "col2"]) @@ -440,7 +503,11 @@ def some_data() -> Iterator[Dict[str, str]]: bigquery_adapter(some_data, partition="int_col") assert some_data.columns == { - "int_col": {"name": "int_col", "data_type": "bigint", "x-bigquery-partition": True}, + "int_col": { + "name": "int_col", + "data_type": "bigint", + "x-bigquery-partition": True, + }, } @@ -449,7 +516,9 @@ def some_data() -> Iterator[Dict[str, str]]: destinations_configs(all_staging_configs=True, subset=["bigquery"]), ids=lambda x: x.name, ) -def test_adapter_hints_partitioning(destination_config: DestinationTestConfiguration) -> None: +def test_adapter_hints_partitioning( + destination_config: DestinationTestConfiguration, +) -> None: @dlt.resource(columns=[{"name": "col1", "data_type": "bigint"}]) def no_hints() -> Iterator[Dict[str, int]]: yield from [{"col1": i} for i in range(10)] @@ -476,7 +545,9 @@ def sources() -> List[DltResource]: no_hints_table = nc.get_table(fqtn_no_hints) hints_table = nc.get_table(fqtn_hints) - assert not no_hints_table.range_partitioning, "`no_hints` table IS clustered on a column." + assert ( + not no_hints_table.range_partitioning + ), "`no_hints` table IS clustered on a column." if not hints_table.range_partitioning: raise ValueError("`hints` table IS NOT clustered on a column.") @@ -514,7 +585,9 @@ def test_adapter_hints_round_half_away_from_zero( def no_hints() -> Iterator[Dict[str, float]]: yield from [{"col1": float(i)} for i in range(10)] - hints = bigquery_adapter(no_hints._clone(new_name="hints"), round_half_away_from_zero="col1") + hints = bigquery_adapter( + no_hints._clone(new_name="hints"), round_half_away_from_zero="col1" + ) @dlt.source(max_table_nesting=0) def sources() -> List[DltResource]: @@ -569,7 +642,9 @@ def some_data() -> Iterator[Dict[str, float]]: destinations_configs(all_staging_configs=True, subset=["bigquery"]), ids=lambda x: x.name, ) -def test_adapter_hints_round_half_even(destination_config: DestinationTestConfiguration) -> None: +def test_adapter_hints_round_half_even( + destination_config: DestinationTestConfiguration, +) -> None: @dlt.resource(columns=[{"name": "col1", "data_type": "wei"}]) def no_hints() -> Iterator[Dict[str, float]]: yield from [{"col1": float(i)} for i in range(10)] @@ -604,7 +679,9 @@ def sources() -> List[DltResource]: elif row["table_name"] == "hints": # type: ignore hints_rounding_mode = row["rounding_mode"] # type: ignore - assert (no_hints_rounding_mode is None) and (hints_rounding_mode == "ROUND_HALF_EVEN") + assert (no_hints_rounding_mode is None) and ( + hints_rounding_mode == "ROUND_HALF_EVEN" + ) def test_adapter_hints_parsing_clustering() -> None: @@ -614,13 +691,20 @@ def some_data() -> Iterator[Dict[str, str]]: bigquery_adapter(some_data, cluster="int_col") assert some_data.columns == { - "int_col": {"name": "int_col", "data_type": "bigint", "x-bigquery-cluster": True}, + "int_col": { + "name": "int_col", + "data_type": "bigint", + "x-bigquery-cluster": True, + }, } def test_adapter_hints_parsing_multiple_clustering() -> None: @dlt.resource( - columns=[{"name": "col1", "data_type": "bigint"}, {"name": "col2", "data_type": "text"}] + columns=[ + {"name": "col1", "data_type": "bigint"}, + {"name": "col2", "data_type": "text"}, + ] ) def some_data() -> Iterator[Dict[str, Any]]: yield from [{"col1": i, "col2": str(i)} for i in range(10)] @@ -722,10 +806,14 @@ def sources() -> List[DltResource]: hints_table = nc.get_table(fqtn_hints) no_hints_cluster_fields = ( - [] if no_hints_table.clustering_fields is None else no_hints_table.clustering_fields + [] + if no_hints_table.clustering_fields is None + else no_hints_table.clustering_fields ) hints_cluster_fields = ( - [] if hints_table.clustering_fields is None else hints_table.clustering_fields + [] + if hints_table.clustering_fields is None + else hints_table.clustering_fields ) assert not no_hints_cluster_fields, "`no_hints` table IS clustered some column." @@ -742,7 +830,9 @@ def sources() -> List[DltResource]: destinations_configs(all_staging_configs=True, subset=["bigquery"]), ids=lambda x: x.name, ) -def test_adapter_hints_clustering(destination_config: DestinationTestConfiguration) -> None: +def test_adapter_hints_clustering( + destination_config: DestinationTestConfiguration, +) -> None: @dlt.resource(columns=[{"name": "col1", "data_type": "text"}]) def no_hints() -> Iterator[Dict[str, str]]: yield from [{"col1": str(i)} for i in range(10)] @@ -770,14 +860,20 @@ def sources() -> List[DltResource]: hints_table = nc.get_table(fqtn_hints) no_hints_cluster_fields = ( - [] if no_hints_table.clustering_fields is None else no_hints_table.clustering_fields + [] + if no_hints_table.clustering_fields is None + else no_hints_table.clustering_fields ) hints_cluster_fields = ( - [] if hints_table.clustering_fields is None else hints_table.clustering_fields + [] + if hints_table.clustering_fields is None + else hints_table.clustering_fields ) assert not no_hints_cluster_fields, "`no_hints` table IS clustered by `col1`." - assert ["col1"] == hints_cluster_fields, "`hints` table IS NOT clustered by `col1`." + assert [ + "col1" + ] == hints_cluster_fields, "`hints` table IS NOT clustered by `col1`." def test_adapter_hints_empty() -> None: @@ -805,7 +901,9 @@ def some_data() -> Iterator[Dict[str, str]]: ), ): bigquery_adapter( - some_data, round_half_away_from_zero="double_col", round_half_even="double_col" + some_data, + round_half_away_from_zero="double_col", + round_half_even="double_col", ) @@ -931,7 +1029,9 @@ def hints() -> Iterator[Dict[str, Any]]: bigquery_adapter(hints, table_expiration_datetime="2030-01-01", cluster=["col1"]) bigquery_adapter( - hints, table_description="A small table somewhere in the cosmos...", partition="col2" + hints, + table_description="A small table somewhere in the cosmos...", + partition="col2", ) pipeline = destination_config.setup_pipeline( @@ -948,11 +1048,15 @@ def hints() -> Iterator[Dict[str, Any]]: table: Table = nc.get_table(table_fqtn) - table_cluster_fields = [] if table.clustering_fields is None else table.clustering_fields + table_cluster_fields = ( + [] if table.clustering_fields is None else table.clustering_fields + ) # Test merging behaviour. assert table.expires == pendulum.datetime(2030, 1, 1, 0) - assert ["col1"] == table_cluster_fields, "`hints` table IS NOT clustered by `col1`." + assert [ + "col1" + ] == table_cluster_fields, "`hints` table IS NOT clustered by `col1`." assert table.description == "A small table somewhere in the cosmos..." if not table.range_partitioning: diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index 8d30d05e42..4d5f083c20 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -15,7 +15,9 @@ def test_databricks_credentials_to_connector_params(): os.environ["CREDENTIALS__ACCESS_TOKEN"] = "my-token" os.environ["CREDENTIALS__CATALOG"] = "my-catalog" # JSON encoded dict of extra args - os.environ["CREDENTIALS__CONNECTION_PARAMETERS"] = '{"extra_a": "a", "extra_b": "b"}' + os.environ["CREDENTIALS__CONNECTION_PARAMETERS"] = ( + '{"extra_a": "a", "extra_b": "b"}' + ) config = resolve_configuration( DatabricksClientConfiguration()._bind_dataset_name(dataset_name="my-dataset") diff --git a/tests/load/duckdb/test_duckdb_client.py b/tests/load/duckdb/test_duckdb_client.py index 3deed7a77d..3b833b120c 100644 --- a/tests/load/duckdb/test_duckdb_client.py +++ b/tests/load/duckdb/test_duckdb_client.py @@ -15,7 +15,12 @@ from dlt.destinations import duckdb from tests.load.pipeline.utils import drop_pipeline, assert_table -from tests.utils import patch_home_dir, autouse_test_storage, preserve_environ, TEST_STORAGE_ROOT +from tests.utils import ( + patch_home_dir, + autouse_test_storage, + preserve_environ, + TEST_STORAGE_ROOT, +) @pytest.fixture(autouse=True) @@ -110,7 +115,9 @@ def test_duckdb_database_path() -> None: # provide absolute path db_path = os.path.abspath("_storage/abs_test_quack.duckdb") c = resolve_configuration( - DuckDbClientConfiguration(credentials=f"duckdb:///{db_path}")._bind_dataset_name( + DuckDbClientConfiguration( + credentials=f"duckdb:///{db_path}" + )._bind_dataset_name( dataset_name="test_dataset", ) ) @@ -161,27 +168,41 @@ def test_duckdb_database_path() -> None: def test_keeps_initial_db_path() -> None: db_path = "_storage/path_test_quack.duckdb" - p = dlt.pipeline(pipeline_name="quack_pipeline", credentials=db_path, destination="duckdb") + p = dlt.pipeline( + pipeline_name="quack_pipeline", credentials=db_path, destination="duckdb" + ) print(p.pipelines_dir) with p.sql_client() as conn: # still cwd assert conn.credentials._conn_str().lower() == os.path.abspath(db_path).lower() # but it is kept in the local state - assert p.get_local_state_val("duckdb_database").lower() == os.path.abspath(db_path).lower() + assert ( + p.get_local_state_val("duckdb_database").lower() + == os.path.abspath(db_path).lower() + ) # attach the pipeline p = dlt.attach(pipeline_name="quack_pipeline") - assert p.get_local_state_val("duckdb_database").lower() == os.path.abspath(db_path).lower() + assert ( + p.get_local_state_val("duckdb_database").lower() + == os.path.abspath(db_path).lower() + ) with p.sql_client() as conn: # still cwd - assert p.get_local_state_val("duckdb_database").lower() == os.path.abspath(db_path).lower() + assert ( + p.get_local_state_val("duckdb_database").lower() + == os.path.abspath(db_path).lower() + ) assert conn.credentials._conn_str().lower() == os.path.abspath(db_path).lower() # now create a new pipeline dlt.pipeline(pipeline_name="not_quack", destination="dummy") with p.sql_client() as conn: # still cwd - assert p.get_local_state_val("duckdb_database").lower() == os.path.abspath(db_path).lower() + assert ( + p.get_local_state_val("duckdb_database").lower() + == os.path.abspath(db_path).lower() + ) # new pipeline context took over # TODO: restore pipeline context on each call assert conn.credentials._conn_str().lower() != os.path.abspath(db_path).lower() @@ -189,7 +210,9 @@ def test_keeps_initial_db_path() -> None: def test_duckdb_database_delete() -> None: db_path = "_storage/path_test_quack.duckdb" - p = dlt.pipeline(pipeline_name="quack_pipeline", destination=duckdb(credentials=db_path)) + p = dlt.pipeline( + pipeline_name="quack_pipeline", destination=duckdb(credentials=db_path) + ) p.run([1, 2, 3], table_name="table", dataset_name="dataset") # attach the pipeline p = dlt.attach(pipeline_name="quack_pipeline") @@ -209,7 +232,9 @@ def test_duck_database_path_delete() -> None: db_folder = "_storage/db_path" os.makedirs(db_folder) db_path = f"{db_folder}/path_test_quack.duckdb" - p = dlt.pipeline(pipeline_name="deep_quack_pipeline", credentials=db_path, destination="duckdb") + p = dlt.pipeline( + pipeline_name="deep_quack_pipeline", credentials=db_path, destination="duckdb" + ) p.run([1, 2, 3], table_name="table", dataset_name="dataset") # attach the pipeline p = dlt.attach(pipeline_name="deep_quack_pipeline") @@ -230,7 +255,9 @@ def test_case_sensitive_database_name() -> None: cs_quack = os.path.join(TEST_STORAGE_ROOT, "QuAcK") os.makedirs(cs_quack, exist_ok=True) db_path = os.path.join(cs_quack, "path_TEST_quack.duckdb") - p = dlt.pipeline(pipeline_name="NOT_QUAck", credentials=db_path, destination="duckdb") + p = dlt.pipeline( + pipeline_name="NOT_QUAck", credentials=db_path, destination="duckdb" + ) with p.sql_client() as conn: conn.execute_sql("DESCRIBE;") @@ -241,7 +268,9 @@ def test_external_duckdb_database() -> None: # pass explicit in memory database conn = duckdb.connect(":memory:") c = resolve_configuration( - DuckDbClientConfiguration(credentials=conn)._bind_dataset_name(dataset_name="test_dataset") + DuckDbClientConfiguration(credentials=conn)._bind_dataset_name( + dataset_name="test_dataset" + ) ) assert c.credentials._conn_borrows == 0 assert c.credentials._conn is conn diff --git a/tests/load/duckdb/test_duckdb_table_builder.py b/tests/load/duckdb/test_duckdb_table_builder.py index 542b18993c..9d75841fb1 100644 --- a/tests/load/duckdb/test_duckdb_table_builder.py +++ b/tests/load/duckdb/test_duckdb_table_builder.py @@ -16,7 +16,9 @@ def client(empty_schema: Schema) -> DuckDbClient: # return client without opening connection return DuckDbClient( empty_schema, - DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_" + uniq_id()), + DuckDbClientConfiguration()._bind_dataset_name( + dataset_name="test_" + uniq_id() + ), ) diff --git a/tests/load/filesystem/test_azure_credentials.py b/tests/load/filesystem/test_azure_credentials.py index 093cd6dd19..6d1879d9dd 100644 --- a/tests/load/filesystem/test_azure_credentials.py +++ b/tests/load/filesystem/test_azure_credentials.py @@ -17,7 +17,9 @@ def test_azure_credentials_from_account_key(environment: Dict[str, str]) -> None: environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"] = "fake_account_name" - environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_KEY"] = "QWERTYUIOPASDFGHJKLZXCVBNM1234567890" + environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_KEY"] = ( + "QWERTYUIOPASDFGHJKLZXCVBNM1234567890" + ) config = resolve_configuration(AzureCredentials()) @@ -33,7 +35,9 @@ def test_azure_credentials_from_account_key(environment: Dict[str, str]) -> None def test_create_azure_sas_token_with_permissions(environment: Dict[str, str]) -> None: environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"] = "fake_account_name" - environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_KEY"] = "QWERTYUIOPASDFGHJKLZXCVBNM1234567890" + environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_KEY"] = ( + "QWERTYUIOPASDFGHJKLZXCVBNM1234567890" + ) environment["CREDENTIALS__AZURE_SAS_TOKEN_PERMISSIONS"] = "rl" config = resolve_configuration(AzureCredentials()) @@ -52,9 +56,13 @@ def test_azure_credentials_from_sas_token(environment: Dict[str, str]) -> None: config = resolve_configuration(AzureCredentials()) - assert config.azure_storage_sas_token == environment["CREDENTIALS__AZURE_STORAGE_SAS_TOKEN"] assert ( - config.azure_storage_account_name == environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"] + config.azure_storage_sas_token + == environment["CREDENTIALS__AZURE_STORAGE_SAS_TOKEN"] + ) + assert ( + config.azure_storage_account_name + == environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"] ) assert config.azure_storage_account_key is None @@ -80,7 +88,8 @@ def test_azure_credentials_from_default(environment: Dict[str, str]) -> None: config = resolve_configuration(AzureCredentials()) assert ( - config.azure_storage_account_name == environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"] + config.azure_storage_account_name + == environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"] ) assert config.azure_storage_account_key is None assert config.azure_storage_sas_token is None diff --git a/tests/load/filesystem/test_filesystem_client.py b/tests/load/filesystem/test_filesystem_client.py index 9948e26882..e4c7637a1b 100644 --- a/tests/load/filesystem/test_filesystem_client.py +++ b/tests/load/filesystem/test_filesystem_client.py @@ -48,7 +48,9 @@ def test_filesystem_destination_configuration() -> None: @pytest.mark.parametrize("write_disposition", ("replace", "append", "merge")) @pytest.mark.parametrize("layout", ALL_LAYOUTS) -def test_successful_load(write_disposition: str, layout: str, with_gdrive_buckets_env: str) -> None: +def test_successful_load( + write_disposition: str, layout: str, with_gdrive_buckets_env: str +) -> None: """Test load is successful with an empty destination dataset""" if layout: os.environ["DESTINATION__FILESYSTEM__LAYOUT"] = layout @@ -95,7 +97,9 @@ def test_replace_write_disposition(layout: str, default_buckets_env: str) -> Non os.environ.pop("DESTINATION__FILESYSTEM__LAYOUT", None) dataset_name = "test_" + uniq_id() # NOTE: context manager will delete the dataset at the end so keep it open until the end - with perform_load(dataset_name, NORMALIZED_FILES, write_disposition="replace") as load_info: + with perform_load( + dataset_name, NORMALIZED_FILES, write_disposition="replace" + ) as load_info: client, _, root_path, load_id1 = load_info layout = client.config.layout @@ -141,9 +145,13 @@ def test_append_write_disposition(layout: str, default_buckets_env: str) -> None os.environ.pop("DESTINATION__FILESYSTEM__LAYOUT", None) dataset_name = "test_" + uniq_id() # NOTE: context manager will delete the dataset at the end so keep it open until the end - with perform_load(dataset_name, NORMALIZED_FILES, write_disposition="append") as load_info: + with perform_load( + dataset_name, NORMALIZED_FILES, write_disposition="append" + ) as load_info: client, jobs1, root_path, load_id1 = load_info - with perform_load(dataset_name, NORMALIZED_FILES, write_disposition="append") as load_info: + with perform_load( + dataset_name, NORMALIZED_FILES, write_disposition="append" + ) as load_info: client, jobs2, root_path, load_id2 = load_info layout = client.config.layout expected_files = [ @@ -157,7 +165,9 @@ def test_append_write_disposition(layout: str, default_buckets_env: str) -> None ) for job in jobs2 ] - expected_files = sorted([posixpath.join(root_path, fn) for fn in expected_files]) + expected_files = sorted( + [posixpath.join(root_path, fn) for fn in expected_files] + ) paths = [] for basedir, _dirs, files in client.fs_client.walk( diff --git a/tests/load/filesystem/test_filesystem_common.py b/tests/load/filesystem/test_filesystem_common.py index 4c94766097..8170291ee5 100644 --- a/tests/load/filesystem/test_filesystem_common.py +++ b/tests/load/filesystem/test_filesystem_common.py @@ -9,7 +9,10 @@ from dlt.common import pendulum from dlt.common.configuration.inject import with_config -from dlt.common.configuration.specs import AzureCredentials, AzureCredentialsWithoutDefaults +from dlt.common.configuration.specs import ( + AzureCredentials, + AzureCredentialsWithoutDefaults, +) from dlt.common.storages import fsspec_from_config, FilesystemConfiguration from dlt.common.storages.fsspec_filesystem import MTIME_DISPATCH, glob_files from dlt.common.utils import uniq_id @@ -99,7 +102,9 @@ def test_filesystem_dict(with_gdrive_buckets_env: str, load_content: bool) -> No pytest.skip(f"Skipping due to {str(ex)}") -@pytest.mark.skipif("s3" not in ALL_FILESYSTEM_DRIVERS, reason="s3 destination not configured") +@pytest.mark.skipif( + "s3" not in ALL_FILESYSTEM_DRIVERS, reason="s3 destination not configured" +) def test_filesystem_instance_from_s3_endpoint(environment: Dict[str, str]) -> None: """Test that fsspec instance is correctly configured when using endpoint URL. E.g. when using an S3 compatible service such as Cloudflare R2 @@ -124,7 +129,9 @@ def test_filesystem_instance_from_s3_endpoint(environment: Dict[str, str]) -> No def test_filesystem_configuration_with_additional_arguments() -> None: config = FilesystemConfiguration( - bucket_url="az://root", kwargs={"use_ssl": True}, client_kwargs={"verify": "public.crt"} + bucket_url="az://root", + kwargs={"use_ssl": True}, + client_kwargs={"verify": "public.crt"}, ) assert dict(config) == { "read_only": False, @@ -135,10 +142,14 @@ def test_filesystem_configuration_with_additional_arguments() -> None: } -@pytest.mark.skipif("s3" not in ALL_FILESYSTEM_DRIVERS, reason="s3 destination not configured") +@pytest.mark.skipif( + "s3" not in ALL_FILESYSTEM_DRIVERS, reason="s3 destination not configured" +) def test_kwargs_propagate_to_s3_instance(default_buckets_env: str) -> None: os.environ["DESTINATION__FILESYSTEM__KWARGS"] = '{"use_ssl": false}' - os.environ["DESTINATION__FILESYSTEM__CLIENT_KWARGS"] = '{"verify": false, "foo": "bar"}' + os.environ["DESTINATION__FILESYSTEM__CLIENT_KWARGS"] = ( + '{"verify": false, "foo": "bar"}' + ) config = get_config() @@ -154,12 +165,18 @@ def test_kwargs_propagate_to_s3_instance(default_buckets_env: str) -> None: assert ("foo", "bar") in filesystem.client_kwargs.items() -@pytest.mark.skipif("s3" not in ALL_FILESYSTEM_DRIVERS, reason="s3 destination not configured") -def test_s3_wrong_client_certificate(default_buckets_env: str, self_signed_cert: str) -> None: +@pytest.mark.skipif( + "s3" not in ALL_FILESYSTEM_DRIVERS, reason="s3 destination not configured" +) +def test_s3_wrong_client_certificate( + default_buckets_env: str, self_signed_cert: str +) -> None: """Test whether filesystem raises an SSLError when trying to establish a connection with the wrong client certificate.""" os.environ["DESTINATION__FILESYSTEM__KWARGS"] = '{"use_ssl": true}' - os.environ["DESTINATION__FILESYSTEM__CLIENT_KWARGS"] = f'{{"verify": "{self_signed_cert}"}}' + os.environ["DESTINATION__FILESYSTEM__CLIENT_KWARGS"] = ( + f'{{"verify": "{self_signed_cert}"}}' + ) config = get_config() diff --git a/tests/load/filesystem/utils.py b/tests/load/filesystem/utils.py index 6e697fdef9..fd54182f9e 100644 --- a/tests/load/filesystem/utils.py +++ b/tests/load/filesystem/utils.py @@ -70,7 +70,9 @@ def perform_load( @pytest.fixture(scope="function", autouse=False) def self_signed_cert() -> Iterator[str]: - key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) + key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) subject = issuer = x509.Name( [ diff --git a/tests/load/mssql/test_mssql_credentials.py b/tests/load/mssql/test_mssql_credentials.py index 0e38791f22..a18710ceb1 100644 --- a/tests/load/mssql/test_mssql_credentials.py +++ b/tests/load/mssql/test_mssql_credentials.py @@ -13,7 +13,9 @@ def test_mssql_credentials_defaults() -> None: assert creds.connect_timeout == 15 assert MsSqlCredentials.__config_gen_annotations__ == ["port", "connect_timeout"] # port should be optional - resolve_configuration(creds, explicit_value="mssql://loader:loader@localhost/dlt_data") + resolve_configuration( + creds, explicit_value="mssql://loader:loader@localhost/dlt_data" + ) assert creds.port == 1433 @@ -122,7 +124,9 @@ def test_to_odbc_dsn_arbitrary_keys_specified() -> None: } -available_drivers = [d for d in pyodbc.drivers() if d in MsSqlCredentials.SUPPORTED_DRIVERS] +available_drivers = [ + d for d in pyodbc.drivers() if d in MsSqlCredentials.SUPPORTED_DRIVERS +] @pytest.mark.skipif(not available_drivers, reason="no supported driver available") diff --git a/tests/load/mssql/test_mssql_table_builder.py b/tests/load/mssql/test_mssql_table_builder.py index 1b4a77a2ab..2e254dfb2d 100644 --- a/tests/load/mssql/test_mssql_table_builder.py +++ b/tests/load/mssql/test_mssql_table_builder.py @@ -4,10 +4,15 @@ from dlt.common.utils import uniq_id from dlt.common.schema import Schema -pytest.importorskip("dlt.destinations.impl.mssql.mssql", reason="MSSQL ODBC driver not installed") +pytest.importorskip( + "dlt.destinations.impl.mssql.mssql", reason="MSSQL ODBC driver not installed" +) from dlt.destinations.impl.mssql.mssql import MsSqlClient -from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration, MsSqlCredentials +from dlt.destinations.impl.mssql.configuration import ( + MsSqlClientConfiguration, + MsSqlCredentials, +) from tests.load.utils import TABLE_UPDATE, empty_schema diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 59cd90c535..457bc3f0e5 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -47,7 +47,9 @@ def test_load_item( def some_data(): yield item - load_info = pipeline.run(some_data(), loader_file_format=destination_config.file_format) + load_info = pipeline.run( + some_data(), loader_file_format=destination_config.file_format + ) # assert the table types some_table_columns = pipeline.default_schema.get_table("some_data")["columns"] assert some_table_columns["string"]["data_type"] == "text" @@ -144,13 +146,17 @@ def test_parquet_column_names_are_normalized( def some_data(): yield tbl - pipeline = dlt.pipeline("arrow_" + uniq_id(), destination=destination_config.destination) + pipeline = dlt.pipeline( + "arrow_" + uniq_id(), destination=destination_config.destination + ) pipeline.extract(some_data()) # Find the extracted file norm_storage = pipeline._get_normalize_storage() extract_files = [ - fn for fn in norm_storage.list_files_to_normalize_sorted() if fn.endswith(".parquet") + fn + for fn in norm_storage.list_files_to_normalize_sorted() + if fn.endswith(".parquet") ] assert len(extract_files) == 1 @@ -158,7 +164,9 @@ def some_data(): expected_column_names = [ pipeline.default_schema.naming.normalize_path(col) for col in df.columns ] - new_table_name = pipeline.default_schema.naming.normalize_table_identifier("some_data") + new_table_name = pipeline.default_schema.naming.normalize_table_identifier( + "some_data" + ) schema_columns = pipeline.default_schema.get_table_columns(new_table_name) # Schema columns are normalized diff --git a/tests/load/pipeline/test_athena.py b/tests/load/pipeline/test_athena.py index 9c17be318f..38bb3a3a90 100644 --- a/tests/load/pipeline/test_athena.py +++ b/tests/load/pipeline/test_athena.py @@ -18,14 +18,19 @@ ids=lambda x: x.name, ) def test_athena_destinations(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "athena_" + uniq_id(), full_refresh=True + ) @dlt.resource(name="items", write_disposition="append") def items(): yield { "id": 1, "name": "item", - "sub_items": [{"id": 101, "name": "sub item 101"}, {"id": 101, "name": "sub item 102"}], + "sub_items": [ + {"id": 101, "name": "sub item 101"}, + {"id": 101, "name": "sub item 102"}, + ], } pipeline.run(items, loader_file_format=destination_config.file_format) @@ -76,13 +81,17 @@ def items2(): def test_athena_all_datatypes_and_timestamps( destination_config: DestinationTestConfiguration, ) -> None: - pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "athena_" + uniq_id(), full_refresh=True + ) # TIME is not supported column_schemas, data_types = table_update_and_row(exclude_types=["time"]) # apply the exact columns definitions so we process complex and wei types correctly! - @dlt.resource(table_name="data_types", write_disposition="append", columns=column_schemas) + @dlt.resource( + table_name="data_types", write_disposition="append", columns=column_schemas + ) def my_resource() -> Iterator[Any]: nonlocal data_types yield [data_types] * 10 @@ -116,7 +125,8 @@ def my_source() -> Any: assert len(db_rows) == 10 # no rows - TIMESTAMP(6) not supported db_rows = sql_client.execute_sql( - "SELECT * FROM data_types WHERE col4 = TIMESTAMP '2022-05-23 13:26:45.176145'" + "SELECT * FROM data_types WHERE col4 = TIMESTAMP '2022-05-23" + " 13:26:45.176145'" ) assert len(db_rows) == 0 # use pendulum @@ -146,7 +156,9 @@ def my_source() -> Any: assert len(db_rows) == 0 # check date - db_rows = sql_client.execute_sql("SELECT * FROM data_types WHERE col10 = DATE '2023-02-27'") + db_rows = sql_client.execute_sql( + "SELECT * FROM data_types WHERE col10 = DATE '2023-02-27'" + ) assert len(db_rows) == 10 db_rows = sql_client.execute_sql( "SELECT * FROM data_types WHERE col10 = %s", pendulum.date(2023, 2, 27) @@ -163,13 +175,19 @@ def my_source() -> Any: destinations_configs(default_sql_configs=True, subset=["athena"]), ids=lambda x: x.name, ) -def test_athena_blocks_time_column(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), full_refresh=True) +def test_athena_blocks_time_column( + destination_config: DestinationTestConfiguration, +) -> None: + pipeline = destination_config.setup_pipeline( + "athena_" + uniq_id(), full_refresh=True + ) column_schemas, data_types = table_update_and_row() # apply the exact columns definitions so we process complex and wei types correctly! - @dlt.resource(table_name="data_types", write_disposition="append", columns=column_schemas) + @dlt.resource( + table_name="data_types", write_disposition="append", columns=column_schemas + ) def my_resource() -> Iterator[Any]: nonlocal data_types yield [data_types] * 10 diff --git a/tests/load/pipeline/test_bigquery.py b/tests/load/pipeline/test_bigquery.py index 711d45fb1f..1aeda0e4a6 100644 --- a/tests/load/pipeline/test_bigquery.py +++ b/tests/load/pipeline/test_bigquery.py @@ -12,18 +12,27 @@ destinations_configs(default_sql_configs=True, subset=["bigquery"]), ids=lambda x: x.name, ) -def test_bigquery_numeric_types(destination_config: DestinationTestConfiguration) -> None: +def test_bigquery_numeric_types( + destination_config: DestinationTestConfiguration, +) -> None: pipeline = destination_config.setup_pipeline("test_bigquery_numeric_types") columns = [ - {"name": "col_big_numeric", "data_type": "decimal", "precision": 47, "scale": 9}, + { + "name": "col_big_numeric", + "data_type": "decimal", + "precision": 47, + "scale": 9, + }, {"name": "col_numeric", "data_type": "decimal", "precision": 38, "scale": 9}, ] data = [ { # Valid BIGNUMERIC and NUMERIC values - "col_big_numeric": Decimal("12345678901234567890123456789012345678.123456789"), + "col_big_numeric": Decimal( + "12345678901234567890123456789012345678.123456789" + ), "col_numeric": Decimal("12345678901234567890123456789.123456789"), }, ] @@ -32,7 +41,9 @@ def test_bigquery_numeric_types(destination_config: DestinationTestConfiguration assert_load_info(info) with pipeline.sql_client() as client: - with client.execute_query("SELECT col_big_numeric, col_numeric FROM big_numeric;") as q: + with client.execute_query( + "SELECT col_big_numeric, col_numeric FROM big_numeric;" + ) as q: row = q.fetchone() assert row[0] == data[0]["col_big_numeric"] assert row[1] == data[0]["col_numeric"] diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index 91318d0f34..28535ba606 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -37,11 +37,14 @@ def test_run_jaffle_package( ) -> None: if destination_config.destination == "athena": pytest.skip( - "dbt-athena requires database to be created and we don't do it in case of Jaffle" + "dbt-athena requires database to be created and we don't do it in case of" + " Jaffle" ) pipeline = destination_config.setup_pipeline("jaffle_jaffle", full_refresh=True) # get runner, pass the env from fixture - dbt = dlt.dbt.package(pipeline, "https://github.com/dbt-labs/jaffle_shop.git", venv=dbt_venv) + dbt = dlt.dbt.package( + pipeline, "https://github.com/dbt-labs/jaffle_shop.git", venv=dbt_venv + ) # no default schema assert pipeline.default_schema_name is None # inject default schema otherwise dataset is not deleted @@ -69,7 +72,9 @@ def test_run_jaffle_package( destinations_configs(default_sql_configs=True, supports_dbt=True), ids=lambda x: x.name, ) -def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_venv: Venv) -> None: +def test_run_chess_dbt( + destination_config: DestinationTestConfiguration, dbt_venv: Venv +) -> None: from docs.examples.chess.chess import chess # provide chess url via environ @@ -80,7 +85,9 @@ def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_ven ) assert pipeline.default_schema_name is None # get the runner for the "dbt_transform" package - transforms = dlt.dbt.package(pipeline, "docs/examples/chess/dbt_transform", venv=dbt_venv) + transforms = dlt.dbt.package( + pipeline, "docs/examples/chess/dbt_transform", venv=dbt_venv + ) assert pipeline.default_schema_name is None # there's no data so the source tests will fail with pytest.raises(PrerequisitesException): @@ -135,7 +142,9 @@ def test_run_chess_dbt_to_other_dataset( pipeline.config.use_single_dataset = False # assert pipeline.default_schema_name is None # get the runner for the "dbt_transform" package - transforms = dlt.dbt.package(pipeline, "docs/examples/chess/dbt_transform", venv=dbt_venv) + transforms = dlt.dbt.package( + pipeline, "docs/examples/chess/dbt_transform", venv=dbt_venv + ) # assert pipeline.default_schema_name is None # load data info = pipeline.run(chess(max_players=5, month=9)) diff --git a/tests/load/pipeline/test_drop.py b/tests/load/pipeline/test_drop.py index afae1c22ca..caa9fdc42e 100644 --- a/tests/load/pipeline/test_drop.py +++ b/tests/load/pipeline/test_drop.py @@ -46,7 +46,9 @@ def droppable_c( ) -> Iterator[Dict[str, Any]]: # Grandchild table yield dict( - asdasd=2424, qe=111, items=[dict(k=2, r=2, labels=[dict(name="abc"), dict(name="www")])] + asdasd=2424, + qe=111, + items=[dict(k=2, r=2, labels=[dict(name="abc"), dict(name="www")])], ) @dlt.resource @@ -60,7 +62,13 @@ def droppable_d( def droppable_no_state(): yield [1, 2, 3] - return [droppable_a(), droppable_b(), droppable_c(), droppable_d(), droppable_no_state] + return [ + droppable_a(), + droppable_b(), + droppable_c(), + droppable_d(), + droppable_no_state, + ] RESOURCE_TABLES = dict( @@ -124,18 +132,26 @@ def assert_destination_state_loaded(pipeline: Pipeline) -> None: @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_drop_command_resources_and_state(destination_config: DestinationTestConfiguration) -> None: +def test_drop_command_resources_and_state( + destination_config: DestinationTestConfiguration, +) -> None: """Test the drop command with resource and state path options and verify correct data is deleted from destination and locally""" source = droppable_source() - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "drop_test_" + uniq_id(), full_refresh=True + ) pipeline.run(source) attached = _attach(pipeline) helpers.drop( - attached, resources=["droppable_c", "droppable_d"], state_paths="data_from_d.*.bar" + attached, + resources=["droppable_c", "droppable_d"], + state_paths="data_from_d.*.bar", ) attached = _attach(pipeline) @@ -150,12 +166,18 @@ def test_drop_command_resources_and_state(destination_config: DestinationTestCon @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_drop_command_only_state(destination_config: DestinationTestConfiguration) -> None: +def test_drop_command_only_state( + destination_config: DestinationTestConfiguration, +) -> None: """Test drop command that deletes part of the state and syncs with destination""" source = droppable_source() - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "drop_test_" + uniq_id(), full_refresh=True + ) pipeline.run(source) attached = _attach(pipeline) @@ -173,12 +195,18 @@ def test_drop_command_only_state(destination_config: DestinationTestConfiguratio @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_drop_command_only_tables(destination_config: DestinationTestConfiguration) -> None: +def test_drop_command_only_tables( + destination_config: DestinationTestConfiguration, +) -> None: """Test drop only tables and makes sure that schema and state are synced""" source = droppable_source() - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "drop_test_" + uniq_id(), full_refresh=True + ) pipeline.run(source) sources_state = pipeline.state["sources"] @@ -195,12 +223,18 @@ def test_drop_command_only_tables(destination_config: DestinationTestConfigurati @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_drop_destination_tables_fails(destination_config: DestinationTestConfiguration) -> None: +def test_drop_destination_tables_fails( + destination_config: DestinationTestConfiguration, +) -> None: """Fail on drop tables. Command runs again.""" source = droppable_source() - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "drop_test_" + uniq_id(), full_refresh=True + ) pipeline.run(source) attached = _attach(pipeline) @@ -221,18 +255,26 @@ def test_drop_destination_tables_fails(destination_config: DestinationTestConfig @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_fail_after_drop_tables(destination_config: DestinationTestConfiguration) -> None: +def test_fail_after_drop_tables( + destination_config: DestinationTestConfiguration, +) -> None: """Fail directly after drop tables. Command runs again ignoring destination tables missing.""" source = droppable_source() - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "drop_test_" + uniq_id(), full_refresh=True + ) pipeline.run(source) attached = _attach(pipeline) with mock.patch.object( - helpers.DropCommand, "_extract_state", side_effect=RuntimeError("Something went wrong") + helpers.DropCommand, + "_extract_state", + side_effect=RuntimeError("Something went wrong"), ): with pytest.raises(RuntimeError): helpers.drop(attached, resources=("droppable_a", "droppable_b")) @@ -245,17 +287,23 @@ def test_fail_after_drop_tables(destination_config: DestinationTestConfiguration @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_load_step_fails(destination_config: DestinationTestConfiguration) -> None: """Test idempotence. pipeline.load() fails. Command can be run again successfully""" source = droppable_source() - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "drop_test_" + uniq_id(), full_refresh=True + ) pipeline.run(source) attached = _attach(pipeline) - with mock.patch.object(Load, "run", side_effect=RuntimeError("Something went wrong")): + with mock.patch.object( + Load, "run", side_effect=RuntimeError("Something went wrong") + ): with pytest.raises(PipelineStepFailed) as e: helpers.drop(attached, resources=("droppable_a", "droppable_b")) assert isinstance(e.value.exception, RuntimeError) @@ -268,11 +316,15 @@ def test_load_step_fails(destination_config: DestinationTestConfiguration) -> No @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_resource_regex(destination_config: DestinationTestConfiguration) -> None: source = droppable_source() - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "drop_test_" + uniq_id(), full_refresh=True + ) pipeline.run(source) attached = _attach(pipeline) @@ -286,12 +338,16 @@ def test_resource_regex(destination_config: DestinationTestConfiguration) -> Non @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_drop_nothing(destination_config: DestinationTestConfiguration) -> None: """No resources, no state keys. Nothing is changed.""" source = droppable_source() - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "drop_test_" + uniq_id(), full_refresh=True + ) pipeline.run(source) attached = _attach(pipeline) @@ -304,12 +360,16 @@ def test_drop_nothing(destination_config: DestinationTestConfiguration) -> None: @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_drop_all_flag(destination_config: DestinationTestConfiguration) -> None: """Using drop_all flag. Destination dataset and all local state is deleted""" source = droppable_source() - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "drop_test_" + uniq_id(), full_refresh=True + ) pipeline.run(source) dlt_tables = [ t["name"] for t in pipeline.default_schema.dlt_tables() @@ -331,11 +391,17 @@ def test_drop_all_flag(destination_config: DestinationTestConfiguration) -> None @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_run_pipeline_after_partial_drop(destination_config: DestinationTestConfiguration) -> None: +def test_run_pipeline_after_partial_drop( + destination_config: DestinationTestConfiguration, +) -> None: """Pipeline can be run again after dropping some resources""" - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "drop_test_" + uniq_id(), full_refresh=True + ) pipeline.run(droppable_source()) attached = _attach(pipeline) @@ -344,17 +410,23 @@ def test_run_pipeline_after_partial_drop(destination_config: DestinationTestConf attached = _attach(pipeline) - attached.extract(droppable_source()) # TODO: individual steps cause pipeline.run() never raises + attached.extract( + droppable_source() + ) # TODO: individual steps cause pipeline.run() never raises attached.normalize() attached.load(raise_on_failed_jobs=True) @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_drop_state_only(destination_config: DestinationTestConfiguration) -> None: """Pipeline can be run again after dropping some resources""" - pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline( + "drop_test_" + uniq_id(), full_refresh=True + ) pipeline.run(droppable_source()) attached = _attach(pipeline) diff --git a/tests/load/pipeline/test_duckdb.py b/tests/load/pipeline/test_duckdb.py index 6064392976..d692ec28c0 100644 --- a/tests/load/pipeline/test_duckdb.py +++ b/tests/load/pipeline/test_duckdb.py @@ -24,8 +24,12 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No os.environ["SCHEMA__NAMING"] = "duck_case" pipeline = destination_config.setup_pipeline("test_duck_case_names") # create tables and columns with emojis and other special characters - pipeline.run(airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock")) - pipeline.run([{"🐾Feet": 2, "1+1": "two", "\nhey": "value"}], table_name="🦚Peacocks🦚") + pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock") + ) + pipeline.run( + [{"🐾Feet": 2, "1+1": "two", "\nhey": "value"}], table_name="🦚Peacocks🦚" + ) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) @@ -40,11 +44,19 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No # this will fail - duckdb preserves case but is case insensitive when comparing identifiers with pytest.raises(PipelineStepFailed) as pip_ex: - pipeline.run([{"🐾Feet": 2, "1+1": "two", "🐾feet": "value"}], table_name="🦚peacocks🦚") + pipeline.run( + [{"🐾Feet": 2, "1+1": "two", "🐾feet": "value"}], table_name="🦚peacocks🦚" + ) assert isinstance(pip_ex.value.__context__, DatabaseTerminalException) # show tables and columns with pipeline.sql_client() as client: with client.execute_query("DESCRIBE 🦚peacocks🦚;") as q: tables = q.df() - assert tables["column_name"].tolist() == ["🐾Feet", "1+1", "hey", "_dlt_load_id", "_dlt_id"] + assert tables["column_name"].tolist() == [ + "🐾Feet", + "1+1", + "hey", + "_dlt_load_id", + "_dlt_id", + ] diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 8fc4adc0c3..e29deef93f 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -4,7 +4,10 @@ import dlt, os from dlt.common.utils import uniq_id from dlt.common.storages.load_storage import LoadJobInfo -from dlt.destinations.impl.filesystem.filesystem import FilesystemClient, LoadFilesystemJob +from dlt.destinations.impl.filesystem.filesystem import ( + FilesystemClient, + LoadFilesystemJob, +) from dlt.common.schema.typing import LOADS_TABLE_NAME from tests.utils import skip_if_not_active @@ -86,8 +89,12 @@ def some_source(): complete_fn = f"{client.schema.name}.{LOADS_TABLE_NAME}.%s" # Test complete_load markers are saved - assert client.fs_client.isfile(posixpath.join(client.dataset_path, complete_fn % load_id1)) - assert client.fs_client.isfile(posixpath.join(client.dataset_path, complete_fn % load_id2)) + assert client.fs_client.isfile( + posixpath.join(client.dataset_path, complete_fn % load_id1) + ) + assert client.fs_client.isfile( + posixpath.join(client.dataset_path, complete_fn % load_id2) + ) # Force replace pipeline.run(some_source(), write_disposition="replace") diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 19ee9a34c8..8e31d2eb19 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -27,12 +27,18 @@ @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguration) -> None: +def test_merge_on_keys_in_schema( + destination_config: DestinationTestConfiguration, +) -> None: p = destination_config.setup_pipeline("eth_2", full_refresh=True) - with open("tests/common/cases/schemas/eth/ethereum_schema_v5.yml", "r", encoding="utf-8") as f: + with open( + "tests/common/cases/schemas/eth/ethereum_schema_v5.yml", "r", encoding="utf-8" + ) as f: schema = dlt.Schema.from_dict(yaml.safe_load(f)) # make block uncles unseen to trigger filtering loader in loader for child tables @@ -61,7 +67,9 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio assert eth_1_counts["blocks"] == 1 # check root key propagation assert ( - p.default_schema.tables["blocks__transactions"]["columns"]["_dlt_root_id"]["root_key"] + p.default_schema.tables["blocks__transactions"]["columns"]["_dlt_root_id"][ + "root_key" + ] is True ) # now we load the whole dataset. blocks should be created which adds columns to blocks @@ -73,7 +81,9 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio schema=schema, loader_file_format=destination_config.file_format, ) - eth_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + eth_2_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) # we have 2 blocks in dataset assert eth_2_counts["blocks"] == 2 if destination_config.supports_merge else 3 # make sure we have same record after merging full dataset again @@ -88,18 +98,26 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio # for non merge destinations we just check that the run passes if not destination_config.supports_merge: return - eth_3_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + eth_3_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) assert eth_2_counts == eth_3_counts @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_merge_on_ad_hoc_primary_key(destination_config: DestinationTestConfiguration) -> None: +def test_merge_on_ad_hoc_primary_key( + destination_config: DestinationTestConfiguration, +) -> None: p = destination_config.setup_pipeline("github_1", full_refresh=True) with open( - "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + "tests/normalize/cases/github.issues.load_page_5_duck.json", + "r", + encoding="utf-8", ) as f: data = json.load(f) # note: NodeId will be normalized to "node_id" which exists in the schema @@ -111,12 +129,18 @@ def test_merge_on_ad_hoc_primary_key(destination_config: DestinationTestConfigur loader_file_format=destination_config.file_format, ) assert_load_info(info) - github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_1_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) # 17 issues assert github_1_counts["issues"] == 17 # primary key set on issues - assert p.default_schema.tables["issues"]["columns"]["node_id"]["primary_key"] is True - assert p.default_schema.tables["issues"]["columns"]["node_id"]["data_type"] == "text" + assert ( + p.default_schema.tables["issues"]["columns"]["node_id"]["primary_key"] is True + ) + assert ( + p.default_schema.tables["issues"]["columns"]["node_id"]["data_type"] == "text" + ) assert p.default_schema.tables["issues"]["columns"]["node_id"]["nullable"] is False info = p.run( @@ -130,7 +154,9 @@ def test_merge_on_ad_hoc_primary_key(destination_config: DestinationTestConfigur # for non merge destinations we just check that the run passes if not destination_config.supports_merge: return - github_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_2_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) # 100 issues total assert github_2_counts["issues"] == 100 # still 100 after the reload @@ -146,7 +172,9 @@ def github(): ) def load_issues(): with open( - "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + "tests/normalize/cases/github.issues.load_page_5_duck.json", + "r", + encoding="utf-8", ) as f: yield from json.load(f) @@ -154,7 +182,9 @@ def load_issues(): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_merge_source_compound_keys_and_changes( destination_config: DestinationTestConfiguration, @@ -163,7 +193,9 @@ def test_merge_source_compound_keys_and_changes( info = p.run(github(), loader_file_format=destination_config.file_format) assert_load_info(info) - github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_1_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) # 100 issues total assert github_1_counts["issues"] == 100 # check keys created @@ -189,25 +221,35 @@ def test_merge_source_compound_keys_and_changes( assert_load_info(info) assert p.default_schema.tables["issues"]["write_disposition"] == "append" # the counts of all tables must be double - github_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_2_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) assert {k: v * 2 for k, v in github_1_counts.items()} == github_2_counts # now replace all resources info = p.run( - github(), write_disposition="replace", loader_file_format=destination_config.file_format + github(), + write_disposition="replace", + loader_file_format=destination_config.file_format, ) assert_load_info(info) assert p.default_schema.tables["issues"]["write_disposition"] == "replace" # assert p.default_schema.tables["issues__labels"]["write_disposition"] == "replace" # the counts of all tables must be double - github_3_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_3_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) assert github_1_counts == github_3_counts @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_merge_no_child_tables(destination_config: DestinationTestConfiguration) -> None: +def test_merge_no_child_tables( + destination_config: DestinationTestConfiguration, +) -> None: p = destination_config.setup_pipeline("github_3", full_refresh=True) github_data = github() assert github_data.max_table_nesting is None @@ -224,7 +266,9 @@ def test_merge_no_child_tables(destination_config: DestinationTestConfiguration) assert len(p.default_schema.data_tables()) == 1 assert "issues" in p.default_schema.tables assert_load_info(info) - github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_1_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) assert github_1_counts["issues"] == 15 # load all @@ -232,13 +276,19 @@ def test_merge_no_child_tables(destination_config: DestinationTestConfiguration) github_data.max_table_nesting = 0 info = p.run(github_data, loader_file_format=destination_config.file_format) assert_load_info(info) - github_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_2_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) # 100 issues total, or 115 if merge is not supported - assert github_2_counts["issues"] == 100 if destination_config.supports_merge else 115 + assert ( + github_2_counts["issues"] == 100 if destination_config.supports_merge else 115 + ) @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("github_3", full_refresh=True) @@ -249,7 +299,9 @@ def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) - github_data.load_issues.add_filter(skip_first(45)) info = p.run(github_data, loader_file_format=destination_config.file_format) assert_load_info(info) - github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_1_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) assert github_1_counts["issues"] == 100 - 45 # take first 10 rows. @@ -260,24 +312,38 @@ def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) - github_data.load_issues.add_filter(take_first(10)) info = p.run(github_data, loader_file_format=destination_config.file_format) assert_load_info(info) - github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_1_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) # only ten rows remains. merge falls back to replace when no keys are specified - assert github_1_counts["issues"] == 10 if destination_config.supports_merge else 100 - 45 + assert ( + github_1_counts["issues"] == 10 + if destination_config.supports_merge + else 100 - 45 + ) @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_merge_keys_non_existing_columns(destination_config: DestinationTestConfiguration) -> None: +def test_merge_keys_non_existing_columns( + destination_config: DestinationTestConfiguration, +) -> None: p = destination_config.setup_pipeline("github_3", full_refresh=True) github_data = github() # set keys names that do not exist in the data - github_data.load_issues.apply_hints(merge_key=("mA1", "Ma2"), primary_key=("123-x",)) + github_data.load_issues.apply_hints( + merge_key=("mA1", "Ma2"), primary_key=("123-x",) + ) # skip first 45 rows github_data.load_issues.add_filter(skip_first(45)) info = p.run(github_data, loader_file_format=destination_config.file_format) assert_load_info(info) - github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_1_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) assert github_1_counts["issues"] == 100 - 45 assert ( p.default_schema.tables["issues"]["columns"]["m_a1"].items() @@ -290,11 +356,15 @@ def test_merge_keys_non_existing_columns(destination_config: DestinationTestConf # all the keys are invalid so the merge falls back to replace github_data = github() - github_data.load_issues.apply_hints(merge_key=("mA1", "Ma2"), primary_key=("123-x",)) + github_data.load_issues.apply_hints( + merge_key=("mA1", "Ma2"), primary_key=("123-x",) + ) github_data.load_issues.add_filter(take_first(1)) info = p.run(github_data, loader_file_format=destination_config.file_format) assert_load_info(info) - github_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_2_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) assert github_2_counts["issues"] == 1 with p._sql_job_client(p.default_schema) as job_c: _, table_schema = job_c.get_storage_table("issues") @@ -307,7 +377,9 @@ def test_merge_keys_non_existing_columns(destination_config: DestinationTestConf destinations_configs(default_sql_configs=True, file_format="parquet"), ids=lambda x: x.name, ) -def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) -> None: +def test_pipeline_load_parquet( + destination_config: DestinationTestConfiguration, +) -> None: p = destination_config.setup_pipeline("github_3", full_refresh=True) github_data = github() # generate some complex types @@ -315,7 +387,9 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) github_data_copy = github() github_data_copy.max_table_nesting = 2 info = p.run( - [github_data, github_data_copy], loader_file_format="parquet", write_disposition="merge" + [github_data, github_data_copy], + loader_file_format="parquet", + write_disposition="merge", ) assert_load_info(info) # make sure it was parquet or sql transforms @@ -323,10 +397,14 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) if p.staging: # allow references if staging is present expected_formats.append("reference") - files = p.get_load_package_info(p.list_completed_load_packages()[0]).jobs["completed_jobs"] + files = p.get_load_package_info(p.list_completed_load_packages()[0]).jobs[ + "completed_jobs" + ] assert all(f.job_file_info.file_format in expected_formats + ["sql"] for f in files) - github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_1_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) expected_rows = 100 if not destination_config.supports_merge: expected_rows *= 2 @@ -339,13 +417,17 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) info = p.run(github_data, loader_file_format="parquet", write_disposition="replace") assert_load_info(info) # make sure it was parquet or sql inserts - files = p.get_load_package_info(p.list_completed_load_packages()[1]).jobs["completed_jobs"] + files = p.get_load_package_info(p.list_completed_load_packages()[1]).jobs[ + "completed_jobs" + ] if destination_config.force_iceberg: # iceberg uses sql to copy tables expected_formats.append("sql") assert all(f.job_file_info.file_format in expected_formats for f in files) - github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) + github_1_counts = load_table_counts( + p, *[t["name"] for t in p.default_schema.data_tables()] + ) assert github_1_counts["issues"] == 100 @@ -375,7 +457,9 @@ def github_repo_events_table_meta( @dlt.resource def _get_shuffled_events(shuffle: bool = dlt.secrets.value): with open( - "tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8" + "tests/normalize/cases/github.events.load_page_1_duck.json", + "r", + encoding="utf-8", ) as f: issues = json.load(f) # random order @@ -385,9 +469,13 @@ def _get_shuffled_events(shuffle: bool = dlt.secrets.value): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize( + "github_resource", [github_repo_events, github_repo_events_table_meta] ) -@pytest.mark.parametrize("github_resource", [github_repo_events, github_repo_events_table_meta]) def test_merge_with_dispatch_and_incremental( destination_config: DestinationTestConfiguration, github_resource: DltResource ) -> None: @@ -445,7 +533,8 @@ def _updated_event(node_id): assert_load_info(info) # get top tables counts = load_table_counts( - p, *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None] + p, + *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None], ) # total number of events in all top tables == 100 assert sum(counts.values()) == 100 @@ -459,11 +548,13 @@ def _updated_event(node_id): # load one more event with a new id info = p.run( - _new_event("new_node") | github_resource, loader_file_format=destination_config.file_format + _new_event("new_node") | github_resource, + loader_file_format=destination_config.file_format, ) assert_load_info(info) counts = load_table_counts( - p, *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None] + p, + *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None], ) assert sum(counts.values()) == 101 # all the columns have primary keys and merge disposition derived from resource @@ -480,7 +571,8 @@ def _updated_event(node_id): assert_load_info(info) # still 101 counts = load_table_counts( - p, *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None] + p, + *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None], ) assert sum(counts.values()) == 101 if destination_config.supports_merge else 102 # for non merge destinations we just check that the run passes @@ -489,14 +581,20 @@ def _updated_event(node_id): # but we have it updated with p.sql_client() as c: qual_name = c.make_qualified_table_name("watch_event") - with c.execute_query(f"SELECT node_id FROM {qual_name} WHERE node_id = 'new_node_X'") as q: + with c.execute_query( + f"SELECT node_id FROM {qual_name} WHERE node_id = 'new_node_X'" + ) as q: assert len(list(q.fetchall())) == 1 @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_deduplicate_single_load(destination_config: DestinationTestConfiguration) -> None: +def test_deduplicate_single_load( + destination_config: DestinationTestConfiguration, +) -> None: p = destination_config.setup_pipeline("abstract", full_refresh=True) @dlt.resource(write_disposition="merge", primary_key="id") @@ -516,18 +614,29 @@ def duplicates(): @dlt.resource(write_disposition="merge", primary_key=("id", "subkey")) def duplicates_no_child(): - yield [{"id": 1, "subkey": "AX", "name": "row1"}, {"id": 1, "subkey": "AX", "name": "row2"}] + yield [ + {"id": 1, "subkey": "AX", "name": "row1"}, + {"id": 1, "subkey": "AX", "name": "row2"}, + ] - info = p.run(duplicates_no_child(), loader_file_format=destination_config.file_format) + info = p.run( + duplicates_no_child(), loader_file_format=destination_config.file_format + ) assert_load_info(info) counts = load_table_counts(p, "duplicates_no_child") - assert counts["duplicates_no_child"] == 1 if destination_config.supports_merge else 2 + assert ( + counts["duplicates_no_child"] == 1 if destination_config.supports_merge else 2 + ) @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_no_deduplicate_only_merge_key(destination_config: DestinationTestConfiguration) -> None: +def test_no_deduplicate_only_merge_key( + destination_config: DestinationTestConfiguration, +) -> None: p = destination_config.setup_pipeline("abstract", full_refresh=True) @dlt.resource(write_disposition="merge", merge_key="id") @@ -545,9 +654,14 @@ def duplicates(): @dlt.resource(write_disposition="merge", merge_key=("id", "subkey")) def duplicates_no_child(): - yield [{"id": 1, "subkey": "AX", "name": "row1"}, {"id": 1, "subkey": "AX", "name": "row2"}] + yield [ + {"id": 1, "subkey": "AX", "name": "row1"}, + {"id": 1, "subkey": "AX", "name": "row2"}, + ] - info = p.run(duplicates_no_child(), loader_file_format=destination_config.file_format) + info = p.run( + duplicates_no_child(), loader_file_format=destination_config.file_format + ) assert_load_info(info) counts = load_table_counts(p, "duplicates_no_child") assert counts["duplicates_no_child"] == 2 @@ -558,7 +672,9 @@ def duplicates_no_child(): destinations_configs(default_sql_configs=True, supports_merge=True), ids=lambda x: x.name, ) -def test_complex_column_missing(destination_config: DestinationTestConfiguration) -> None: +def test_complex_column_missing( + destination_config: DestinationTestConfiguration, +) -> None: table_name = "test_complex_column_missing" @dlt.resource(name=table_name, write_disposition="merge", primary_key="id") @@ -587,7 +703,9 @@ def r(data): ids=lambda x: x.name, ) @pytest.mark.parametrize("key_type", ["primary_key", "merge_key"]) -def test_hard_delete_hint(destination_config: DestinationTestConfiguration, key_type: str) -> None: +def test_hard_delete_hint( + destination_config: DestinationTestConfiguration, key_type: str +) -> None: table_name = "test_hard_delete_hint" @dlt.resource( @@ -730,7 +848,9 @@ def data_resource(data): destinations_configs(default_sql_configs=True, supports_merge=True), ids=lambda x: x.name, ) -def test_hard_delete_hint_config(destination_config: DestinationTestConfiguration) -> None: +def test_hard_delete_hint_config( + destination_config: DestinationTestConfiguration, +) -> None: table_name = "test_hard_delete_hint_non_bool" @dlt.resource( @@ -738,7 +858,11 @@ def test_hard_delete_hint_config(destination_config: DestinationTestConfiguratio write_disposition="merge", primary_key="id", columns={ - "deleted_timestamp": {"data_type": "timestamp", "nullable": True, "hard_delete": True} + "deleted_timestamp": { + "data_type": "timestamp", + "nullable": True, + "hard_delete": True, + } }, ) def data_resource(data): @@ -776,7 +900,10 @@ def data_resource(data): @dlt.resource( name="test_hard_delete_hint_too_many_hints", write_disposition="merge", - columns={"deleted_1": {"hard_delete": True}, "deleted_2": {"hard_delete": True}}, + columns={ + "deleted_1": {"hard_delete": True}, + "deleted_2": {"hard_delete": True}, + }, ) def r(): yield {"id": 1, "val": "foo", "deleted_1": True, "deleted_2": False} @@ -910,7 +1037,9 @@ def data_resource(data): {"id": 1, "val": "foo", "sequence": 1}, {"id": 1, "val": "bar", "sequence": 2, "deleted": True}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run( + data_resource(data), loader_file_format=destination_config.file_format + ) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 0 @@ -920,7 +1049,9 @@ def data_resource(data): {"id": 1, "val": "foo", "sequence": 2}, {"id": 1, "val": "bar", "sequence": 1, "deleted": True}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run( + data_resource(data), loader_file_format=destination_config.file_format + ) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 1 @@ -939,7 +1070,10 @@ def r(): # more than one "dedup_sort" column hints are provided r.apply_hints( - columns={"dedup_sort_1": {"dedup_sort": "desc"}, "dedup_sort_2": {"dedup_sort": "desc"}} + columns={ + "dedup_sort_1": {"dedup_sort": "desc"}, + "dedup_sort_2": {"dedup_sort": "desc"}, + } ) with pytest.raises(PipelineStepFailed): info = p.run(r(), loader_file_format=destination_config.file_format) diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 017bef2c01..5e17151a4f 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -60,7 +60,9 @@ def test_default_pipeline_names( possible_names = ["dlt_pytest", "dlt_pipeline"] possible_dataset_names = ["dlt_pytest_dataset", "dlt_pipeline_dataset"] assert p.pipeline_name in possible_names - assert p.pipelines_dir == os.path.abspath(os.path.join(TEST_STORAGE_ROOT, ".dlt", "pipelines")) + assert p.pipelines_dir == os.path.abspath( + os.path.join(TEST_STORAGE_ROOT, ".dlt", "pipelines") + ) assert p.dataset_name in possible_dataset_names assert p.destination is None assert p.default_schema_name is None @@ -197,7 +199,9 @@ def _data(): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_skip_sync_schema_for_tables_without_columns( destination_config: DestinationTestConfiguration, @@ -210,7 +214,9 @@ def _data(): for d in data: yield d - p = destination_config.setup_pipeline("test_skip_sync_schema_for_tables", full_refresh=True) + p = destination_config.setup_pipeline( + "test_skip_sync_schema_for_tables", full_refresh=True + ) p.extract(_data) schema = p.default_schema assert "data_table" in schema.tables @@ -269,13 +275,17 @@ def _data(): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_evolve_schema(destination_config: DestinationTestConfiguration) -> None: dataset_name = "d" + uniq_id() row = { "id": "level0", - "f": [{"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]}], + "f": [ + {"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]} + ], } @dlt.source(name="parallel") @@ -313,12 +323,18 @@ def extended_rows(): # yield deferred items resolved in threads yield get_item(no + 100) - return simple_rows(), extended_rows(), dlt.resource(["a", "b", "c"], name="simple") + return ( + simple_rows(), + extended_rows(), + dlt.resource(["a", "b", "c"], name="simple"), + ) import_schema_path = os.path.join(TEST_STORAGE_ROOT, "schemas", "import") export_schema_path = os.path.join(TEST_STORAGE_ROOT, "schemas", "export") p = destination_config.setup_pipeline( - "my_pipeline", import_schema_path=import_schema_path, export_schema_path=export_schema_path + "my_pipeline", + import_schema_path=import_schema_path, + export_schema_path=export_schema_path, ) p.extract(source(10).with_resources("simple_rows")) @@ -367,7 +383,8 @@ def extended_rows(): # TODO: test export and import schema # test data id_data = sorted( - ["level" + str(n) for n in range(10)] + ["level" + str(n) for n in range(100, 110)] + ["level" + str(n) for n in range(10)] + + ["level" + str(n) for n in range(100, 110)] ) with p.sql_client() as client: simple_rows_table = client.make_qualified_table_name("simple_rows") @@ -409,7 +426,9 @@ def test_pipeline_data_writer_compression( @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_source_max_nesting(destination_config: DestinationTestConfiguration) -> None: destination_config.setup() @@ -439,7 +458,9 @@ def complex_data(): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_dataset_name_change(destination_config: DestinationTestConfiguration) -> None: destination_config.setup() @@ -449,14 +470,20 @@ def test_dataset_name_change(destination_config: DestinationTestConfiguration) - ds_2_name = "IteRation" + uniq_id() # illegal name that will be later normalized ds_3_name = "1it/era 👍 tion__" + uniq_id() - p, s = simple_nested_pipeline(destination_config, dataset_name=ds_1_name, full_refresh=False) + p, s = simple_nested_pipeline( + destination_config, dataset_name=ds_1_name, full_refresh=False + ) try: info = p.run(s(), loader_file_format=destination_config.file_format) assert_load_info(info) assert info.dataset_name == ds_1_name ds_1_counts = load_table_counts(p, "lists", "lists__value") # run to another dataset - info = p.run(s(), dataset_name=ds_2_name, loader_file_format=destination_config.file_format) + info = p.run( + s(), + dataset_name=ds_2_name, + loader_file_format=destination_config.file_format, + ) assert_load_info(info) assert info.dataset_name.startswith("ite_ration") # save normalized dataset name to delete correctly later @@ -493,7 +520,9 @@ def test_pipeline_explicit_destination_credentials( destination=Destination.from_reference("postgres", destination_name="mydest"), credentials="postgresql://loader:loader@localhost:7777/dlt_data", ) - c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] + c = p._get_destination_clients( + Schema("s"), p._get_destination_client_initial_config() + )[0] assert c.config.credentials.port == 7777 # type: ignore[attr-defined] # TODO: may want to clear the env completely and ignore/mock config files somehow to avoid side effects @@ -503,7 +532,9 @@ def test_pipeline_explicit_destination_credentials( destination=Destination.from_reference("postgres", destination_name="mydest"), credentials="postgresql://loader:loader@localhost:5432/dlt_data", ) - c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] + c = p._get_destination_clients( + Schema("s"), p._get_destination_client_initial_config() + )[0] assert c.config.credentials.host == "localhost" # type: ignore[attr-defined] # explicit partial credentials will use config providers @@ -513,7 +544,9 @@ def test_pipeline_explicit_destination_credentials( destination=Destination.from_reference("postgres", destination_name="mydest"), credentials="postgresql://localhost:5432/dlt_data", ) - c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] + c = p._get_destination_clients( + Schema("s"), p._get_destination_client_initial_config() + )[0] assert c.config.credentials.username == "UN" # type: ignore[attr-defined] # host is also overridden assert c.config.credentials.host == "HOST" # type: ignore[attr-defined] @@ -619,7 +652,9 @@ def conflict(): destinations_configs(default_sql_configs=True, subset=["postgres"]), ids=lambda x: x.name, ) -def test_many_pipelines_single_dataset(destination_config: DestinationTestConfiguration) -> None: +def test_many_pipelines_single_dataset( + destination_config: DestinationTestConfiguration, +) -> None: schema = Schema("shared") @dlt.source(schema=schema, max_table_nesting=1) @@ -647,22 +682,30 @@ def gen2(): # load source_1 to common dataset p = dlt.pipeline( - pipeline_name="source_1_pipeline", destination="duckdb", dataset_name="shared_dataset" + pipeline_name="source_1_pipeline", + destination="duckdb", + dataset_name="shared_dataset", ) p.run(source_1(), credentials="duckdb:///_storage/test_quack.duckdb") counts = load_table_counts(p, *p.default_schema.tables.keys()) - assert counts.items() >= {"gen1": 1, "_dlt_pipeline_state": 1, "_dlt_loads": 1}.items() + assert ( + counts.items() >= {"gen1": 1, "_dlt_pipeline_state": 1, "_dlt_loads": 1}.items() + ) p._wipe_working_folder() p.deactivate() p = dlt.pipeline( - pipeline_name="source_2_pipeline", destination="duckdb", dataset_name="shared_dataset" + pipeline_name="source_2_pipeline", + destination="duckdb", + dataset_name="shared_dataset", ) p.run(source_2(), credentials="duckdb:///_storage/test_quack.duckdb") # table_names = [t["name"] for t in p.default_schema.data_tables()] counts = load_table_counts(p, *p.default_schema.tables.keys()) # gen1: one record comes from source_1, 1 record from source_2 - assert counts.items() >= {"gen1": 2, "_dlt_pipeline_state": 2, "_dlt_loads": 2}.items() + assert ( + counts.items() >= {"gen1": 2, "_dlt_pipeline_state": 2, "_dlt_loads": 2}.items() + ) # assert counts == {'gen1': 2, 'gen2': 3} p._wipe_working_folder() p.deactivate() @@ -705,10 +748,14 @@ def gen2(): destinations_configs(default_sql_configs=True, subset=["snowflake"]), ids=lambda x: x.name, ) -def test_snowflake_custom_stage(destination_config: DestinationTestConfiguration) -> None: +def test_snowflake_custom_stage( + destination_config: DestinationTestConfiguration, +) -> None: """Using custom stage name instead of the table stage""" os.environ["DESTINATION__SNOWFLAKE__STAGE_NAME"] = "my_non_existing_stage" - pipeline, data = simple_nested_pipeline(destination_config, f"custom_stage_{uniq_id()}", False) + pipeline, data = simple_nested_pipeline( + destination_config, f"custom_stage_{uniq_id()}", False + ) info = pipeline.run(data(), loader_file_format=destination_config.file_format) with pytest.raises(DestinationHasFailedJobs) as f_jobs: info.raise_on_failed_jobs() @@ -721,7 +768,9 @@ def test_snowflake_custom_stage(destination_config: DestinationTestConfiguration # GRANT READ, WRITE ON STAGE DLT_DATA.PUBLIC.MY_CUSTOM_LOCAL_STAGE TO ROLE DLT_LOADER_ROLE; stage_name = "PUBLIC.MY_CUSTOM_LOCAL_STAGE" os.environ["DESTINATION__SNOWFLAKE__STAGE_NAME"] = stage_name - pipeline, data = simple_nested_pipeline(destination_config, f"custom_stage_{uniq_id()}", False) + pipeline, data = simple_nested_pipeline( + destination_config, f"custom_stage_{uniq_id()}", False + ) info = pipeline.run(data(), loader_file_format=destination_config.file_format) assert_load_info(info) @@ -742,7 +791,9 @@ def test_snowflake_custom_stage(destination_config: DestinationTestConfiguration destinations_configs(default_sql_configs=True, subset=["snowflake"]), ids=lambda x: x.name, ) -def test_snowflake_delete_file_after_copy(destination_config: DestinationTestConfiguration) -> None: +def test_snowflake_delete_file_after_copy( + destination_config: DestinationTestConfiguration, +) -> None: """Using keep_staged_files = false option to remove staged files after copy""" os.environ["DESTINATION__SNOWFLAKE__KEEP_STAGED_FILES"] = "FALSE" @@ -769,7 +820,9 @@ def test_snowflake_delete_file_after_copy(destination_config: DestinationTestCon # do not remove - it allows us to filter tests by destination @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, all_staging_configs=True, file_format="parquet"), + destinations_configs( + default_sql_configs=True, all_staging_configs=True, file_format="parquet" + ), ids=lambda x: x.name, ) def test_parquet_loading(destination_config: DestinationTestConfiguration) -> None: @@ -793,14 +846,21 @@ def other_data(): # parquet on bigquery does not support JSON but we still want to run the test if destination_config.destination == "bigquery": - column_schemas["col9_null"]["data_type"] = column_schemas["col9"]["data_type"] = "text" + column_schemas["col9_null"]["data_type"] = column_schemas["col9"][ + "data_type" + ] = "text" # duckdb 0.9.1 does not support TIME other than 6 if destination_config.destination in ["duckdb", "motherduck"]: column_schemas["col11_precision"]["precision"] = 0 # drop TIME from databases not supporting it via parquet - if destination_config.destination in ["redshift", "athena", "synapse", "databricks"]: + if destination_config.destination in [ + "redshift", + "athena", + "synapse", + "databricks", + ]: data_types.pop("col11") data_types.pop("col11_null") data_types.pop("col11_precision") @@ -813,7 +873,9 @@ def other_data(): column_schemas.pop("col7_precision") # apply the exact columns definitions so we process complex and wei types correctly! - @dlt.resource(table_name="data_types", write_disposition="merge", columns=column_schemas) + @dlt.resource( + table_name="data_types", write_disposition="merge", columns=column_schemas + ) def my_resource(): nonlocal data_types yield [data_types] * 10 @@ -842,11 +904,15 @@ def some_source(): qual_name = sql_client.make_qualified_table_name assert [ row[0] - for row in sql_client.execute_sql(f"SELECT * FROM {qual_name('other_data')} ORDER BY 1") + for row in sql_client.execute_sql( + f"SELECT * FROM {qual_name('other_data')} ORDER BY 1" + ) ] == [1, 2, 3, 4, 5] assert [ row[0] - for row in sql_client.execute_sql(f"SELECT * FROM {qual_name('some_data')} ORDER BY 1") + for row in sql_client.execute_sql( + f"SELECT * FROM {qual_name('some_data')} ORDER BY 1" + ) ] == [1, 2, 3] db_rows = sql_client.execute_sql(f"SELECT * FROM {qual_name('data_types')}") assert len(db_rows) == 10 @@ -870,7 +936,10 @@ def some_source(): def test_pipeline_upfront_tables_two_loads( destination_config: DestinationTestConfiguration, replace_strategy: str ) -> None: - if not destination_config.supports_merge and replace_strategy != "truncate-and-insert": + if ( + not destination_config.supports_merge + and replace_strategy != "truncate-and-insert" + ): pytest.skip( f"Destination {destination_config.name} does not support merge and thus" f" {replace_strategy}" @@ -932,14 +1001,17 @@ def table_3(make_data=False): ) # load with one empty job, table 3 not created - load_info = pipeline.run(source.table_3, loader_file_format=destination_config.file_format) + load_info = pipeline.run( + source.table_3, loader_file_format=destination_config.file_format + ) assert_load_info(load_info, expected_load_packages=0) with pytest.raises(DatabaseUndefinedRelation): load_table_counts(pipeline, "table_3") # print(pipeline.default_schema.to_pretty_yaml()) load_info_2 = pipeline.run( - [source.table_1, source.table_3], loader_file_format=destination_config.file_format + [source.table_1, source.table_3], + loader_file_format=destination_config.file_format, ) assert_load_info(load_info_2) # 1 record in table 1 @@ -1027,7 +1099,9 @@ def table_3(make_data=False): def simple_nested_pipeline( - destination_config: DestinationTestConfiguration, dataset_name: str, full_refresh: bool + destination_config: DestinationTestConfiguration, + dataset_name: str, + full_refresh: bool, ) -> Tuple[dlt.Pipeline, Callable[[], DltSource]]: data = ["a", ["a", "b", "c"], ["a", "b", "c"]] diff --git a/tests/load/pipeline/test_redshift.py b/tests/load/pipeline/test_redshift.py index 44234ec64b..92d8945556 100644 --- a/tests/load/pipeline/test_redshift.py +++ b/tests/load/pipeline/test_redshift.py @@ -14,13 +14,19 @@ destinations_configs(all_staging_configs=True, subset=["redshift"]), ids=lambda x: x.name, ) -def test_redshift_blocks_time_column(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), full_refresh=True) +def test_redshift_blocks_time_column( + destination_config: DestinationTestConfiguration, +) -> None: + pipeline = destination_config.setup_pipeline( + "athena_" + uniq_id(), full_refresh=True + ) column_schemas, data_types = table_update_and_row() # apply the exact columns definitions so we process complex and wei types correctly! - @dlt.resource(table_name="data_types", write_disposition="append", columns=column_schemas) + @dlt.resource( + table_name="data_types", write_disposition="append", columns=column_schemas + ) def my_resource() -> Iterator[Any]: nonlocal data_types yield [data_types] * 10 diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index a69d4440dc..c179ef789f 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -19,7 +19,9 @@ @pytest.mark.parametrize( "destination_config", destinations_configs( - local_filesystem_configs=True, default_staging_configs=True, default_sql_configs=True + local_filesystem_configs=True, + default_staging_configs=True, + default_sql_configs=True, ), ids=lambda x: x.name, ) @@ -27,7 +29,10 @@ def test_replace_disposition( destination_config: DestinationTestConfiguration, replace_strategy: str ) -> None: - if not destination_config.supports_merge and replace_strategy != "truncate-and-insert": + if ( + not destination_config.supports_merge + and replace_strategy != "truncate-and-insert" + ): pytest.skip( f"Destination {destination_config.name} does not support merge and thus" f" {replace_strategy}" @@ -38,10 +43,14 @@ def test_replace_disposition( # use staging tables for replace os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy # make duckdb to reuse database in working folder - os.environ["DESTINATION__DUCKDB__CREDENTIALS"] = "duckdb:///test_replace_disposition.duckdb" + os.environ["DESTINATION__DUCKDB__CREDENTIALS"] = ( + "duckdb:///test_replace_disposition.duckdb" + ) # TODO: start storing _dlt_loads with right json content - increase_loads = lambda x: x if destination_config.destination == "filesystem" else x + 1 + increase_loads = lambda x: ( + x if destination_config.destination == "filesystem" else x + 1 + ) increase_state_loads = lambda info: len( [ job @@ -65,7 +74,9 @@ def norm_table_counts(counts: Dict[str, int], *child_tables: str) -> Dict[str, i offset = 1000 # keep merge key with unknown column to test replace SQL generator - @dlt.resource(name="items", write_disposition="replace", primary_key="id", merge_key="NA") + @dlt.resource( + name="items", write_disposition="replace", primary_key="id", merge_key="NA" + ) def load_items(): # will produce 3 jobs for the main table with 40 items each # 6 jobs for the sub_items @@ -161,7 +172,8 @@ def load_items_none(): yield info = pipeline.run( - [load_items_none, append_items], loader_file_format=destination_config.file_format + [load_items_none, append_items], + loader_file_format=destination_config.file_format, ) assert_load_info(info) state_records += increase_state_loads(info) @@ -194,7 +206,9 @@ def load_items_none(): "test_replace_strategies_2", dataset_name=dataset_name ) info = pipeline_2.run( - load_items, table_name="items_copy", loader_file_format=destination_config.file_format + load_items, + table_name="items_copy", + loader_file_format=destination_config.file_format, ) assert_load_info(info) new_state_records = increase_state_loads(info) @@ -209,14 +223,18 @@ def load_items_none(): "_dlt_pipeline_state": 1, } - info = pipeline_2.run(append_items, loader_file_format=destination_config.file_format) + info = pipeline_2.run( + append_items, loader_file_format=destination_config.file_format + ) assert_load_info(info) new_state_records = increase_state_loads(info) assert new_state_records == 0 dlt_loads = increase_loads(dlt_loads) # new pipeline - table_counts = load_table_counts(pipeline_2, *pipeline_2.default_schema.tables.keys()) + table_counts = load_table_counts( + pipeline_2, *pipeline_2.default_schema.tables.keys() + ) assert norm_table_counts(table_counts) == { "append_items": 48, "items_copy": 120, @@ -249,7 +267,9 @@ def load_items_none(): @pytest.mark.parametrize( "destination_config", destinations_configs( - local_filesystem_configs=True, default_staging_configs=True, default_sql_configs=True + local_filesystem_configs=True, + default_staging_configs=True, + default_sql_configs=True, ), ids=lambda x: x.name, ) @@ -257,7 +277,10 @@ def load_items_none(): def test_replace_table_clearing( destination_config: DestinationTestConfiguration, replace_strategy: str ) -> None: - if not destination_config.supports_merge and replace_strategy != "truncate-and-insert": + if ( + not destination_config.supports_merge + and replace_strategy != "truncate-and-insert" + ): pytest.skip( f"Destination {destination_config.name} does not support merge and thus" f" {replace_strategy}" @@ -267,7 +290,9 @@ def test_replace_table_clearing( os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy pipeline = destination_config.setup_pipeline( - "test_replace_table_clearing", dataset_name="test_replace_table_clearing", full_refresh=True + "test_replace_table_clearing", + dataset_name="test_replace_table_clearing", + full_refresh=True, ) @dlt.resource(name="main_resource", write_disposition="replace", primary_key="id") @@ -275,7 +300,10 @@ def items_with_subitems(): data = { "id": 1, "name": "item", - "sub_items": [{"id": 101, "name": "sub item 101"}, {"id": 101, "name": "sub item 102"}], + "sub_items": [ + {"id": 101, "name": "sub item 101"}, + {"id": 101, "name": "sub item 102"}, + ], } yield dlt.mark.with_table_name(data, "items") yield dlt.mark.with_table_name(data, "other_items") @@ -310,7 +338,10 @@ def static_items(): yield { "id": 1, "name": "item", - "sub_items": [{"id": 101, "name": "sub item 101"}, {"id": 101, "name": "sub item 102"}], + "sub_items": [ + {"id": 101, "name": "sub item 101"}, + {"id": 101, "name": "sub item 102"}, + ], } @dlt.resource(name="main_resource", write_disposition="replace", primary_key="id") @@ -329,7 +360,8 @@ def yield_empty_list(): # regular call pipeline.run( - [items_with_subitems, static_items], loader_file_format=destination_config.file_format + [items_with_subitems, static_items], + loader_file_format=destination_config.file_format, ) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] @@ -352,7 +384,9 @@ def yield_empty_list(): } # see if child table gets cleared - pipeline.run(items_without_subitems, loader_file_format=destination_config.file_format) + pipeline.run( + items_without_subitems, loader_file_format=destination_config.file_format + ) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) @@ -363,11 +397,16 @@ def yield_empty_list(): assert table_counts["static_items"] == 1 assert table_counts["static_items__sub_items"] == 2 # check trace - assert pipeline.last_trace.last_normalize_info.row_counts == {"items": 1, "other_items": 1} + assert pipeline.last_trace.last_normalize_info.row_counts == { + "items": 1, + "other_items": 1, + } # see if yield none clears everything for empty_resource in [yield_none, no_yield, yield_empty_list]: - pipeline.run(items_with_subitems, loader_file_format=destination_config.file_format) + pipeline.run( + items_with_subitems, loader_file_format=destination_config.file_format + ) pipeline.run(empty_resource, loader_file_format=destination_config.file_format) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] @@ -379,10 +418,16 @@ def yield_empty_list(): assert table_counts["static_items"] == 1 assert table_counts["static_items__sub_items"] == 2 # check trace - assert pipeline.last_trace.last_normalize_info.row_counts == {"items": 0, "other_items": 0} + assert pipeline.last_trace.last_normalize_info.row_counts == { + "items": 0, + "other_items": 0, + } # see if yielding something next to other none entries still goes into db - pipeline.run(items_with_subitems_yield_none, loader_file_format=destination_config.file_format) + pipeline.run( + items_with_subitems_yield_none, + loader_file_format=destination_config.file_format, + ) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index e50654adcc..9111500133 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -22,7 +22,10 @@ from tests.utils import TEST_STORAGE_ROOT from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_DECODED -from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V9, yml_case_path as common_yml_case_path +from tests.common.utils import ( + IMPORTED_VERSION_HASH_ETH_V9, + yml_case_path as common_yml_case_path, +) from tests.common.configuration.utils import environment from tests.load.pipeline.utils import assert_query_data, drop_active_pipeline_data from tests.load.utils import ( @@ -42,7 +45,9 @@ def duckdb_pipeline_location() -> None: @pytest.mark.parametrize( "destination_config", destinations_configs( - default_staging_configs=True, default_sql_configs=True, default_vector_configs=True + default_staging_configs=True, + default_sql_configs=True, + default_vector_configs=True, ), ids=lambda x: x.name, ) @@ -75,11 +80,17 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - resource.apply_hints( columns={ "_dlt_id": {"name": "_dlt_id", "data_type": "text", "nullable": False}, - "_dlt_load_id": {"name": "_dlt_load_id", "data_type": "text", "nullable": False}, + "_dlt_load_id": { + "name": "_dlt_load_id", + "data_type": "text", + "nullable": False, + }, **STATE_TABLE_COLUMNS, } ) - schema.update_table(schema.normalize_table_identifiers(resource.compute_table_schema())) + schema.update_table( + schema.normalize_table_identifiers(resource.compute_table_schema()) + ) # do not bump version here or in sync_schema, dlt won't recognize that schema changed and it won't update it in storage # so dlt in normalize stage infers _state_version table again but with different column order and the column order in schema is different # then in database. parquet is created in schema order and in Redshift it must exactly match the order. @@ -120,7 +131,9 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - p.normalize(loader_file_format=destination_config.file_format) info = p.load() assert len(info.loads_ids) == 0 - new_stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) + new_stored_state = load_pipeline_state_from_destination( + p.pipeline_name, job_client + ) # new state should not be stored assert new_stored_state == stored_state @@ -133,7 +146,10 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - # version increased assert local_state["_state_version"] + 1 == new_local_state["_state_version"] # last extracted hash does not match current version hash - assert new_local_state_local["_last_extracted_hash"] != new_local_state["_version_hash"] + assert ( + new_local_state_local["_last_extracted_hash"] + != new_local_state["_version_hash"] + ) # use the state context manager again but do not change state # because _last_extracted_hash is not present (or different), the version will not change but state will be extracted anyway @@ -145,16 +161,24 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - # there's extraction timestamp assert "_last_extracted_at" in new_local_state_2_local # and extract hash is == hash - assert new_local_state_2_local["_last_extracted_hash"] == new_local_state_2["_version_hash"] + assert ( + new_local_state_2_local["_last_extracted_hash"] + == new_local_state_2["_version_hash"] + ) # but the version didn't change assert new_local_state["_state_version"] == new_local_state_2["_state_version"] p.normalize(loader_file_format=destination_config.file_format) info = p.load() assert len(info.loads_ids) == 1 - new_stored_state_2 = load_pipeline_state_from_destination(p.pipeline_name, job_client) + new_stored_state_2 = load_pipeline_state_from_destination( + p.pipeline_name, job_client + ) # the stored state changed to next version assert new_stored_state != new_stored_state_2 - assert new_stored_state["_state_version"] + 1 == new_stored_state_2["_state_version"] + assert ( + new_stored_state["_state_version"] + 1 + == new_stored_state_2["_state_version"] + ) @pytest.mark.parametrize( @@ -174,7 +198,9 @@ def test_silently_skip_on_invalid_credentials( pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() # NOTE: we are not restoring the state in __init__ anymore but the test should stay: init should not fail on lack of credentials - destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + destination_config.setup_pipeline( + pipeline_name=pipeline_name, dataset_name=dataset_name + ) @pytest.mark.parametrize( @@ -189,7 +215,9 @@ def test_get_schemas_from_destination( pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() - p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + p = destination_config.setup_pipeline( + pipeline_name=pipeline_name, dataset_name=dataset_name + ) p.config.use_single_dataset = use_single_dataset def _make_dn_name(schema_name: str) -> str: @@ -227,7 +255,9 @@ def _make_dn_name(schema_name: str) -> str: # wipe and restore p._wipe_working_folder() - p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + p = destination_config.setup_pipeline( + pipeline_name=pipeline_name, dataset_name=dataset_name + ) p.config.use_single_dataset = use_single_dataset assert not p.default_schema_name @@ -237,26 +267,34 @@ def _make_dn_name(schema_name: str) -> str: restored_schemas = p._get_schemas_from_destination([], always_download=False) assert restored_schemas == [] # restore unknown schema - restored_schemas = p._get_schemas_from_destination(["_unknown"], always_download=False) + restored_schemas = p._get_schemas_from_destination( + ["_unknown"], always_download=False + ) assert restored_schemas == [] # restore default schema p.default_schema_name = "state" p.schema_names = ["state"] - restored_schemas = p._get_schemas_from_destination(p.schema_names, always_download=False) + restored_schemas = p._get_schemas_from_destination( + p.schema_names, always_download=False + ) assert len(restored_schemas) == 1 assert restored_schemas[0].name == "state" p._schema_storage.save_schema(restored_schemas[0]) assert p._schema_storage.list_schemas() == ["state"] # restore all the rest p.schema_names = ["state", "two", "three"] - restored_schemas = p._get_schemas_from_destination(p.schema_names, always_download=False) + restored_schemas = p._get_schemas_from_destination( + p.schema_names, always_download=False + ) # only two restored schemas, state is already present assert len(restored_schemas) == 2 for schema in restored_schemas: p._schema_storage.save_schema(schema) assert set(p._schema_storage.list_schemas()) == set(p.schema_names) # force download - all three schemas are restored - restored_schemas = p._get_schemas_from_destination(p.schema_names, always_download=True) + restored_schemas = p._get_schemas_from_destination( + p.schema_names, always_download=True + ) assert len(restored_schemas) == 3 @@ -265,11 +303,15 @@ def _make_dn_name(schema_name: str) -> str: destinations_configs(default_sql_configs=True, default_vector_configs=True), ids=lambda x: x.name, ) -def test_restore_state_pipeline(destination_config: DestinationTestConfiguration) -> None: +def test_restore_state_pipeline( + destination_config: DestinationTestConfiguration, +) -> None: os.environ["RESTORE_FROM_DESTINATION"] = "True" pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() - p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + p = destination_config.setup_pipeline( + pipeline_name=pipeline_name, dataset_name=dataset_name + ) def some_data_gen(param: str) -> Any: dlt.current.source_state()[param] = param @@ -318,14 +360,18 @@ def some_data(): # wipe and restore p._wipe_working_folder() os.environ["RESTORE_FROM_DESTINATION"] = "False" - p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + p = destination_config.setup_pipeline( + pipeline_name=pipeline_name, dataset_name=dataset_name + ) p.run(loader_file_format=destination_config.file_format) # restore was not requested so schema is empty assert p.default_schema_name is None p._wipe_working_folder() # request restore os.environ["RESTORE_FROM_DESTINATION"] = "True" - p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + p = destination_config.setup_pipeline( + pipeline_name=pipeline_name, dataset_name=dataset_name + ) p.run(loader_file_format=destination_config.file_format) assert p.default_schema_name == "default" assert set(p.schema_names) == set(["default", "two", "three", "four"]) @@ -353,7 +399,9 @@ def some_data(): # create pipeline without restore os.environ["RESTORE_FROM_DESTINATION"] = "False" - p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + p = destination_config.setup_pipeline( + pipeline_name=pipeline_name, dataset_name=dataset_name + ) # now attach locally os.environ["RESTORE_FROM_DESTINATION"] = "True" p = dlt.attach(pipeline_name=pipeline_name) @@ -389,17 +437,23 @@ def some_data(): destinations_configs(default_sql_configs=True, default_vector_configs=True), ids=lambda x: x.name, ) -def test_ignore_state_unfinished_load(destination_config: DestinationTestConfiguration) -> None: +def test_ignore_state_unfinished_load( + destination_config: DestinationTestConfiguration, +) -> None: pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() - p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) + p = destination_config.setup_pipeline( + pipeline_name=pipeline_name, dataset_name=dataset_name + ) @dlt.resource def some_data(param: str) -> Any: dlt.current.source_state()[param] = param yield param - def complete_package_mock(self, load_id: str, schema: Schema, aborted: bool = False): + def complete_package_mock( + self, load_id: str, schema: Schema, aborted: bool = False + ): # complete in local storage but skip call to the database self.load_storage.complete_load_package(load_id, aborted) @@ -455,7 +509,9 @@ def test_restore_schemas_while_import_schemas_exist( # re-attach the pipeline p = dlt.attach(pipeline_name=pipeline_name) p.run( - ["C", "D", "E"], table_name="annotations", loader_file_format=destination_config.file_format + ["C", "D", "E"], + table_name="annotations", + loader_file_format=destination_config.file_format, ) schema = p.schemas["ethereum"] assert normalized_labels in schema.tables @@ -485,7 +541,9 @@ def test_restore_schemas_while_import_schemas_exist( assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 # extract some data with restored pipeline p.run( - ["C", "D", "E"], table_name="blacklist", loader_file_format=destination_config.file_format + ["C", "D", "E"], + table_name="blacklist", + loader_file_format=destination_config.file_format, ) assert normalized_labels in schema.tables assert normalized_annotations in schema.tables @@ -505,7 +563,9 @@ def test_restore_change_dataset_and_destination(destination_name: str) -> None: destinations_configs(default_sql_configs=True, default_vector_configs=True), ids=lambda x: x.name, ) -def test_restore_state_parallel_changes(destination_config: DestinationTestConfiguration) -> None: +def test_restore_state_parallel_changes( + destination_config: DestinationTestConfiguration, +) -> None: pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() destination_config.setup() @@ -531,7 +591,9 @@ def some_data(param: str) -> Any: orig_state = p.state # create a production pipeline in separate pipelines_dir - production_p = dlt.pipeline(pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT) + production_p = dlt.pipeline( + pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT + ) production_p.run( destination=destination_config.destination, staging=destination_config.staging, @@ -541,7 +603,9 @@ def some_data(param: str) -> Any: assert production_p.default_schema_name == "default" prod_state = production_p.state - assert prod_state["sources"] == {"default": {"state1": "state1", "state2": "state2"}} + assert prod_state["sources"] == { + "default": {"state1": "state1", "state2": "state2"} + } assert prod_state["_state_version"] == orig_state["_state_version"] # generate data on production that modifies the schema but not state data2 = some_data("state1") @@ -569,7 +633,10 @@ def some_data(param: str) -> Any: # print(p.default_schema) p.sync_destination() # existing schema got overwritten - assert normalize("state1_data2") in p._schema_storage.load_schema(p.default_schema_name).tables + assert ( + normalize("state1_data2") + in p._schema_storage.load_schema(p.default_schema_name).tables + ) # print(p.default_schema) assert normalize("state1_data2") in p.default_schema.tables @@ -588,7 +655,9 @@ def some_data(param: str) -> Any: prod_state = production_p.state assert p.state["_state_version"] == prod_state["_state_version"] - 1 # re-attach production and sync - ra_production_p = dlt.attach(pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT) + ra_production_p = dlt.attach( + pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT + ) ra_production_p.sync_destination() # state didn't change because production is ahead of local with its version # nevertheless this is potentially dangerous situation 🤷 @@ -597,13 +666,19 @@ def some_data(param: str) -> Any: # get all the states, notice version 4 twice (one from production, the other from local) try: with p.sql_client() as client: - state_table = client.make_qualified_table_name(p.default_schema.state_table_name) + state_table = client.make_qualified_table_name( + p.default_schema.state_table_name + ) assert_query_data( - p, f"SELECT version FROM {state_table} ORDER BY created_at DESC", [5, 4, 4, 3, 2] + p, + f"SELECT version FROM {state_table} ORDER BY created_at DESC", + [5, 4, 4, 3, 2], ) except SqlClientNotAvailable: - pytest.skip(f"destination {destination_config.destination} does not support sql client") + pytest.skip( + f"destination {destination_config.destination} does not support sql client" + ) @pytest.mark.parametrize( @@ -636,7 +711,9 @@ def some_data(param: str) -> Any: ) data5 = some_data("state4") data5.apply_hints(table_name="state1_data5") - p.run(data5, schema=Schema("sch2"), loader_file_format=destination_config.file_format) + p.run( + data5, schema=Schema("sch2"), loader_file_format=destination_config.file_format + ) assert p.state["_state_version"] == 3 assert p.first_run is False with p.destination_client() as job_client: @@ -672,7 +749,9 @@ def some_data(param: str) -> Any: p.config.restore_from_destination = True data5 = some_data("state4") data5.apply_hints(table_name="state1_data5") - p.run(data5, schema=Schema("sch2"), loader_file_format=destination_config.file_format) + p.run( + data5, schema=Schema("sch2"), loader_file_format=destination_config.file_format + ) # the pipeline was not wiped out, the actual presence if the dataset was checked assert set(p.schema_names) == set(["sch2", "sch1"]) @@ -681,5 +760,7 @@ def prepare_import_folder(p: Pipeline) -> None: os.makedirs(p._schema_storage.config.import_schema_path, exist_ok=True) shutil.copy( common_yml_case_path("schemas/eth/ethereum_schema_v5"), - os.path.join(p._schema_storage.config.import_schema_path, "ethereum.schema.yaml"), + os.path.join( + p._schema_storage.config.import_schema_path, "ethereum.schema.yaml" + ), ) diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index 222171ca61..613c225e48 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -20,11 +20,16 @@ @dlt.resource( - table_name="issues", write_disposition="merge", primary_key="id", merge_key=("node_id", "url") + table_name="issues", + write_disposition="merge", + primary_key="id", + merge_key=("node_id", "url"), ) def load_modified_issues(): with open( - "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + "tests/normalize/cases/github.issues.load_page_5_duck.json", + "r", + encoding="utf-8", ) as f: issues = json.load(f) @@ -39,11 +44,14 @@ def load_modified_issues(): @pytest.mark.parametrize( - "destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(all_staging_configs=True), + ids=lambda x: x.name, ) def test_staging_load(destination_config: DestinationTestConfiguration) -> None: pipeline = destination_config.setup_pipeline( - pipeline_name="test_stage_loading_5", dataset_name="test_staging_load" + uniq_id() + pipeline_name="test_stage_loading_5", + dataset_name="test_staging_load" + uniq_id(), ) info = pipeline.run(github(), loader_file_format=destination_config.file_format) @@ -100,12 +108,16 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: f"SELECT TOP 1 url FROM {qual_name('issues')} WHERE id = 388089021" ) else: - rows = sql_client.execute_sql("SELECT url FROM issues WHERE id = 388089021 LIMIT 1") + rows = sql_client.execute_sql( + "SELECT url FROM issues WHERE id = 388089021 LIMIT 1" + ) assert rows[0][0] == "https://api.github.com/repos/duckdb/duckdb/issues/71" if destination_config.supports_merge: # test merging in some changed values - info = pipeline.run(load_modified_issues, loader_file_format=destination_config.file_format) + info = pipeline.run( + load_modified_issues, loader_file_format=destination_config.file_format + ) assert_load_info(info) assert pipeline.default_schema.tables["issues"]["write_disposition"] == "merge" merge_counts = load_table_counts( @@ -118,10 +130,12 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: if destination_config.destination in ["mssql", "synapse"]: qual_name = sql_client.make_qualified_table_name rows_1 = sql_client.execute_sql( - f"SELECT TOP 1 number FROM {qual_name('issues')} WHERE id = 1232152492" + f"SELECT TOP 1 number FROM {qual_name('issues')} WHERE id =" + " 1232152492" ) rows_2 = sql_client.execute_sql( - f"SELECT TOP 1 number FROM {qual_name('issues')} WHERE id = 1142699354" + f"SELECT TOP 1 number FROM {qual_name('issues')} WHERE id =" + " 1142699354" ) else: rows_1 = sql_client.execute_sql( @@ -163,7 +177,9 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: @pytest.mark.parametrize( - "destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(all_staging_configs=True), + ids=lambda x: x.name, ) def test_all_data_types(destination_config: DestinationTestConfiguration) -> None: pipeline = destination_config.setup_pipeline( @@ -181,16 +197,26 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ) and destination_config.file_format in ("parquet", "jsonl"): # Redshift copy doesn't support TIME column exclude_types.append("time") - if destination_config.destination == "synapse" and destination_config.file_format == "parquet": + if ( + destination_config.destination == "synapse" + and destination_config.file_format == "parquet" + ): # TIME columns are not supported for staged parquet loads into Synapse exclude_types.append("time") - if destination_config.destination == "redshift" and destination_config.file_format in ( - "parquet", - "jsonl", + if ( + destination_config.destination == "redshift" + and destination_config.file_format + in ( + "parquet", + "jsonl", + ) ): # Redshift can't load fixed width binary columns from parquet exclude_columns.append("col7_precision") - if destination_config.destination == "databricks" and destination_config.file_format == "jsonl": + if ( + destination_config.destination == "databricks" + and destination_config.file_format == "jsonl" + ): exclude_types.extend(["decimal", "binary", "wei", "complex", "date"]) exclude_columns.append("col1_precision") @@ -202,7 +228,9 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non if destination_config.file_format == "parquet": if destination_config.destination == "bigquery": # change datatype to text and then allow for it in the assert (parse_complex_strings) - column_schemas["col9_null"]["data_type"] = column_schemas["col9"]["data_type"] = "text" + column_schemas["col9_null"]["data_type"] = column_schemas["col9"][ + "data_type" + ] = "text" # redshift cannot load from json into VARBYTE if destination_config.file_format == "jsonl": if destination_config.destination == "redshift": @@ -212,7 +240,9 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non column_schemas[col]["data_type"] = "text" # apply the exact columns definitions so we process complex and wei types correctly! - @dlt.resource(table_name="data_types", write_disposition="merge", columns=column_schemas) + @dlt.resource( + table_name="data_types", write_disposition="merge", columns=column_schemas + ) def my_resource(): nonlocal data_types yield [data_types] * 10 diff --git a/tests/load/pipeline/test_write_disposition_changes.py b/tests/load/pipeline/test_write_disposition_changes.py index 164243cd55..ff3d3e5687 100644 --- a/tests/load/pipeline/test_write_disposition_changes.py +++ b/tests/load/pipeline/test_write_disposition_changes.py @@ -20,7 +20,9 @@ def data_with_subtables(offset: int) -> Any: @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_switch_from_merge(destination_config: DestinationTestConfiguration): pipeline = destination_config.setup_pipeline( @@ -34,9 +36,9 @@ def test_switch_from_merge(destination_config: DestinationTestConfiguration): loader_file_format=destination_config.file_format, ) assert_data_table_counts(pipeline, {"items": 100, "items__sub_items": 100}) - assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"]["tables"][ - "items" - ] == {"_dlt_id": "_dlt_root_id"} + assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"][ + "tables" + ]["items"] == {"_dlt_id": "_dlt_root_id"} info = pipeline.run( data_with_subtables(10), @@ -52,9 +54,9 @@ def test_switch_from_merge(destination_config: DestinationTestConfiguration): "items__sub_items": 100 if destination_config.supports_merge else 200, }, ) - assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"]["tables"][ - "items" - ] == {"_dlt_id": "_dlt_root_id"} + assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"][ + "tables" + ]["items"] == {"_dlt_id": "_dlt_root_id"} info = pipeline.run( data_with_subtables(10), @@ -70,9 +72,9 @@ def test_switch_from_merge(destination_config: DestinationTestConfiguration): "items__sub_items": 200 if destination_config.supports_merge else 300, }, ) - assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"]["tables"][ - "items" - ] == {"_dlt_id": "_dlt_root_id"} + assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"][ + "tables" + ]["items"] == {"_dlt_id": "_dlt_root_id"} info = pipeline.run( data_with_subtables(10), @@ -82,16 +84,20 @@ def test_switch_from_merge(destination_config: DestinationTestConfiguration): ) assert_load_info(info) assert_data_table_counts(pipeline, {"items": 100, "items__sub_items": 100}) - assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"]["tables"][ - "items" - ] == {"_dlt_id": "_dlt_root_id"} + assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"][ + "tables" + ]["items"] == {"_dlt_id": "_dlt_root_id"} @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) @pytest.mark.parametrize("with_root_key", [True, False]) -def test_switch_to_merge(destination_config: DestinationTestConfiguration, with_root_key: bool): +def test_switch_to_merge( + destination_config: DestinationTestConfiguration, with_root_key: bool +): pipeline = destination_config.setup_pipeline( pipeline_name="test_switch_to_merge", full_refresh=True ) @@ -116,13 +122,13 @@ def source(): assert_data_table_counts(pipeline, {"items": 100, "items__sub_items": 100}) if with_root_key: - assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"][ - "root" - ] == {"_dlt_id": "_dlt_root_id"} + assert pipeline.default_schema._normalizers_config["json"]["config"][ + "propagation" + ]["root"] == {"_dlt_id": "_dlt_root_id"} else: - assert "propagation" not in pipeline.default_schema._normalizers_config["json"].get( - "config", {} - ) + assert "propagation" not in pipeline.default_schema._normalizers_config[ + "json" + ].get("config", {}) # without a root key this will fail, it is expected if not with_root_key and destination_config.supports_merge: @@ -149,6 +155,6 @@ def source(): "items__sub_items": 100 if destination_config.supports_merge else 200, }, ) - assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"]["tables"][ - "items" - ] == {"_dlt_id": "_dlt_root_id"} + assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"][ + "tables" + ]["items"] == {"_dlt_id": "_dlt_root_id"} diff --git a/tests/load/pipeline/utils.py b/tests/load/pipeline/utils.py index 7a5ef02ae6..589d77619b 100644 --- a/tests/load/pipeline/utils.py +++ b/tests/load/pipeline/utils.py @@ -127,7 +127,9 @@ def _assert_table_fs( assert client.fs_client.size(files[0]) > 0 -def select_data(p: dlt.Pipeline, sql: str, schema_name: str = None) -> List[Sequence[Any]]: +def select_data( + p: dlt.Pipeline, sql: str, schema_name: str = None +) -> List[Sequence[Any]]: with p.sql_client(schema_name=schema_name) as c: with c.execute_query(sql) as cur: return list(cur.fetchall()) diff --git a/tests/load/postgres/test_postgres_client.py b/tests/load/postgres/test_postgres_client.py index 896e449b28..52215cc2c5 100644 --- a/tests/load/postgres/test_postgres_client.py +++ b/tests/load/postgres/test_postgres_client.py @@ -3,7 +3,10 @@ import pytest from dlt.common import pendulum, Wei -from dlt.common.configuration.resolve import resolve_configuration, ConfigFieldMissingException +from dlt.common.configuration.resolve import ( + resolve_configuration, + ConfigFieldMissingException, +) from dlt.common.storages import FileStorage from dlt.common.utils import uniq_id @@ -11,7 +14,12 @@ from dlt.destinations.impl.postgres.postgres import PostgresClient from dlt.destinations.impl.postgres.sql_client import psycopg2 -from tests.utils import TEST_STORAGE_ROOT, delete_test_storage, skipifpypy, preserve_environ +from tests.utils import ( + TEST_STORAGE_ROOT, + delete_test_storage, + skipifpypy, + preserve_environ, +) from tests.load.utils import expect_load_file, prepare_table, yield_client_with_storage from tests.common.configuration.utils import environment @@ -37,7 +45,9 @@ def test_postgres_credentials_defaults() -> None: assert pg_cred.connect_timeout == 15 assert PostgresCredentials.__config_gen_annotations__ == ["port", "connect_timeout"] # port should be optional - resolve_configuration(pg_cred, explicit_value="postgres://loader:loader@localhost/DLT_DATA") + resolve_configuration( + pg_cred, explicit_value="postgres://loader:loader@localhost/DLT_DATA" + ) assert pg_cred.port == 5432 # preserve case assert pg_cred.database == "DLT_DATA" @@ -57,7 +67,8 @@ def test_postgres_credentials_native_value(environment) -> None: assert c.password == "pass" # but if password is specified - it is final c = resolve_configuration( - PostgresCredentials(), explicit_value="postgres://loader:loader@localhost/dlt_data" + PostgresCredentials(), + explicit_value="postgres://loader:loader@localhost/dlt_data", ) assert c.is_resolved() assert c.password == "loader" diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index 0ab1343a3b..52d0e68fd2 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -19,9 +19,9 @@ def client(empty_schema: Schema) -> PostgresClient: # return client without opening connection return PostgresClient( empty_schema, - PostgresClientConfiguration(credentials=PostgresCredentials())._bind_dataset_name( - dataset_name="test_" + uniq_id() - ), + PostgresClientConfiguration( + credentials=PostgresCredentials() + )._bind_dataset_name(dataset_name="test_" + uniq_id()), ) diff --git a/tests/load/qdrant/test_pipeline.py b/tests/load/qdrant/test_pipeline.py index 4c1361dcca..b7aafcabf5 100644 --- a/tests/load/qdrant/test_pipeline.py +++ b/tests/load/qdrant/test_pipeline.py @@ -190,14 +190,16 @@ def test_pipeline_merge() -> None: "doc_id": 1, "title": "The Shawshank Redemption", "description": ( - "Two imprisoned men find redemption through acts of decency over the years." + "Two imprisoned men find redemption through acts of decency over the" + " years." ), }, { "doc_id": 2, "title": "The Godfather", "description": ( - "A crime dynasty's aging patriarch transfers control to his reluctant son." + "A crime dynasty's aging patriarch transfers control to his reluctant" + " son." ), }, { @@ -225,7 +227,9 @@ def movies_data(): dataset_name="TestPipelineAppendDataset" + uniq_id(), ) info = pipeline.run( - movies_data(), write_disposition="merge", dataset_name="MoviesDataset" + uniq_id() + movies_data(), + write_disposition="merge", + dataset_name="MoviesDataset" + uniq_id(), ) assert_load_info(info) assert_collection(pipeline, "movies_data", items=data) @@ -302,7 +306,9 @@ def test_merge_github_nested() -> None: assert p.dataset_name.startswith("github1_202") with open( - "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + "tests/normalize/cases/github.issues.load_page_5_duck.json", + "r", + encoding="utf-8", ) as f: data = json.load(f) @@ -348,7 +354,9 @@ def test_empty_dataset_allowed() -> None: client: QdrantClient = p.destination_client() # type: ignore[assignment] assert p.dataset_name is None - info = p.run(qdrant_adapter(["context", "created", "not a stop word"], embed=["value"])) + info = p.run( + qdrant_adapter(["context", "created", "not a stop word"], embed=["value"]) + ) # dataset in load info is empty assert info.dataset_name is None client = p.destination_client() # type: ignore[assignment] diff --git a/tests/load/qdrant/utils.py b/tests/load/qdrant/utils.py index 74d5db9715..d1e2dd72fe 100644 --- a/tests/load/qdrant/utils.py +++ b/tests/load/qdrant/utils.py @@ -39,7 +39,8 @@ def assert_collection( drop_keys = ["_dlt_id", "_dlt_load_id"] objects_without_dlt_keys = [ - {k: v for k, v in point.payload.items() if k not in drop_keys} for point in point_records + {k: v for k, v in point.payload.items() if k not in drop_keys} + for point in point_records ] assert_unordered_list_equal(objects_without_dlt_keys, items) diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index f5efc16a47..3c42920862 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -35,7 +35,9 @@ def test_postgres_and_redshift_credentials_defaults() -> None: assert red_cred.port == 5439 assert red_cred.connect_timeout == 15 assert RedshiftCredentials.__config_gen_annotations__ == ["port", "connect_timeout"] - resolve_configuration(red_cred, explicit_value="postgres://loader:loader@localhost/dlt_data") + resolve_configuration( + red_cred, explicit_value="postgres://loader:loader@localhost/dlt_data" + ) assert red_cred.port == 5439 @@ -53,9 +55,13 @@ def test_text_too_long(client: RedshiftClient, file_storage: FileStorage) -> Non # max_len_str_b = max_len_str.encode("utf-8") # print(len(max_len_str_b)) row_id = uniq_id() - insert_values = f"('{row_id}', '{uniq_id()}', '{max_len_str}' , '{str(pendulum.now())}');" + insert_values = ( + f"('{row_id}', '{uniq_id()}', '{max_len_str}' , '{str(pendulum.now())}');" + ) with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name + ) assert type(exv.value.dbapi_exception) is psycopg2.errors.StringDataRightTruncation @@ -72,7 +78,9 @@ def test_wei_value(client: RedshiftClient, file_storage: FileStorage) -> None: f" '{str(pendulum.now())}', {10**38});" ) with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name + ) assert type(exv.value.dbapi_exception) is psycopg2.errors.InternalError_ @@ -83,7 +91,9 @@ def test_schema_string_exceeds_max_text_length(client: RedshiftClient) -> None: os.path.join(COMMON_TEST_CASES_PATH, "schemas/ev1"), "event", ("json",) ) schema_str = json.dumps(schema.to_dict()) - assert len(schema_str.encode("utf-8")) > client.capabilities.max_text_data_type_length + assert ( + len(schema_str.encode("utf-8")) > client.capabilities.max_text_data_type_length + ) client._update_schema_in_storage(schema) schema_info = client.get_stored_schema() assert schema_info.schema == schema_str @@ -107,13 +117,18 @@ def test_maximum_query_size(client: RedshiftClient, file_storage: FileStorage) - with patch.object(RedshiftClient, "capabilities") as caps: caps.return_value = mocked_caps - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" + insert_sql = ( + "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" + ) insert_values = "('{}', '{}', '90238094809sajlkjxoiewjhduuiuehd', '{}'){}" insert_sql = ( insert_sql - + insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ",\n") * 150000 + + insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ",\n") + * 150000 + ) + insert_sql += insert_values.format( + uniq_id(), uniq_id(), str(pendulum.now()), ";" ) - insert_sql += insert_values.format(uniq_id(), uniq_id(), str(pendulum.now()), ";") user_table_name = prepare_table(client) with pytest.raises(DatabaseTerminalException) as exv: diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index bc132c7818..430484b168 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -20,9 +20,9 @@ def client(empty_schema: Schema) -> RedshiftClient: # return client without opening connection return RedshiftClient( empty_schema, - RedshiftClientConfiguration(credentials=RedshiftCredentials())._bind_dataset_name( - dataset_name="test_" + uniq_id() - ), + RedshiftClientConfiguration( + credentials=RedshiftCredentials() + )._bind_dataset_name(dataset_name="test_" + uniq_id()), ) @@ -36,7 +36,9 @@ def test_redshift_configuration() -> None: "DESTINATION__MY_REDSHIFT__CREDENTIALS__PASSWORD": " pass\n", } ): - C = resolve_configuration(RedshiftCredentials(), sections=("destination", "my_redshift")) + C = resolve_configuration( + RedshiftCredentials(), sections=("destination", "my_redshift") + ) assert C.database == "UPPER_CASE_DATABASE" assert C.password == "pass" @@ -45,14 +47,20 @@ def test_redshift_configuration() -> None: # based on host c = resolve_configuration( RedshiftCredentials(), - explicit_value="postgres://user1:pass@host1/db1?warehouse=warehouse1&role=role1", + explicit_value=( + "postgres://user1:pass@host1/db1?warehouse=warehouse1&role=role1" + ), + ) + assert RedshiftClientConfiguration(credentials=c).fingerprint() == digest128( + "host1" ) - assert RedshiftClientConfiguration(credentials=c).fingerprint() == digest128("host1") def test_create_table(client: RedshiftClient) -> None: # non existing table - sql = ";".join(client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)) + sql = ";".join( + client._get_table_update_sql("event_test_table", TABLE_UPDATE, False) + ) sqlfluff.parse(sql, dialect="redshift") assert "event_test_table" in sql assert '"col1" bigint NOT NULL' in sql diff --git a/tests/load/snowflake/test_snowflake_configuration.py b/tests/load/snowflake/test_snowflake_configuration.py index d0ca4de41b..441005c74a 100644 --- a/tests/load/snowflake/test_snowflake_configuration.py +++ b/tests/load/snowflake/test_snowflake_configuration.py @@ -40,7 +40,9 @@ def test_connection_string_with_all_params() -> None: def test_to_connector_params() -> None: # PEM key - pkey_str = Path("./tests/common/cases/secrets/encrypted-private-key").read_text("utf8") + pkey_str = Path("./tests/common/cases/secrets/encrypted-private-key").read_text( + "utf8" + ) creds = SnowflakeCredentials() creds.private_key = pkey_str # type: ignore[assignment] @@ -66,7 +68,9 @@ def test_to_connector_params() -> None: ) # base64 encoded DER key - pkey_str = Path("./tests/common/cases/secrets/encrypted-private-key-base64").read_text("utf8") + pkey_str = Path( + "./tests/common/cases/secrets/encrypted-private-key-base64" + ).read_text("utf8") creds = SnowflakeCredentials() creds.private_key = pkey_str # type: ignore[assignment] @@ -96,7 +100,9 @@ def test_snowflake_credentials_native_value(environment) -> None: with pytest.raises(ConfigurationValueError): resolve_configuration( SnowflakeCredentials(), - explicit_value="snowflake://user1@host1/db1?warehouse=warehouse1&role=role1", + explicit_value=( + "snowflake://user1@host1/db1?warehouse=warehouse1&role=role1" + ), ) # set password via env os.environ["CREDENTIALS__PASSWORD"] = "pass" @@ -109,7 +115,9 @@ def test_snowflake_credentials_native_value(environment) -> None: # # but if password is specified - it is final c = resolve_configuration( SnowflakeCredentials(), - explicit_value="snowflake://user1:pass1@host1/db1?warehouse=warehouse1&role=role1", + explicit_value=( + "snowflake://user1:pass1@host1/db1?warehouse=warehouse1&role=role1" + ), ) assert c.is_resolved() assert c.password == "pass1" @@ -131,6 +139,10 @@ def test_snowflake_configuration() -> None: # based on host c = resolve_configuration( SnowflakeCredentials(), - explicit_value="snowflake://user1:pass@host1/db1?warehouse=warehouse1&role=role1", + explicit_value=( + "snowflake://user1:pass@host1/db1?warehouse=warehouse1&role=role1" + ), + ) + assert SnowflakeClientConfiguration(credentials=c).fingerprint() == digest128( + "host1" ) - assert SnowflakeClientConfiguration(credentials=c).fingerprint() == digest128("host1") diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index 5d7108803e..9402f8e9ba 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -28,7 +28,9 @@ def snowflake_client(empty_schema: Schema) -> SnowflakeClient: def test_create_table(snowflake_client: SnowflakeClient) -> None: - statements = snowflake_client._get_table_update_sql("event_test_table", TABLE_UPDATE, False) + statements = snowflake_client._get_table_update_sql( + "event_test_table", TABLE_UPDATE, False + ) assert len(statements) == 1 sql = statements[0] sqlfluff.parse(sql, dialect="snowflake") @@ -48,7 +50,9 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None: def test_alter_table(snowflake_client: SnowflakeClient) -> None: - statements = snowflake_client._get_table_update_sql("event_test_table", TABLE_UPDATE, True) + statements = snowflake_client._get_table_update_sql( + "event_test_table", TABLE_UPDATE, True + ) assert len(statements) == 1 sql = statements[0] @@ -78,13 +82,17 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: assert '"COL2" FLOAT NOT NULL' in sql -def test_create_table_with_partition_and_cluster(snowflake_client: SnowflakeClient) -> None: +def test_create_table_with_partition_and_cluster( + snowflake_client: SnowflakeClient, +) -> None: mod_update = deepcopy(TABLE_UPDATE) # timestamp mod_update[3]["partition"] = True mod_update[4]["cluster"] = True mod_update[1]["cluster"] = True - statements = snowflake_client._get_table_update_sql("event_test_table", mod_update, False) + statements = snowflake_client._get_table_update_sql( + "event_test_table", mod_update, False + ) assert len(statements) == 1 sql = statements[0] diff --git a/tests/load/synapse/test_synapse_table_builder.py b/tests/load/synapse/test_synapse_table_builder.py index 8575835820..d989839998 100644 --- a/tests/load/synapse/test_synapse_table_builder.py +++ b/tests/load/synapse/test_synapse_table_builder.py @@ -121,9 +121,9 @@ def test_create_table_with_column_hint( assert f" {attr} " not in sql # Case: table with hint, client has indexes enabled. - sql = client_with_indexes_enabled._get_table_update_sql("event_test_table", mod_update, False)[ - 0 - ] + sql = client_with_indexes_enabled._get_table_update_sql( + "event_test_table", mod_update, False + )[0] # We expect an error because "PRIMARY KEY NONCLUSTERED NOT ENFORCED" and # "UNIQUE NOT ENFORCED" are invalid in the generic "tsql" dialect. # They are however valid in the Synapse variant of the dialect. diff --git a/tests/load/synapse/test_synapse_table_indexing.py b/tests/load/synapse/test_synapse_table_indexing.py index df90933de4..f4a3bc99c2 100644 --- a/tests/load/synapse/test_synapse_table_indexing.py +++ b/tests/load/synapse/test_synapse_table_indexing.py @@ -84,7 +84,9 @@ def items_with_table_index_type_specified() -> Iterator[Any]: yield TABLE_ROW_ALL_DATA_TYPES pipeline.run( - synapse_adapter(items_with_table_index_type_specified, "clustered_columnstore_index") + synapse_adapter( + items_with_table_index_type_specified, "clustered_columnstore_index" + ) ) applied_table_index_type = get_storage_table_index_type( job_client.sql_client, # type: ignore[attr-defined] @@ -131,7 +133,9 @@ def items_with_table_index_type_specified() -> Iterator[Any]: pipeline.run(synapse_adapter(items_with_table_index_type_specified, "foo")) # type: ignore[arg-type] # Run the pipeline and create the tables. - pipeline.run(synapse_adapter(items_with_table_index_type_specified, table_index_type)) + pipeline.run( + synapse_adapter(items_with_table_index_type_specified, table_index_type) + ) # For all tables, assert the applied index type equals the expected index type. # Child tables, if any, inherit the index type of their parent. diff --git a/tests/load/synapse/utils.py b/tests/load/synapse/utils.py index cd53716878..86d086ec64 100644 --- a/tests/load/synapse/utils.py +++ b/tests/load/synapse/utils.py @@ -5,7 +5,9 @@ from dlt.destinations.impl.synapse.synapse_adapter import TTableIndexType -def get_storage_table_index_type(sql_client: SynapseSqlClient, table_name: str) -> TTableIndexType: +def get_storage_table_index_type( + sql_client: SynapseSqlClient, table_name: str +) -> TTableIndexType: """Returns table index type of table in storage destination.""" with sql_client: schema_name = sql_client.fully_qualified_dataset_name(escape=False) diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index c5e4f874fc..f313a7df14 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -16,7 +16,9 @@ get_top_level_table, ) -from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration +from dlt.destinations.impl.filesystem.configuration import ( + FilesystemDestinationClientConfiguration, +) from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations import dummy, filesystem from dlt.destinations.impl.dummy import dummy as dummy_impl @@ -24,7 +26,11 @@ from dlt.load import Load from dlt.load.exceptions import LoadClientJobFailed, LoadClientJobRetry -from dlt.load.utils import get_completed_table_chain, init_client, _extend_tables_with_table_chain +from dlt.load.utils import ( + get_completed_table_chain, + init_client, + _extend_tables_with_table_chain, +) from tests.utils import ( clean_test_storage, @@ -42,7 +48,9 @@ "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl", ] -REMOTE_FILESYSTEM = os.path.abspath(os.path.join(TEST_STORAGE_ROOT, "_remote_filesystem")) +REMOTE_FILESYSTEM = os.path.abspath( + os.path.join(TEST_STORAGE_ROOT, "_remote_filesystem") +) @pytest.fixture(autouse=True) @@ -121,7 +129,9 @@ def test_get_completed_table_chain_single_job_per_table() -> None: # update tables so we have all possible hints for table_name, table in schema.tables.items(): - schema.tables[table_name] = fill_hints_from_parent_and_clone_table(schema.tables, table) + schema.tables[table_name] = fill_hints_from_parent_and_clone_table( + schema.tables, table + ) top_job_table = get_top_level_table(schema.tables, "event_user") all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) @@ -153,7 +163,10 @@ def test_get_completed_table_chain_single_job_per_table() -> None: schema.get_table("event_loop_interrupted") ] assert get_completed_table_chain( - schema, all_jobs, loop_top_job_table, "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl" + schema, + all_jobs, + loop_top_job_table, + "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl", ) == [schema.get_table("event_loop_interrupted")] @@ -184,7 +197,9 @@ def test_spool_job_failed() -> None: ) assert load.load_storage.normalized_packages.storage.has_file( load.load_storage.normalized_packages.get_job_file_path( - load_id, PackageStorage.FAILED_JOBS_FOLDER, job.file_name() + ".exception" + load_id, + PackageStorage.FAILED_JOBS_FOLDER, + job.file_name() + ".exception", ) ) started_files = load.load_storage.normalized_packages.list_started_jobs(load_id) @@ -204,7 +219,9 @@ def test_spool_job_failed_exception_init() -> None: # this config fails job on start os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" os.environ["FAIL_IN_INIT"] = "true" - load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=True)) + load = setup_loader( + client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=True) + ) load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with patch.object(dummy_impl.DummyClient, "complete_load") as complete_load: with pytest.raises(LoadClientJobFailed) as py_ex: @@ -223,7 +240,9 @@ def test_spool_job_failed_exception_complete() -> None: # this config fails job on start os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" os.environ["FAIL_IN_INIT"] = "false" - load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=False)) + load = setup_loader( + client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=False) + ) load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with pytest.raises(LoadClientJobFailed) as py_ex: run_all(load) @@ -334,7 +353,9 @@ def test_completed_loop_followup_jobs() -> None: # TODO: until we fix how we create capabilities we must set env os.environ["CREATE_FOLLOWUP_JOBS"] = "true" load = setup_loader( - client_config=DummyClientConfiguration(completed_prob=1.0, create_followup_jobs=True) + client_config=DummyClientConfiguration( + completed_prob=1.0, create_followup_jobs=True + ) ) assert_complete_job(load) # for each JOB there's REFERENCE JOB @@ -345,7 +366,8 @@ def test_completed_loop_followup_jobs() -> None: def test_failed_loop() -> None: # ask to delete completed load = setup_loader( - delete_completed_jobs=True, client_config=DummyClientConfiguration(fail_prob=1.0) + delete_completed_jobs=True, + client_config=DummyClientConfiguration(fail_prob=1.0), ) # actually not deleted because one of the jobs failed assert_complete_job(load, should_delete_completed=False) @@ -463,7 +485,8 @@ def test_wrong_writer_type() -> None: def test_extend_table_chain() -> None: load = setup_loader() _, schema = prepare_load_package( - load.load_storage, ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"] + load.load_storage, + ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"], ) # only event user table (no other jobs) tables = _extend_tables_with_table_chain(schema, ["event_user"], ["event_user"]) @@ -473,18 +496,24 @@ def test_extend_table_chain() -> None: schema, ["event_user"], ["event_user", "event_user__parse_data__entities"] ) assert tables == {"event_user", "event_user__parse_data__entities"} - user_chain = {name for name in schema.data_table_names() if name.startswith("event_user__")} | { - "event_user" - } + user_chain = { + name for name in schema.data_table_names() if name.startswith("event_user__") + } | {"event_user"} # change event user to merge/replace to get full table chain for w_d in ["merge", "replace"]: - schema.tables["event_user"]["write_disposition"] = w_d # type:ignore[typeddict-item] + schema.tables["event_user"][ + "write_disposition" + ] = w_d # type:ignore[typeddict-item] tables = _extend_tables_with_table_chain(schema, ["event_user"], ["event_user"]) assert tables == user_chain # no jobs for bot - assert _extend_tables_with_table_chain(schema, ["event_bot"], ["event_user"]) == set() + assert ( + _extend_tables_with_table_chain(schema, ["event_bot"], ["event_user"]) == set() + ) # skip unseen tables - del schema.tables["event_user__parse_data__entities"][ # type:ignore[typeddict-item] + del schema.tables[ + "event_user__parse_data__entities" + ][ # type:ignore[typeddict-item] "x-normalizer" ] entities_chain = { @@ -496,11 +525,16 @@ def test_extend_table_chain() -> None: assert tables == user_chain - {"event_user__parse_data__entities"} # exclude the whole chain tables = _extend_tables_with_table_chain( - schema, ["event_user"], ["event_user"], lambda table: table["name"] not in entities_chain + schema, + ["event_user"], + ["event_user"], + lambda table: table["name"] not in entities_chain, ) assert tables == user_chain - entities_chain # ask for tables that are not top - tables = _extend_tables_with_table_chain(schema, ["event_user__parse_data__entities"], []) + tables = _extend_tables_with_table_chain( + schema, ["event_user__parse_data__entities"], [] + ) # user chain but without entities (not seen data) assert tables == user_chain - {"event_user__parse_data__entities"} # go to append and ask only for entities chain @@ -514,7 +548,9 @@ def test_extend_table_chain() -> None: # add multiple chains bot_jobs = {"event_bot", "event_bot__data__buttons"} tables = _extend_tables_with_table_chain( - schema, ["event_user__parse_data__entities", "event_bot"], entities_chain | bot_jobs + schema, + ["event_user__parse_data__entities", "event_bot"], + entities_chain | bot_jobs, ) assert tables == (entities_chain | bot_jobs) - {"event_user__parse_data__entities"} @@ -522,12 +558,15 @@ def test_extend_table_chain() -> None: def test_get_completed_table_chain_cases() -> None: load = setup_loader() _, schema = prepare_load_package( - load.load_storage, ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"] + load.load_storage, + ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"], ) # update tables so we have all possible hints for table_name, table in schema.tables.items(): - schema.tables[table_name] = fill_hints_from_parent_and_clone_table(schema.tables, table) + schema.tables[table_name] = fill_hints_from_parent_and_clone_table( + schema.tables, table + ) # child completed, parent not event_user = schema.get_table("event_user") @@ -548,11 +587,16 @@ def test_get_completed_table_chain_cases() -> None: None, 0, ParsedLoadJobFileName( - "event_user__parse_data__entities", "event_user__parse_data__entities_id", 0, "jsonl" + "event_user__parse_data__entities", + "event_user__parse_data__entities_id", + 0, + "jsonl", ), None, ) - chain = get_completed_table_chain(schema, [event_user_job, event_user_entities_job], event_user) + chain = get_completed_table_chain( + schema, [event_user_job, event_user_entities_job], event_user + ) assert chain is None # parent just got completed @@ -567,7 +611,9 @@ def test_get_completed_table_chain_cases() -> None: # parent failed, child completed chain = get_completed_table_chain( - schema, [event_user_job._replace(state="failed_jobs"), event_user_entities_job], event_user + schema, + [event_user_job._replace(state="failed_jobs"), event_user_entities_job], + event_user, ) assert chain == [event_user, event_user_entities] @@ -617,7 +663,8 @@ def test_get_completed_table_chain_cases() -> None: def test_init_client_truncate_tables() -> None: load = setup_loader() _, schema = prepare_load_package( - load.load_storage, ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"] + load.load_storage, + ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"], ) nothing_ = lambda _: False @@ -626,8 +673,12 @@ def test_init_client_truncate_tables() -> None: event_user = ParsedLoadJobFileName("event_user", "event_user_id", 0, "jsonl") event_bot = ParsedLoadJobFileName("event_bot", "event_bot_id", 0, "jsonl") - with patch.object(dummy_impl.DummyClient, "initialize_storage") as initialize_storage: - with patch.object(dummy_impl.DummyClient, "update_stored_schema") as update_stored_schema: + with patch.object( + dummy_impl.DummyClient, "initialize_storage" + ) as initialize_storage: + with patch.object( + dummy_impl.DummyClient, "update_stored_schema" + ) as update_stored_schema: with load.get_destination_client(schema) as client: init_client(client, schema, [], {}, nothing_, nothing_) # we do not allow for any staging dataset tables @@ -639,7 +690,9 @@ def test_init_client_truncate_tables() -> None: assert initialize_storage.call_count == 2 # initialize storage is called twice, we deselected all tables to truncate assert initialize_storage.call_args_list[0].args == () - assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == set() + assert ( + initialize_storage.call_args_list[1].kwargs["truncate_tables"] == set() + ) initialize_storage.reset_mock() update_stored_schema.reset_mock() @@ -651,7 +704,9 @@ def test_init_client_truncate_tables() -> None: assert "event_user" in update_stored_schema.call_args[1]["only_tables"] assert initialize_storage.call_count == 2 assert initialize_storage.call_args_list[0].args == () - assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == {"event_user"} + assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == { + "event_user" + } # now we push all to stage initialize_storage.reset_mock() @@ -670,7 +725,9 @@ def test_init_client_truncate_tables() -> None: ) assert initialize_storage.call_count == 4 assert initialize_storage.call_args_list[0].args == () - assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == set() + assert ( + initialize_storage.call_args_list[1].kwargs["truncate_tables"] == set() + ) assert initialize_storage.call_args_list[2].args == () # all tables that will be used on staging must be truncated assert initialize_storage.call_args_list[3].kwargs["truncate_tables"] == { @@ -690,15 +747,23 @@ def test_init_client_truncate_tables() -> None: bot["write_disposition"] = w_d # type:ignore[typeddict-item] # merge goes to staging, replace goes to truncate with load.get_destination_client(schema) as client: - init_client(client, schema, [event_user, event_bot], {}, replace_, merge_) + init_client( + client, schema, [event_user, event_bot], {}, replace_, merge_ + ) if w_d == "merge": # we use staging dataset assert update_stored_schema.call_count == 2 # 4 tables to update in main dataset - assert len(update_stored_schema.call_args_list[0].kwargs["only_tables"]) == 4 assert ( - "event_user" in update_stored_schema.call_args_list[0].kwargs["only_tables"] + len( + update_stored_schema.call_args_list[0].kwargs["only_tables"] + ) + == 4 + ) + assert ( + "event_user" + in update_stored_schema.call_args_list[0].kwargs["only_tables"] ) # full bot table chain + dlt version but no user assert len( @@ -706,14 +771,21 @@ def test_init_client_truncate_tables() -> None: ) == 1 + len(bot_chain) assert ( "event_user" - not in update_stored_schema.call_args_list[1].kwargs["only_tables"] + not in update_stored_schema.call_args_list[1].kwargs[ + "only_tables" + ] ) assert initialize_storage.call_count == 4 - assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == set() + assert ( + initialize_storage.call_args_list[1].kwargs["truncate_tables"] + == set() + ) assert initialize_storage.call_args_list[3].kwargs[ "truncate_tables" - ] == update_stored_schema.call_args_list[1].kwargs["only_tables"] - { + ] == update_stored_schema.call_args_list[1].kwargs[ + "only_tables" + ] - { "_dlt_version" } @@ -725,14 +797,20 @@ def test_init_client_truncate_tables() -> None: initialize_storage.call_args_list[1].kwargs["truncate_tables"] ) == len(bot_chain) # migrate only tables for which we have jobs - assert len(update_stored_schema.call_args_list[0].kwargs["only_tables"]) == 4 + assert ( + len( + update_stored_schema.call_args_list[0].kwargs["only_tables"] + ) + == 4 + ) # print(initialize_storage.call_args_list) # print(update_stored_schema.call_args_list) def test_dummy_staging_filesystem() -> None: load = setup_loader( - client_config=DummyClientConfiguration(completed_prob=1.0), filesystem_staging=True + client_config=DummyClientConfiguration(completed_prob=1.0), + filesystem_staging=True, ) assert_complete_job(load) # two reference jobs @@ -780,10 +858,14 @@ def assert_complete_job(load: Load, should_delete_completed: bool = False) -> No ) if should_delete_completed: # package was deleted - assert not load.load_storage.loaded_packages.storage.has_folder(completed_path) + assert not load.load_storage.loaded_packages.storage.has_folder( + completed_path + ) else: # package not deleted - assert load.load_storage.loaded_packages.storage.has_folder(completed_path) + assert load.load_storage.loaded_packages.storage.has_folder( + completed_path + ) # complete load on client was called complete_load.assert_called_once_with(load_id) @@ -805,16 +887,22 @@ def setup_loader( # reset jobs for a test dummy_impl.JOBS = {} dummy_impl.CREATED_FOLLOWUP_JOBS = {} - client_config = client_config or DummyClientConfiguration(loader_file_format="jsonl") + client_config = client_config or DummyClientConfiguration( + loader_file_format="jsonl" + ) destination: TDestination = dummy(**client_config) # type: ignore[assignment] # setup staging_system_config = None staging = None if filesystem_staging: # do not accept jsonl to not conflict with filesystem destination - client_config = client_config or DummyClientConfiguration(loader_file_format="reference") - staging_system_config = FilesystemDestinationClientConfiguration()._bind_dataset_name( - dataset_name="dummy" + client_config = client_config or DummyClientConfiguration( + loader_file_format="reference" + ) + staging_system_config = ( + FilesystemDestinationClientConfiguration()._bind_dataset_name( + dataset_name="dummy" + ) ) staging_system_config.as_staging = True os.makedirs(REMOTE_FILESYSTEM) @@ -822,7 +910,9 @@ def setup_loader( # patch destination to provide client_config # destination.client = lambda schema: dummy_impl.DummyClient(schema, client_config) # setup loader - with TEST_DICT_CONFIG_PROVIDER().values({"delete_completed_jobs": delete_completed_jobs}): + with TEST_DICT_CONFIG_PROVIDER().values( + {"delete_completed_jobs": delete_completed_jobs} + ): return Load( destination, initial_client_config=client_config, diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 1c79b733e5..efbec9be6b 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -46,22 +46,34 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}')" ) - expect_load_file(client, file_storage, insert_sql + insert_values + ";", user_table_name) - rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] + expect_load_file( + client, file_storage, insert_sql + insert_values + ";", user_table_name + ) + rows_count = client.sql_client.execute_sql( + f"SELECT COUNT(1) FROM {canonical_name}" + )[0][0] assert rows_count == 1 # insert 100 more rows query = insert_sql + (insert_values + ",\n") * 99 + insert_values + ";" expect_load_file(client, file_storage, query, user_table_name) - rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] + rows_count = client.sql_client.execute_sql( + f"SELECT COUNT(1) FROM {canonical_name}" + )[0][0] assert rows_count == 101 # insert null value - insert_sql_nc = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, text)\nVALUES\n" + insert_sql_nc = ( + "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, text)\nVALUES\n" + ) insert_values_nc = ( f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', NULL);" ) - expect_load_file(client, file_storage, insert_sql_nc + insert_values_nc, user_table_name) - rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {canonical_name}")[0][0] + expect_load_file( + client, file_storage, insert_sql_nc + insert_values_nc, user_table_name + ) + rows_count = client.sql_client.execute_sql( + f"SELECT COUNT(1) FROM {canonical_name}" + )[0][0] assert rows_count == 102 @@ -72,7 +84,9 @@ def test_simple_load(client: InsertValuesJobClient, file_storage: FileStorage) - indirect=True, ids=lambda x: x.name, ) -def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage) -> None: +def test_loading_errors( + client: InsertValuesJobClient, file_storage: FileStorage +) -> None: # test expected dbiapi exceptions for supported destinations import duckdb from dlt.destinations.impl.postgres.sql_client import psycopg2 @@ -91,29 +105,42 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage user_table_name = prepare_table(client) # insert into unknown column - insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, _unk_)\nVALUES\n" + insert_sql = ( + "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, _unk_)\nVALUES\n" + ) insert_values = ( f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', NULL);" ) with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name + ) assert type(exv.value.dbapi_exception) is TUndefinedColumn # insert null value insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', NULL);" + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', NULL);" + ) with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name + ) assert type(exv.value.dbapi_exception) is TNotNullViolation # insert wrong type insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" - insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', TRUE);" + insert_values = ( + f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', TRUE);" + ) with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name + ) assert type(exv.value.dbapi_exception) is TDatatypeMismatch # numeric overflow on bigint insert_sql = ( - "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, metadata__rasa_x_id)\nVALUES\n" + "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp," + " metadata__rasa_x_id)\nVALUES\n" ) # 2**64//2 - 1 is a maximum bigint value insert_values = ( @@ -121,7 +148,9 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage f" '{str(pendulum.now())}', {2**64//2});" ) with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name + ) assert type(exv.value.dbapi_exception) in (TNumericValueOutOfRange,) # numeric overflow on NUMERIC insert_sql = ( @@ -144,7 +173,9 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage f" '{str(pendulum.now())}', {above_limit});" ) with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name + ) assert type(exv.value.dbapi_exception) in ( TNumericValueOutOfRange, psycopg2.errors.InternalError_, @@ -223,7 +254,9 @@ def assert_load_with_max_query( user_table_name = prepare_table(client) insert_sql = prepare_insert_statement(insert_lines) expect_load_file(client, file_storage, insert_sql, user_table_name) - rows_count = client.sql_client.execute_sql(f"SELECT COUNT(1) FROM {user_table_name}")[0][0] + rows_count = client.sql_client.execute_sql( + f"SELECT COUNT(1) FROM {user_table_name}" + )[0][0] assert rows_count == insert_lines # get all uniq ids in order with client.sql_client.execute_query( @@ -242,7 +275,9 @@ def prepare_insert_statement(lines: int) -> str: for i in range(lines): # id_ = uniq_id() # ids.append(id_) - insert_sql += insert_values.format(str(i), uniq_id(), str(pendulum.now().add(seconds=i))) + insert_sql += insert_values.format( + str(i), uniq_id(), str(pendulum.now().add(seconds=i)) + ) if i < 9: insert_sql += ",\n" else: diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 2e23086f81..9bb5e9bfa8 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -57,7 +57,10 @@ def client(request) -> Iterator[SqlJobClientBase]: @pytest.mark.order(1) @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_initialize_storage(client: SqlJobClientBase) -> None: pass @@ -65,7 +68,10 @@ def test_initialize_storage(client: SqlJobClientBase) -> None: @pytest.mark.order(2) @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_get_schema_on_empty_storage(client: SqlJobClientBase) -> None: # test getting schema on empty dataset without any tables @@ -79,13 +85,20 @@ def test_get_schema_on_empty_storage(client: SqlJobClientBase) -> None: @pytest.mark.order(3) @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_get_update_basic_schema(client: SqlJobClientBase) -> None: schema = client.schema schema_update = client.update_stored_schema() # expect dlt tables in schema update - assert set(schema_update.keys()) == {VERSION_TABLE_NAME, LOADS_TABLE_NAME, "event_slot"} + assert set(schema_update.keys()) == { + VERSION_TABLE_NAME, + LOADS_TABLE_NAME, + "event_slot", + } # event_bot and event_user are not present because they have no columns # check is event slot has variant assert schema_update["event_slot"]["columns"]["value"]["variant"] is True @@ -157,7 +170,10 @@ def test_get_update_basic_schema(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_complete_load(client: SqlJobClientBase) -> None: client.update_stored_schema() @@ -188,7 +204,9 @@ def test_complete_load(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( "client", - destinations_configs(default_sql_configs=True, subset=["redshift", "postgres", "duckdb"]), + destinations_configs( + default_sql_configs=True, subset=["redshift", "postgres", "duckdb"] + ), indirect=True, ids=lambda x: x.name, ) @@ -205,7 +223,9 @@ def test_schema_update_create_table_redshift(client: SqlJobClientBase) -> None: # this will be not null record_hash = schema._infer_column("_dlt_id", "m,i0392903jdlkasjdlk") assert record_hash["unique"] is True - schema.update_table(new_table(table_name, columns=[timestamp, sender_id, record_hash])) + schema.update_table( + new_table(table_name, columns=[timestamp, sender_id, record_hash]) + ) schema._bump_version() schema_update = client.update_stored_schema() # check hints in schema update @@ -232,7 +252,9 @@ def test_schema_update_create_table_bigquery(client: SqlJobClientBase) -> None: sender_id = schema._infer_column("sender_id", "982398490809324") # this will be not null record_hash = schema._infer_column("_dlt_id", "m,i0392903jdlkasjdlk") - schema.update_table(new_table("event_test_table", columns=[timestamp, sender_id, record_hash])) + schema.update_table( + new_table("event_test_table", columns=[timestamp, sender_id, record_hash]) + ) schema._bump_version() schema_update = client.update_stored_schema() # check hints in schema update @@ -250,7 +272,10 @@ def test_schema_update_create_table_bigquery(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_schema_update_alter_table(client: SqlJobClientBase) -> None: # force to update schema in chunks by setting the max query size to 10 bytes/chars @@ -290,12 +315,17 @@ def test_schema_update_alter_table(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_drop_tables(client: SqlJobClientBase) -> None: schema = client.schema # Add columns in all tables - schema.tables["event_user"]["columns"] = dict(schema.tables["event_slot"]["columns"]) + schema.tables["event_user"]["columns"] = dict( + schema.tables["event_slot"]["columns"] + ) schema.tables["event_bot"]["columns"] = dict(schema.tables["event_slot"]["columns"]) schema._bump_version() client.update_stored_schema() @@ -357,7 +387,10 @@ def test_drop_tables(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: schema = client.schema @@ -388,15 +421,23 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: ): continue # mssql and synapse have no native data type for the complex type. - if client.config.destination_type in ("mssql", "synapse") and c["data_type"] in ("complex"): + if client.config.destination_type in ("mssql", "synapse") and c[ + "data_type" + ] in ("complex"): continue - if client.config.destination_type == "databricks" and c["data_type"] in ("complex", "time"): + if client.config.destination_type == "databricks" and c["data_type"] in ( + "complex", + "time", + ): continue assert c["data_type"] == expected_c["data_type"] @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_preserve_column_order(client: SqlJobClientBase) -> None: schema = client.schema @@ -421,23 +462,35 @@ def _assert_columns_order(sql_: str) -> None: idx = sql_.find(col_name, idx) assert idx > 0, f"column {col_name} not found in script" - sql = ";".join(client._get_table_update_sql(table_name, columns, generate_alter=False)) + sql = ";".join( + client._get_table_update_sql(table_name, columns, generate_alter=False) + ) _assert_columns_order(sql) - sql = ";".join(client._get_table_update_sql(table_name, columns, generate_alter=True)) + sql = ";".join( + client._get_table_update_sql(table_name, columns, generate_alter=True) + ) _assert_columns_order(sql) @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_data_writer_load(client: SqlJobClientBase, file_storage: FileStorage) -> None: if not client.capabilities.preferred_loader_file_format: - pytest.skip("preferred loader file format not set, destination will only work with staging") + pytest.skip( + "preferred loader file format not set, destination will only work with" + " staging" + ) rows, table_name = prepare_schema(client, "simple_row") canonical_name = client.sql_client.make_qualified_table_name(table_name) # write only first row with io.BytesIO() as f: - write_dataset(client, f, [rows[0]], client.schema.get_table(table_name)["columns"]) + write_dataset( + client, f, [rows[0]], client.schema.get_table(table_name)["columns"] + ) query = f.getvalue().decode() expect_load_file(client, file_storage, query, table_name) db_row = client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0] @@ -445,7 +498,9 @@ def test_data_writer_load(client: SqlJobClientBase, file_storage: FileStorage) - assert list(db_row) == list(rows[0].values()) # write second row that contains two nulls with io.BytesIO() as f: - write_dataset(client, f, [rows[1]], client.schema.get_table(table_name)["columns"]) + write_dataset( + client, f, [rows[1]], client.schema.get_table(table_name)["columns"] + ) query = f.getvalue().decode() expect_load_file(client, file_storage, query, table_name) db_row = client.sql_client.execute_sql( @@ -456,11 +511,19 @@ def test_data_writer_load(client: SqlJobClientBase, file_storage: FileStorage) - @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) -def test_data_writer_string_escape(client: SqlJobClientBase, file_storage: FileStorage) -> None: +def test_data_writer_string_escape( + client: SqlJobClientBase, file_storage: FileStorage +) -> None: if not client.capabilities.preferred_loader_file_format: - pytest.skip("preferred loader file format not set, destination will only work with staging") + pytest.skip( + "preferred loader file format not set, destination will only work with" + " staging" + ) rows, table_name = prepare_schema(client, "simple_row") canonical_name = client.sql_client.make_qualified_table_name(table_name) row = rows[0] @@ -468,7 +531,9 @@ def test_data_writer_string_escape(client: SqlJobClientBase, file_storage: FileS inj_str = f", NULL'); DROP TABLE {canonical_name} --" row["f_str"] = inj_str with io.BytesIO() as f: - write_dataset(client, f, [rows[0]], client.schema.get_table(table_name)["columns"]) + write_dataset( + client, f, [rows[0]], client.schema.get_table(table_name)["columns"] + ) query = f.getvalue().decode() expect_load_file(client, file_storage, query, table_name) db_row = client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0] @@ -476,13 +541,19 @@ def test_data_writer_string_escape(client: SqlJobClientBase, file_storage: FileS @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_data_writer_string_escape_edge( client: SqlJobClientBase, file_storage: FileStorage ) -> None: if not client.capabilities.preferred_loader_file_format: - pytest.skip("preferred loader file format not set, destination will only work with staging") + pytest.skip( + "preferred loader file format not set, destination will only work with" + " staging" + ) rows, table_name = prepare_schema(client, "weird_rows") canonical_name = client.sql_client.make_qualified_table_name(table_name) with io.BytesIO() as f: @@ -490,28 +561,42 @@ def test_data_writer_string_escape_edge( query = f.getvalue().decode() expect_load_file(client, file_storage, query, table_name) for i in range(1, len(rows) + 1): - db_row = client.sql_client.execute_sql(f"SELECT str FROM {canonical_name} WHERE idx = {i}") + db_row = client.sql_client.execute_sql( + f"SELECT str FROM {canonical_name} WHERE idx = {i}" + ) row_value, expected = db_row[0][0], rows[i - 1]["str"] assert row_value == expected @pytest.mark.parametrize("write_disposition", ["append", "replace"]) @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_load_with_all_types( - client: SqlJobClientBase, write_disposition: TWriteDisposition, file_storage: FileStorage + client: SqlJobClientBase, + write_disposition: TWriteDisposition, + file_storage: FileStorage, ) -> None: if not client.capabilities.preferred_loader_file_format: - pytest.skip("preferred loader file format not set, destination will only work with staging") + pytest.skip( + "preferred loader file format not set, destination will only work with" + " staging" + ) table_name = "event_test_table" + uniq_id() column_schemas, data_types = table_update_and_row( - exclude_types=["time"] if client.config.destination_type == "databricks" else None, + exclude_types=( + ["time"] if client.config.destination_type == "databricks" else None + ), ) # we should have identical content with all disposition types client.schema.update_table( new_table( - table_name, write_disposition=write_disposition, columns=list(column_schemas.values()) + table_name, + write_disposition=write_disposition, + columns=list(column_schemas.values()), ) ) client.schema._bump_version() @@ -548,7 +633,10 @@ def test_load_with_all_types( ], ) @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_write_dispositions( client: SqlJobClientBase, @@ -557,7 +645,10 @@ def test_write_dispositions( file_storage: FileStorage, ) -> None: if not client.capabilities.preferred_loader_file_format: - pytest.skip("preferred loader file format not set, destination will only work with staging") + pytest.skip( + "preferred loader file format not set, destination will only work with" + " staging" + ) os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy table_name = "event_test_table" + uniq_id() @@ -602,7 +693,8 @@ def test_write_dispositions( expect_load_file(client, file_storage, query, t) db_rows = list( client.sql_client.execute_sql( - f"SELECT * FROM {client.sql_client.make_qualified_table_name(t)} ORDER BY" + "SELECT * FROM" + f" {client.sql_client.make_qualified_table_name(t)} ORDER BY" " col1 ASC" ) ) @@ -620,7 +712,8 @@ def test_write_dispositions( with client.sql_client.with_staging_dataset(staging=True): db_rows = list( client.sql_client.execute_sql( - f"SELECT * FROM {client.sql_client.make_qualified_table_name(t)} ORDER" + "SELECT * FROM" + f" {client.sql_client.make_qualified_table_name(t)} ORDER" " BY col1 ASC" ) ) @@ -630,11 +723,17 @@ def test_write_dispositions( @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> None: if not client.capabilities.preferred_loader_file_format: - pytest.skip("preferred loader file format not set, destination will only work with staging") + pytest.skip( + "preferred loader file format not set, destination will only work with" + " staging" + ) user_table_name = prepare_table(client) load_json = { "_dlt_id": uniq_id(), @@ -643,7 +742,9 @@ def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> No "timestamp": str(pendulum.now()), } with io.BytesIO() as f: - write_dataset(client, f, [load_json], client.schema.get_table(user_table_name)["columns"]) + write_dataset( + client, f, [load_json], client.schema.get_table(user_table_name)["columns"] + ) dataset = f.getvalue().decode() job = expect_load_file(client, file_storage, dataset, user_table_name) # now try to retrieve the job @@ -656,9 +757,13 @@ def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> No @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_default_schema_name_init_storage(destination_config: DestinationTestConfiguration) -> None: +def test_default_schema_name_init_storage( + destination_config: DestinationTestConfiguration, +) -> None: with cm_yield_client_with_storage( destination_config.destination, default_config_values={ @@ -673,9 +778,7 @@ def test_default_schema_name_init_storage(destination_config: DestinationTestCon with cm_yield_client_with_storage( destination_config.destination, default_config_values={ - "default_schema_name": ( - None # no default_schema. that should create dataset with the name `dataset_name` - ) + "default_schema_name": None # no default_schema. that should create dataset with the name `dataset_name` }, ) as client: assert client.sql_client.dataset_name == client.config.dataset_name @@ -694,7 +797,9 @@ def test_default_schema_name_init_storage(destination_config: DestinationTestCon @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) def test_many_schemas_single_dataset( destination_config: DestinationTestConfiguration, file_storage: FileStorage @@ -710,24 +815,32 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: "timestamp": str(pendulum.now()), } with io.BytesIO() as f: - write_dataset(_client, f, [user_row], _client.schema.tables["event_user"]["columns"]) + write_dataset( + _client, f, [user_row], _client.schema.tables["event_user"]["columns"] + ) query = f.getvalue().decode() expect_load_file(_client, file_storage, query, "event_user") qual_table_name = _client.sql_client.make_qualified_table_name("event_user") - db_rows = list(_client.sql_client.execute_sql(f"SELECT * FROM {qual_table_name}")) + db_rows = list( + _client.sql_client.execute_sql(f"SELECT * FROM {qual_table_name}") + ) assert len(db_rows) == expected_rows with cm_yield_client_with_storage( - destination_config.destination, default_config_values={"default_schema_name": None} + destination_config.destination, + default_config_values={"default_schema_name": None}, ) as client: # event schema with event table if not client.capabilities.preferred_loader_file_format: pytest.skip( - "preferred loader file format not set, destination will only work with staging" + "preferred loader file format not set, destination will only work with" + " staging" ) user_table = load_table("event_user")["event_user"] - client.schema.update_table(new_table("event_user", columns=list(user_table.values()))) + client.schema.update_table( + new_table("event_user", columns=list(user_table.values())) + ) client.schema._bump_version() schema_update = client.update_stored_schema() assert len(schema_update) > 0 @@ -757,7 +870,9 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: # use third schema where one of the fields is non null, but the field exists so it is ignored schema_dict["name"] = "event_3" event_3_schema = Schema.from_stored_schema(schema_dict) - event_3_schema.tables["event_user"]["columns"]["input_channel"]["nullable"] = False + event_3_schema.tables["event_user"]["columns"]["input_channel"][ + "nullable" + ] = False # swap schemas in client instance client.schema = event_3_schema client.schema._bump_version() @@ -781,11 +896,15 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: ) -def prepare_schema(client: SqlJobClientBase, case: str) -> Tuple[List[Dict[str, Any]], str]: +def prepare_schema( + client: SqlJobClientBase, case: str +) -> Tuple[List[Dict[str, Any]], str]: client.update_stored_schema() rows = load_json_case(case) # use first row to infer table - table: TTableSchemaColumns = {k: client.schema._infer_column(k, v) for k, v in rows[0].items()} + table: TTableSchemaColumns = { + k: client.schema._infer_column(k, v) for k, v in rows[0].items() + } table_name = f"event_{case}_{uniq_id()}" client.schema.update_table(new_table(table_name, columns=list(table.values()))) client.schema._bump_version() diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index d82925a7d3..2d49df8b3e 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -65,7 +65,10 @@ def test_sql_client_default_dataset_unqualified(client: SqlJobClientBase) -> Non @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_malformed_query_parameters(client: SqlJobClientBase) -> None: client.update_stored_schema() @@ -104,7 +107,10 @@ def test_malformed_query_parameters(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_malformed_execute_parameters(client: SqlJobClientBase) -> None: client.update_stored_schema() @@ -140,7 +146,10 @@ def test_malformed_execute_parameters(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_execute_sql(client: SqlJobClientBase) -> None: client.update_stored_schema() @@ -154,7 +163,9 @@ def test_execute_sql(client: SqlJobClientBase) -> None: assert len(rows) == 1 assert rows[0][0] == "event" rows = client.sql_client.execute_sql( - f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE schema_name = %s", "event" + f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE schema_name" + " = %s", + "event", ) assert len(rows) == 1 # print(rows) @@ -165,29 +176,33 @@ def test_execute_sql(client: SqlJobClientBase) -> None: # print(type(rows[0][1])) # convert to pendulum to make sure it is supported by dbapi rows = client.sql_client.execute_sql( - f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s", + f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at" + " = %s", ensure_pendulum_datetime(rows[0][1]), ) assert len(rows) == 1 # use rows in subsequent test if client.sql_client.dbapi.paramstyle == "pyformat": rows = client.sql_client.execute_sql( - f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at =" - " %(date)s", + f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE" + " inserted_at = %(date)s", date=rows[0][1], ) assert len(rows) == 1 assert rows[0][0] == "event" rows = client.sql_client.execute_sql( - f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at =" - " %(date)s", + f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE" + " inserted_at = %(date)s", date=pendulum.now().add(seconds=1), ) assert len(rows) == 0 @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_execute_ddl(client: SqlJobClientBase) -> None: uniq_suffix = uniq_id() @@ -199,13 +214,18 @@ def test_execute_ddl(client: SqlJobClientBase) -> None: assert rows[0][0] == Decimal("1.0") # create view, note that bigquery will not let you execute a view that does not have fully qualified table names. view_name = client.sql_client.make_qualified_table_name(f"view_tmp_{uniq_suffix}") - client.sql_client.execute_sql(f"CREATE VIEW {view_name} AS (SELECT * FROM {f_q_table_name});") + client.sql_client.execute_sql( + f"CREATE VIEW {view_name} AS (SELECT * FROM {f_q_table_name});" + ) rows = client.sql_client.execute_sql(f"SELECT * FROM {view_name}") assert rows[0][0] == Decimal("1.0") @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_execute_query(client: SqlJobClientBase) -> None: client.update_stored_schema() @@ -217,29 +237,33 @@ def test_execute_query(client: SqlJobClientBase) -> None: assert len(rows) == 1 assert rows[0][0] == "event" with client.sql_client.execute_query( - f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE schema_name = %s", "event" + f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE schema_name" + " = %s", + "event", ) as curr: rows = curr.fetchall() assert len(rows) == 1 assert rows[0][0] == "event" assert isinstance(rows[0][1], datetime.datetime) with client.sql_client.execute_query( - f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s", + f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at" + " = %s", rows[0][1], ) as curr: rows = curr.fetchall() assert len(rows) == 1 assert rows[0][0] == "event" with client.sql_client.execute_query( - f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s", + f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at" + " = %s", pendulum.now().add(seconds=1), ) as curr: rows = curr.fetchall() assert len(rows) == 0 if client.sql_client.dbapi.paramstyle == "pyformat": with client.sql_client.execute_query( - f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at =" - " %(date)s", + f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE" + " inserted_at = %(date)s", date=pendulum.now().add(seconds=1), ) as curr: rows = curr.fetchall() @@ -247,7 +271,10 @@ def test_execute_query(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_execute_df(client: SqlJobClientBase) -> None: if client.config.destination_type == "bigquery": @@ -268,7 +295,9 @@ def test_execute_df(client: SqlJobClientBase) -> None: insert_query = ",".join([f"({idx})" for idx in range(0, total_records)]) sql_stmt = f"INSERT INTO {f_q_table_name} VALUES {insert_query};" elif client.capabilities.insert_values_writer_type == "select_union": - insert_query = " UNION ALL ".join([f"SELECT {idx}" for idx in range(0, total_records)]) + insert_query = " UNION ALL ".join( + [f"SELECT {idx}" for idx in range(0, total_records)] + ) sql_stmt = f"INSERT INTO {f_q_table_name} {insert_query};" client.sql_client.execute_sql(sql_stmt) @@ -298,14 +327,19 @@ def test_execute_df(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_database_exceptions(client: SqlJobClientBase) -> None: client.update_stored_schema() term_ex: Any # invalid table with pytest.raises(DatabaseUndefinedRelation) as term_ex: - with client.sql_client.execute_query("SELECT * FROM TABLE_XXX ORDER BY inserted_at"): + with client.sql_client.execute_query( + "SELECT * FROM TABLE_XXX ORDER BY inserted_at" + ): pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) with pytest.raises(DatabaseUndefinedRelation) as term_ex: @@ -331,7 +365,9 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: # invalid syntax with pytest.raises(DatabaseTransientException) as term_ex: - with client.sql_client.execute_query("SELECTA * FROM TABLE_XXX ORDER BY inserted_at"): + with client.sql_client.execute_query( + "SELECTA * FROM TABLE_XXX ORDER BY inserted_at" + ): pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) # invalid column @@ -357,7 +393,9 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) with pytest.raises(DatabaseUndefinedRelation) as term_ex: - with client.sql_client.execute_query(f"DELETE FROM {qualified_name} WHERE 1=1"): + with client.sql_client.execute_query( + f"DELETE FROM {qualified_name} WHERE 1=1" + ): pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) with pytest.raises(DatabaseUndefinedRelation) as term_ex: @@ -376,7 +414,9 @@ def test_commit_transaction(client: SqlJobClientBase) -> None: table_name = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) with client.sql_client.begin_transaction(): - client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0")) + client.sql_client.execute_sql( + f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0") + ) # check row still in transaction rows = client.sql_client.execute_sql( f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") @@ -429,7 +469,9 @@ def test_rollback_transaction(client: SqlJobClientBase) -> None: assert len(rows) == 0 # test rollback on invalid query - f_q_wrong_table_name = client.sql_client.make_qualified_table_name(f"{table_name}_X") + f_q_wrong_table_name = client.sql_client.make_qualified_table_name( + f"{table_name}_X" + ) with pytest.raises(DatabaseException): with client.sql_client.begin_transaction(): client.sql_client.execute_sql( @@ -446,7 +488,9 @@ def test_rollback_transaction(client: SqlJobClientBase) -> None: # test explicit rollback with client.sql_client.begin_transaction() as tx: - client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0")) + client.sql_client.execute_sql( + f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0") + ) tx.rollback_transaction() rows = client.sql_client.execute_sql( f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") @@ -482,11 +526,15 @@ def test_thread(thread_id: Decimal) -> None: ) with thread_client: with thread_client.begin_transaction(): - thread_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", thread_id) + thread_client.execute_sql( + f"INSERT INTO {f_q_table_name} VALUES (%s)", thread_id + ) event.wait() with client.sql_client.begin_transaction() as tx: - client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0")) + client.sql_client.execute_sql( + f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0") + ) t = Thread(target=test_thread, daemon=True, args=(Decimal("2.0"),)) t.start() # thread 2.0 inserts @@ -501,19 +549,25 @@ def test_thread(thread_id: Decimal) -> None: client.sql_client.close_connection() # re open connection client.sql_client.open_connection() - rows = client.sql_client.execute_sql(f"SELECT col FROM {f_q_table_name} ORDER BY col") + rows = client.sql_client.execute_sql( + f"SELECT col FROM {f_q_table_name} ORDER BY col" + ) assert len(rows) == 1 # only thread 2 is left assert rows[0][0] == Decimal("2.0") @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_max_table_identifier_length(client: SqlJobClientBase) -> None: if client.capabilities.max_identifier_length >= 65536: pytest.skip( - f"destination {client.config.destination_type} has no table name length restriction" + f"destination {client.config.destination_type} has no table name length" + " restriction" ) table_name = ( 8 @@ -538,12 +592,16 @@ def test_max_table_identifier_length(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True), + indirect=True, + ids=lambda x: x.name, ) def test_max_column_identifier_length(client: SqlJobClientBase) -> None: if client.capabilities.max_column_identifier_length >= 65536: pytest.skip( - f"destination {client.config.destination_type} has no column name length restriction" + f"destination {client.config.destination_type} has no column name length" + " restriction" ) table_name = "prospects_external_data__data365_member__member" column_name = ( @@ -584,7 +642,11 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: assert_load_id(client.sql_client, "ABC") # syntax error within tx - statements = ["BEGIN TRANSACTION;", f"INVERT INTO {version_table} VALUES(1);", "COMMIT;"] + statements = [ + "BEGIN TRANSACTION;", + f"INVERT INTO {version_table} VALUES(1);", + "COMMIT;", + ] with pytest.raises(DatabaseTransientException): client.sql_client.execute_many(statements) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "ProgrammingError") @@ -613,7 +675,9 @@ def assert_load_id(sql_client: SqlClientBase[TNativeConn], load_id: str) -> None sql_client.close_connection() sql_client.open_connection() loads_table = sql_client.make_qualified_table_name(LOADS_TABLE_NAME) - rows = sql_client.execute_sql(f"SELECT load_id FROM {loads_table} WHERE load_id = %s", load_id) + rows = sql_client.execute_sql( + f"SELECT load_id FROM {loads_table} WHERE load_id = %s", load_id + ) assert len(rows) == 1 @@ -624,8 +688,8 @@ def prepare_temp_table(client: SqlJobClientBase) -> str: coltype = "numeric" if client.config.destination_type == "athena": iceberg_table_suffix = ( - f"LOCATION '{AWS_BUCKET}/ci/{table_name}' TBLPROPERTIES ('table_type'='ICEBERG'," - " 'format'='parquet');" + f"LOCATION '{AWS_BUCKET}/ci/{table_name}' TBLPROPERTIES" + " ('table_type'='ICEBERG', 'format'='parquet');" ) coltype = "bigint" qualified_table_name = table_name diff --git a/tests/load/utils.py b/tests/load/utils.py index 8c5eda6d3b..0124559e02 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -2,7 +2,18 @@ import contextlib import codecs import os -from typing import Any, Iterator, List, Sequence, IO, Tuple, Optional, Dict, Union, Generator +from typing import ( + Any, + Iterator, + List, + Sequence, + IO, + Tuple, + Optional, + Dict, + Union, + Generator, +) import shutil from pathlib import Path from dataclasses import dataclass @@ -67,7 +78,9 @@ # Filter out buckets not in all filesystem drivers DEFAULT_BUCKETS = [GCS_BUCKET, AWS_BUCKET, FILE_BUCKET, MEMORY_BUCKET, AZ_BUCKET] DEFAULT_BUCKETS = [ - bucket for bucket in DEFAULT_BUCKETS if bucket.split(":")[0] in ALL_FILESYSTEM_DRIVERS + bucket + for bucket in DEFAULT_BUCKETS + if bucket.split(":")[0] in ALL_FILESYSTEM_DRIVERS ] # temporary solution to include gdrive bucket in tests, @@ -135,7 +148,11 @@ def setup(self) -> None: os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" def setup_pipeline( - self, pipeline_name: str, dataset_name: str = None, full_refresh: bool = False, **kwargs + self, + pipeline_name: str, + dataset_name: str = None, + full_refresh: bool = False, + **kwargs, ) -> dlt.Pipeline: """Convenience method to setup pipeline with this configuration""" self.setup() @@ -165,7 +182,9 @@ def destinations_configs( ) -> List[DestinationTestConfiguration]: # sanity check for item in subset: - assert item in IMPLEMENTED_DESTINATIONS, f"Destination {item} is not implemented" + assert ( + item in IMPLEMENTED_DESTINATIONS + ), f"Destination {item} is not implemented" # build destination configs destination_configs: List[DestinationTestConfiguration] = [] @@ -348,7 +367,9 @@ def destinations_configs( if local_filesystem_configs: destination_configs += [ DestinationTestConfiguration( - destination="filesystem", bucket_url=FILE_BUCKET, file_format="insert_values" + destination="filesystem", + bucket_url=FILE_BUCKET, + file_format="insert_values", ) ] destination_configs += [ @@ -377,7 +398,9 @@ def destinations_configs( # filter out destinations not in subset if subset: - destination_configs = [conf for conf in destination_configs if conf.destination in subset] + destination_configs = [ + conf for conf in destination_configs if conf.destination in subset + ] if exclude: destination_configs = [ conf for conf in destination_configs if conf.destination not in exclude @@ -392,7 +415,9 @@ def destinations_configs( ] if supports_merge is not None: destination_configs = [ - conf for conf in destination_configs if conf.supports_merge == supports_merge + conf + for conf in destination_configs + if conf.supports_merge == supports_merge ] if supports_dbt is not None: destination_configs = [ @@ -401,7 +426,9 @@ def destinations_configs( # filter out excluded configs destination_configs = [ - conf for conf in destination_configs if conf.name not in EXCLUDED_DESTINATION_CONFIGURATIONS + conf + for conf in destination_configs + if conf.name not in EXCLUDED_DESTINATION_CONFIGURATIONS ] return destination_configs @@ -420,8 +447,8 @@ def get_normalized_dataset_name(client: JobClientBase) -> str: return client.config.normalize_dataset_name(client.schema) else: raise TypeError( - f"{type(client)} client has configuration {type(client.config)} that does not support" - " dataset name" + f"{type(client)} client has configuration {type(client.config)} that does" + " not support dataset name" ) @@ -445,7 +472,9 @@ def expect_load_file( ).file_name() file_storage.save(file_name, query.encode("utf-8")) table = client.prepare_load_table(table_name) - job = client.start_file_load(table, file_storage.make_full_path(file_name), uniq_id()) + job = client.start_file_load( + table, file_storage.make_full_path(file_name), uniq_id() + ) while job.state() == "running": sleep(0.5) assert job.file_name() == file_name @@ -466,7 +495,9 @@ def prepare_table( user_table_name = table_name + uniq_id() else: user_table_name = table_name - client.schema.update_table(new_table(user_table_name, columns=list(user_table.values()))) + client.schema.update_table( + new_table(user_table_name, columns=list(user_table.values())) + ) client.schema._bump_version() client.update_stored_schema() return user_table_name @@ -531,11 +562,15 @@ def cm_yield_client( default_config_values: StrAny = None, schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: - return yield_client(destination_type, dataset_name, default_config_values, schema_name) + return yield_client( + destination_type, dataset_name, default_config_values, schema_name + ) def yield_client_with_storage( - destination_type: str, default_config_values: StrAny = None, schema_name: str = "event" + destination_type: str, + default_config_values: StrAny = None, + schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: # create dataset with random name dataset_name = "test_" + uniq_id() @@ -562,9 +597,13 @@ def delete_dataset(client: SqlClientBase[Any], normalized_dataset_name: str) -> @contextlib.contextmanager def cm_yield_client_with_storage( - destination_type: str, default_config_values: StrAny = None, schema_name: str = "event" + destination_type: str, + default_config_values: StrAny = None, + schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: - return yield_client_with_storage(destination_type, default_config_values, schema_name) + return yield_client_with_storage( + destination_type, default_config_values, schema_name + ) def write_dataset( @@ -609,7 +648,9 @@ def prepare_load_package( full_package_path = load_storage.new_packages.storage.make_full_path( load_storage.new_packages.get_package_path(load_id) ) - Path(full_package_path).joinpath(schema_path.name).write_text(json.dumps(data), encoding="utf8") + Path(full_package_path).joinpath(schema_path.name).write_text( + json.dumps(data), encoding="utf8" + ) schema_update_path = "./tests/load/cases/loading/schema_updates.json" shutil.copy(schema_update_path, full_package_path) diff --git a/tests/load/weaviate/test_naming.py b/tests/load/weaviate/test_naming.py index 290879cb67..633390ab9b 100644 --- a/tests/load/weaviate/test_naming.py +++ b/tests/load/weaviate/test_naming.py @@ -1,7 +1,9 @@ import dlt, pytest from dlt.destinations.impl.weaviate.naming import NamingConvention -from dlt.destinations.impl.weaviate.ci_naming import NamingConvention as CINamingConvention +from dlt.destinations.impl.weaviate.ci_naming import ( + NamingConvention as CINamingConvention, +) from tests.common.utils import load_yml_case diff --git a/tests/load/weaviate/test_pipeline.py b/tests/load/weaviate/test_pipeline.py index dc23644940..719dd97f37 100644 --- a/tests/load/weaviate/test_pipeline.py +++ b/tests/load/weaviate/test_pipeline.py @@ -8,7 +8,10 @@ from dlt.destinations.impl.weaviate import weaviate_adapter from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict -from dlt.destinations.impl.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT +from dlt.destinations.impl.weaviate.weaviate_adapter import ( + VECTORIZE_HINT, + TOKENIZATION_HINT, +) from dlt.destinations.impl.weaviate.weaviate_client import WeaviateClient from dlt.pipeline.exceptions import PipelineStepFailed @@ -170,7 +173,9 @@ def some_data(): write_disposition="replace", ) assert_load_info(info) - assert info.dataset_name == "TestPipelineReplaceDataset" + uid # normalized internally + assert ( + info.dataset_name == "TestPipelineReplaceDataset" + uid + ) # normalized internally data = next(generator_instance2) assert_class(pipeline, "SomeData", items=data) @@ -191,14 +196,16 @@ def test_pipeline_merge() -> None: "doc_id": 1, "title": "The Shawshank Redemption", "description": ( - "Two imprisoned men find redemption through acts of decency over the years." + "Two imprisoned men find redemption through acts of decency over the" + " years." ), }, { "doc_id": 2, "title": "The Godfather", "description": ( - "A crime dynasty's aging patriarch transfers control to his reluctant son." + "A crime dynasty's aging patriarch transfers control to his reluctant" + " son." ), }, { @@ -304,13 +311,17 @@ def test_merge_github_nested() -> None: assert p.dataset_name.startswith("github1_202") with open( - "tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + "tests/normalize/cases/github.issues.load_page_5_duck.json", + "r", + encoding="utf-8", ) as f: data = json.load(f) info = p.run( weaviate_adapter( - data[:17], vectorize=["title", "body"], tokenization={"user__login": "lowercase"} + data[:17], + vectorize=["title", "body"], + tokenization={"user__login": "lowercase"}, ), table_name="issues", write_disposition="merge", @@ -356,7 +367,9 @@ def test_empty_dataset_allowed() -> None: pytest.skip("skip to avoid race condition with other tests") assert p.dataset_name is None - info = p.run(weaviate_adapter(["context", "created", "not a stop word"], vectorize=["value"])) + info = p.run( + weaviate_adapter(["context", "created", "not a stop word"], vectorize=["value"]) + ) # dataset in load info is empty assert info.dataset_name is None client = p.destination_client() # type: ignore[assignment] diff --git a/tests/load/weaviate/test_weaviate_client.py b/tests/load/weaviate/test_weaviate_client.py index 3f966c2330..f54ebe1ddf 100644 --- a/tests/load/weaviate/test_weaviate_client.py +++ b/tests/load/weaviate/test_weaviate_client.py @@ -6,7 +6,11 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.utils import uniq_id -from dlt.common.schema.typing import TWriteDisposition, TColumnSchema, TTableSchemaColumns +from dlt.common.schema.typing import ( + TWriteDisposition, + TColumnSchema, + TTableSchemaColumns, +) from dlt.destinations import weaviate from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict @@ -69,7 +73,9 @@ def file_storage() -> FileStorage: @pytest.mark.parametrize("write_disposition", ["append", "replace", "merge"]) def test_all_data_types( - client: WeaviateClient, write_disposition: TWriteDisposition, file_storage: FileStorage + client: WeaviateClient, + write_disposition: TWriteDisposition, + file_storage: FileStorage, ) -> None: class_name = "AllTypes" # we should have identical content with all disposition types @@ -81,7 +87,9 @@ def test_all_data_types( # write row with io.BytesIO() as f: - write_dataset(client, f, [TABLE_ROW_ALL_DATA_TYPES], TABLE_UPDATE_COLUMNS_SCHEMA) + write_dataset( + client, f, [TABLE_ROW_ALL_DATA_TYPES], TABLE_UPDATE_COLUMNS_SCHEMA + ) query = f.getvalue().decode() expect_load_file(client, file_storage, query, class_name) _, table_columns = client.get_storage_table("AllTypes") @@ -89,7 +97,11 @@ def test_all_data_types( assert len(table_columns) == len(TABLE_UPDATE_COLUMNS_SCHEMA) for col_name in table_columns: assert col_name in TABLE_UPDATE_COLUMNS_SCHEMA - if TABLE_UPDATE_COLUMNS_SCHEMA[col_name]["data_type"] in ["decimal", "complex", "time"]: + if TABLE_UPDATE_COLUMNS_SCHEMA[col_name]["data_type"] in [ + "decimal", + "complex", + "time", + ]: # no native representation assert table_columns[col_name]["data_type"] == "text" elif TABLE_UPDATE_COLUMNS_SCHEMA[col_name]["data_type"] == "wei": @@ -111,7 +123,9 @@ def test_case_sensitive_properties_create(client: WeaviateClient) -> None: {"name": "coL1", "data_type": "double", "nullable": False}, ] client.schema.update_table( - client.schema.normalize_table_identifiers(new_table(class_name, columns=table_create)) + client.schema.normalize_table_identifiers( + new_table(class_name, columns=table_create) + ) ) client.schema._bump_version() with pytest.raises(PropertyNameConflict): @@ -126,7 +140,9 @@ def test_case_insensitive_properties_create(ci_client: WeaviateClient) -> None: {"name": "coL1", "data_type": "double", "nullable": False}, ] ci_client.schema.update_table( - ci_client.schema.normalize_table_identifiers(new_table(class_name, columns=table_create)) + ci_client.schema.normalize_table_identifiers( + new_table(class_name, columns=table_create) + ) ) ci_client.schema._bump_version() ci_client.update_stored_schema() @@ -138,18 +154,24 @@ def test_case_insensitive_properties_create(ci_client: WeaviateClient) -> None: def test_case_sensitive_properties_add(client: WeaviateClient) -> None: class_name = "col_class" # we have two properties which will map to the same name in Weaviate - table_create: List[TColumnSchema] = [{"name": "col1", "data_type": "bigint", "nullable": False}] + table_create: List[TColumnSchema] = [ + {"name": "col1", "data_type": "bigint", "nullable": False} + ] table_update: List[TColumnSchema] = [ {"name": "coL1", "data_type": "double", "nullable": False}, ] client.schema.update_table( - client.schema.normalize_table_identifiers(new_table(class_name, columns=table_create)) + client.schema.normalize_table_identifiers( + new_table(class_name, columns=table_create) + ) ) client.schema._bump_version() client.update_stored_schema() client.schema.update_table( - client.schema.normalize_table_identifiers(new_table(class_name, columns=table_update)) + client.schema.normalize_table_identifiers( + new_table(class_name, columns=table_update) + ) ) client.schema._bump_version() with pytest.raises(PropertyNameConflict): @@ -159,7 +181,9 @@ def test_case_sensitive_properties_add(client: WeaviateClient) -> None: # print(table_columns) -def test_load_case_sensitive_data(client: WeaviateClient, file_storage: FileStorage) -> None: +def test_load_case_sensitive_data( + client: WeaviateClient, file_storage: FileStorage +) -> None: class_name = "col_class" # we have two properties which will map to the same name in Weaviate table_create: TTableSchemaColumns = { @@ -178,7 +202,9 @@ def test_load_case_sensitive_data(client: WeaviateClient, file_storage: FileStor expect_load_file(client, file_storage, query, class_name) -def test_load_case_sensitive_data_ci(ci_client: WeaviateClient, file_storage: FileStorage) -> None: +def test_load_case_sensitive_data_ci( + ci_client: WeaviateClient, file_storage: FileStorage +) -> None: class_name = "col_class" # we have two properties which will map to the same name in Weaviate table_create: TTableSchemaColumns = { diff --git a/tests/load/weaviate/utils.py b/tests/load/weaviate/utils.py index 1b2a74fcb8..8e661ec795 100644 --- a/tests/load/weaviate/utils.py +++ b/tests/load/weaviate/utils.py @@ -7,7 +7,10 @@ from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.weaviate.weaviate_client import WeaviateClient -from dlt.destinations.impl.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT +from dlt.destinations.impl.weaviate.weaviate_adapter import ( + VECTORIZE_HINT, + TOKENIZATION_HINT, +) def assert_unordered_list_equal(list1: List[Any], list2: List[Any]) -> None: @@ -45,7 +48,9 @@ def assert_class( assert prop["tokenization"] == column[TOKENIZATION_HINT] # type: ignore[literal-required] # if there's a single vectorize hint, class must have vectorizer enabled - if get_columns_names_with_prop(pipeline.default_schema.get_table(class_name), VECTORIZE_HINT): + if get_columns_names_with_prop( + pipeline.default_schema.get_table(class_name), VECTORIZE_HINT + ): assert schema["vectorizer"] == vectorizer_name else: assert schema["vectorizer"] == "none" diff --git a/tests/normalize/mock_rasa_json_normalizer.py b/tests/normalize/mock_rasa_json_normalizer.py index f911c55493..bcea367583 100644 --- a/tests/normalize/mock_rasa_json_normalizer.py +++ b/tests/normalize/mock_rasa_json_normalizer.py @@ -1,5 +1,7 @@ from dlt.common.normalizers.json import TNormalizedRowIterator -from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer +from dlt.common.normalizers.json.relational import ( + DataItemNormalizer as RelationalNormalizer, +) from dlt.common.schema import Schema from dlt.common.typing import TDataItem diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 39a18c5de2..b5e1ab541d 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -12,7 +12,12 @@ from dlt.common.utils import uniq_id from dlt.common.typing import StrAny from dlt.common.data_types import TDataType -from dlt.common.storages import NormalizeStorage, LoadStorage, ParsedLoadJobFileName, PackageStorage +from dlt.common.storages import ( + NormalizeStorage, + LoadStorage, + ParsedLoadJobFileName, + PackageStorage, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.configuration.container import Container @@ -101,7 +106,9 @@ def test_normalize_single_user_event_jsonl( for expected_table in expected_tables: get_line_from_file(raw_normalize.load_storage, load_files[expected_table]) # return first line from event_user file - event_text, lines = get_line_from_file(raw_normalize.load_storage, load_files["event"], 0) + event_text, lines = get_line_from_file( + raw_normalize.load_storage, load_files["event"], 0 + ) assert lines == 1 event_json = json.loads(event_text) assert event_json["event"] == "user" @@ -131,7 +138,9 @@ def test_normalize_single_user_event_insert( for expected_table in expected_tables: get_line_from_file(raw_normalize.load_storage, load_files[expected_table]) # return first values line from event_user file - event_text, lines = get_line_from_file(raw_normalize.load_storage, load_files["event"], 2) + event_text, lines = get_line_from_file( + raw_normalize.load_storage, load_files["event"], 2 + ) assert lines == 3 assert "'user'" in event_text assert "'greet'" in event_text @@ -149,7 +158,9 @@ def test_normalize_single_user_event_insert( def test_normalize_filter_user_event( caps: DestinationCapabilitiesContext, rasa_normalize: Normalize ) -> None: - load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.user_load_v228_1"]) + load_id = extract_and_normalize_cases( + rasa_normalize, ["event.event.user_load_v228_1"] + ) _, load_files = expect_load_package( rasa_normalize.load_storage, load_id, @@ -162,7 +173,9 @@ def test_normalize_filter_user_event( "event_user__parse_data__intent_ranking", ], ) - event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event_user"], 0) + event_text, lines = get_line_from_file( + rasa_normalize.load_storage, load_files["event_user"], 0 + ) assert lines == 1 filtered_row = json.loads(event_text) assert "parse_data__intent__name" in filtered_row @@ -180,7 +193,9 @@ def test_normalize_filter_bot_event( _, load_files = expect_load_package( rasa_normalize.load_storage, load_id, ["event", "event_bot"] ) - event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event_bot"], 0) + event_text, lines = get_line_from_file( + rasa_normalize.load_storage, load_files["event_bot"], 0 + ) assert lines == 1 filtered_row = json.loads(event_text) assert "metadata__utter_action" in filtered_row @@ -191,11 +206,15 @@ def test_normalize_filter_bot_event( def test_preserve_slot_complex_value_json_l( caps: DestinationCapabilitiesContext, rasa_normalize: Normalize ) -> None: - load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) + load_id = extract_and_normalize_cases( + rasa_normalize, ["event.event.slot_session_metadata_1"] + ) _, load_files = expect_load_package( rasa_normalize.load_storage, load_id, ["event", "event_slot"] ) - event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event_slot"], 0) + event_text, lines = get_line_from_file( + rasa_normalize.load_storage, load_files["event_slot"], 0 + ) assert lines == 1 filtered_row = json.loads(event_text) assert type(filtered_row["value"]) is dict @@ -206,11 +225,15 @@ def test_preserve_slot_complex_value_json_l( def test_preserve_slot_complex_value_insert( caps: DestinationCapabilitiesContext, rasa_normalize: Normalize ) -> None: - load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) + load_id = extract_and_normalize_cases( + rasa_normalize, ["event.event.slot_session_metadata_1"] + ) _, load_files = expect_load_package( rasa_normalize.load_storage, load_id, ["event", "event_slot"] ) - event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event_slot"], 2) + event_text, lines = get_line_from_file( + rasa_normalize.load_storage, load_files["event_slot"], 2 + ) assert lines == 3 c_val = json.dumps({"user_id": "world", "mitter_id": "hello"}) assert c_val in event_text @@ -223,10 +246,17 @@ def test_normalize_many_events_insert( load_id = extract_and_normalize_cases( rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"] ) - expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] - _, load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) + expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + [ + "event_bot", + "event_action", + ] + _, load_files = expect_load_package( + rasa_normalize.load_storage, load_id, expected_tables + ) # return first values line from event_user file - event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event"], 4) + event_text, lines = get_line_from_file( + rasa_normalize.load_storage, load_files["event"], 4 + ) # 2 lines header + 3 lines data assert lines == 5 assert f"'{load_id}'" in event_text @@ -239,10 +269,17 @@ def test_normalize_many_events( load_id = extract_and_normalize_cases( rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"] ) - expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] - _, load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) + expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + [ + "event_bot", + "event_action", + ] + _, load_files = expect_load_package( + rasa_normalize.load_storage, load_id, expected_tables + ) # return first values line from event_user file - event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event"], 2) + event_text, lines = get_line_from_file( + rasa_normalize.load_storage, load_files["event"], 2 + ) # 3 lines data assert lines == 3 assert f"{load_id}" in event_text @@ -275,12 +312,16 @@ def test_multiprocessing_row_counting( # get step info step_info = raw_normalize.get_step_info(MockPipeline("multiprocessing_pipeline", True)) # type: ignore[abstract] assert step_info.row_counts["events"] == 100 - assert step_info.row_counts["events__payload__pull_request__requested_reviewers"] == 24 + assert ( + step_info.row_counts["events__payload__pull_request__requested_reviewers"] == 24 + ) # check if single load id assert len(step_info.loads_ids) == 1 row_counts = { t: m.items_count - for t, m in step_info.metrics[step_info.loads_ids[0]][0]["table_metrics"].items() + for t, m in step_info.metrics[step_info.loads_ids[0]][0][ + "table_metrics" + ].items() } assert row_counts == step_info.row_counts @@ -315,11 +356,17 @@ def test_normalize_many_packages( schemas.append(schema.name) # expect event tables if schema.name == "event": - expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] + expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + [ + "event_bot", + "event_action", + ] expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) if schema.name == "ethereum": expect_load_package( - rasa_normalize.load_storage, load_id, EXPECTED_ETH_TABLES, full_schema_update=False + rasa_normalize.load_storage, + load_id, + EXPECTED_ETH_TABLES, + full_schema_update=False, ) assert set(schemas) == set(["ethereum", "event"]) @@ -328,7 +375,9 @@ def test_normalize_many_packages( def test_normalize_typed_json( caps: DestinationCapabilitiesContext, raw_normalize: Normalize ) -> None: - extract_items(raw_normalize.normalize_storage, [JSON_TYPED_DICT], Schema("special"), "special") + extract_items( + raw_normalize.normalize_storage, [JSON_TYPED_DICT], Schema("special"), "special" + ) with ThreadPoolExecutor(max_workers=1) as pool: raw_normalize.run(pool) loads = raw_normalize.load_storage.list_normalized_packages() @@ -344,7 +393,9 @@ def test_normalize_typed_json( @pytest.mark.parametrize("caps", ALL_CAPABILITIES, indirect=True) -def test_schema_changes(caps: DestinationCapabilitiesContext, raw_normalize: Normalize) -> None: +def test_schema_changes( + caps: DestinationCapabilitiesContext, raw_normalize: Normalize +) -> None: doc = {"str": "text", "int": 1} extract_items(raw_normalize.normalize_storage, [doc], Schema("evolution"), "doc") load_id = normalize_pending(raw_normalize) @@ -372,13 +423,18 @@ def test_schema_changes(caps: DestinationCapabilitiesContext, raw_normalize: Nor doc3_2v = {"comp": [doc2]} doc3_doc_v = {"comp": [doc_v]} extract_items( - raw_normalize.normalize_storage, [doc3, doc, doc_v, doc3_2v, doc3_doc_v], schema, "doc" + raw_normalize.normalize_storage, + [doc3, doc, doc_v, doc3_2v, doc3_doc_v], + schema, + "doc", ) # schema = raw_normalize.schema_storage.load_schema("evolution") # extract_items(raw_normalize.normalize_storage, [doc3_2v, doc3_doc_v], schema, "doc") load_id = normalize_pending(raw_normalize) - _, table_files = expect_load_package(raw_normalize.load_storage, load_id, ["doc", "doc__comp"]) + _, table_files = expect_load_package( + raw_normalize.load_storage, load_id, ["doc", "doc__comp"] + ) assert len(table_files["doc"]) == 1 assert len(table_files["doc__comp"]) == 1 schema = raw_normalize.schema_storage.load_schema("evolution") @@ -403,9 +459,13 @@ def test_schema_changes(caps: DestinationCapabilitiesContext, raw_normalize: Nor def test_normalize_twice_with_flatten( caps: DestinationCapabilitiesContext, raw_normalize: Normalize ) -> None: - load_id = extract_and_normalize_cases(raw_normalize, ["github.issues.load_page_5_duck"]) + load_id = extract_and_normalize_cases( + raw_normalize, ["github.issues.load_page_5_duck"] + ) _, table_files = expect_load_package( - raw_normalize.load_storage, load_id, ["issues", "issues__labels", "issues__assignees"] + raw_normalize.load_storage, + load_id, + ["issues", "issues__labels", "issues__assignees"], ) assert len(table_files["issues"]) == 1 _, lines = get_line_from_file(raw_normalize.load_storage, table_files["issues"], 0) @@ -422,7 +482,9 @@ def assert_schema(_schema: Schema): schema = raw_normalize.schema_storage.load_schema("github") assert_schema(schema) - load_id = extract_and_normalize_cases(raw_normalize, ["github.issues.load_page_5_duck"]) + load_id = extract_and_normalize_cases( + raw_normalize, ["github.issues.load_page_5_duck"] + ) _, table_files = expect_load_package( raw_normalize.load_storage, load_id, @@ -454,7 +516,9 @@ def test_normalize_retry(raw_normalize: Normalize) -> None: # subsequent run must succeed raw_normalize.run(None) _, table_files = expect_load_package( - raw_normalize.load_storage, load_id, ["issues", "issues__labels", "issues__assignees"] + raw_normalize.load_storage, + load_id, + ["issues", "issues__labels", "issues__assignees"], ) assert len(table_files["issues"]) == 1 @@ -465,7 +529,12 @@ def test_group_worker_files() -> None: assert Normalize.group_worker_files([], 4) == [] assert Normalize.group_worker_files(["f001"], 1) == [["f001"]] assert Normalize.group_worker_files(["f001"], 100) == [["f001"]] - assert Normalize.group_worker_files(files[:4], 4) == [["f000"], ["f001"], ["f002"], ["f003"]] + assert Normalize.group_worker_files(files[:4], 4) == [ + ["f000"], + ["f001"], + ["f002"], + ["f003"], + ] assert Normalize.group_worker_files(files[:5], 4) == [ ["f000"], ["f001"], @@ -526,11 +595,16 @@ def test_group_worker_files() -> None: def extract_items( - normalize_storage: NormalizeStorage, items: Sequence[StrAny], schema: Schema, table_name: str + normalize_storage: NormalizeStorage, + items: Sequence[StrAny], + schema: Schema, + table_name: str, ) -> str: extractor = ExtractStorage(normalize_storage.config) load_id = extractor.create_load_package(schema) - extractor.write_data_item("puae-jsonl", load_id, schema.name, table_name, items, None) + extractor.write_data_item( + "puae-jsonl", load_id, schema.name, table_name, items, None + ) extractor.close_writers(load_id) extractor.commit_new_load_package(load_id, schema) return load_id @@ -559,7 +633,9 @@ def normalize_pending(normalize: Normalize) -> str: # read schema from package schema = normalize.normalize_storage.extracted_packages.load_schema(load_id) # get files - schema_files = normalize.normalize_storage.extracted_packages.list_new_jobs(load_id) + schema_files = normalize.normalize_storage.extracted_packages.list_new_jobs( + load_id + ) # normalize without pool normalize.spool_files(load_id, schema, normalize.map_single, schema_files) @@ -654,4 +730,6 @@ def assert_timestamp_data_type(load_storage: LoadStorage, data_type: TDataType) loads = load_storage.list_normalized_packages() event_schema = load_storage.normalized_packages.load_schema(loads[0]) # in raw normalize timestamp column must not be coerced to timestamp - assert event_schema.get_table_columns("event")["timestamp"]["data_type"] == data_type + assert ( + event_schema.get_table_columns("event")["timestamp"]["data_type"] == data_type + ) diff --git a/tests/pipeline/cases/github_pipeline/github_extract.py b/tests/pipeline/cases/github_pipeline/github_extract.py index 6be6643947..91bb2e47e7 100644 --- a/tests/pipeline/cases/github_pipeline/github_extract.py +++ b/tests/pipeline/cases/github_pipeline/github_extract.py @@ -6,7 +6,10 @@ if __name__ == "__main__": p = dlt.pipeline( - "dlt_github_pipeline", destination="duckdb", dataset_name="github_3", full_refresh=False + "dlt_github_pipeline", + destination="duckdb", + dataset_name="github_3", + full_refresh=False, ) github_source = github() if len(sys.argv) > 1: diff --git a/tests/pipeline/cases/github_pipeline/github_pipeline.py b/tests/pipeline/cases/github_pipeline/github_pipeline.py index c55bd02ba0..7dc0b97698 100644 --- a/tests/pipeline/cases/github_pipeline/github_pipeline.py +++ b/tests/pipeline/cases/github_pipeline/github_pipeline.py @@ -20,13 +20,19 @@ def github(): merge_key=("node_id", "url"), ) def load_issues( - created_at=dlt.sources.incremental[pendulum.DateTime]("created_at"), # noqa: B008 + created_at=dlt.sources.incremental[pendulum.DateTime]( + "created_at" + ), # noqa: B008 ): # we should be in TEST_STORAGE folder with open( - "../tests/normalize/cases/github.issues.load_page_5_duck.json", "r", encoding="utf-8" + "../tests/normalize/cases/github.issues.load_page_5_duck.json", + "r", + encoding="utf-8", ) as f: - issues = map(convert_dates, sorted(json.load(f), key=lambda x: x["created_at"])) + issues = map( + convert_dates, sorted(json.load(f), key=lambda x: x["created_at"]) + ) yield from issues return load_issues @@ -34,7 +40,10 @@ def load_issues( if __name__ == "__main__": p = dlt.pipeline( - "dlt_github_pipeline", destination="duckdb", dataset_name="github_3", full_refresh=False + "dlt_github_pipeline", + destination="duckdb", + dataset_name="github_3", + full_refresh=False, ) github_source = github() if len(sys.argv) > 1: diff --git a/tests/pipeline/test_arrow_sources.py b/tests/pipeline/test_arrow_sources.py index 4991afa002..1dab4b5be1 100644 --- a/tests/pipeline/test_arrow_sources.py +++ b/tests/pipeline/test_arrow_sources.py @@ -14,7 +14,11 @@ from dlt.pipeline.exceptions import PipelineStepFailed -from tests.cases import arrow_format_from_pandas, arrow_table_all_data_types, TArrowFormat +from tests.cases import ( + arrow_format_from_pandas, + arrow_table_all_data_types, + TArrowFormat, +) from tests.utils import preserve_environ @@ -44,7 +48,9 @@ def some_data(): pipeline.extract(some_data()) norm_storage = pipeline._get_normalize_storage() extract_files = [ - fn for fn in norm_storage.list_files_to_normalize_sorted() if fn.endswith(".parquet") + fn + for fn in norm_storage.list_files_to_normalize_sorted() + if fn.endswith(".parquet") ] assert len(extract_files) == 1 @@ -183,8 +189,12 @@ def data_frames(): yield item # get buffer written and file rotated with each yielded frame - os.environ[f"SOURCES__{pipeline_name.upper()}__DATA_WRITER__BUFFER_MAX_ITEMS"] = str(len(rows)) - os.environ[f"SOURCES__{pipeline_name.upper()}__DATA_WRITER__FILE_MAX_ITEMS"] = str(len(rows)) + os.environ[f"SOURCES__{pipeline_name.upper()}__DATA_WRITER__BUFFER_MAX_ITEMS"] = ( + str(len(rows)) + ) + os.environ[f"SOURCES__{pipeline_name.upper()}__DATA_WRITER__FILE_MAX_ITEMS"] = str( + len(rows) + ) pipeline.extract(data_frames()) # ten parquet files @@ -343,7 +353,9 @@ def test_empty_arrow(item_type: TArrowFormat) -> None: empty_df = pd.DataFrame(columns=item.columns) item_resource = dlt.resource( - arrow_format_from_pandas(empty_df, item_type), name="items", write_disposition="replace" + arrow_format_from_pandas(empty_df, item_type), + name="items", + write_disposition="replace", ) info = pipeline.extract(item_resource) load_id = info.loads_ids[0] diff --git a/tests/pipeline/test_dlt_versions.py b/tests/pipeline/test_dlt_versions.py index ccf926cc62..a038e01ef3 100644 --- a/tests/pipeline/test_dlt_versions.py +++ b/tests/pipeline/test_dlt_versions.py @@ -33,7 +33,9 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: - shutil.copytree("tests/pipeline/cases/github_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True) + shutil.copytree( + "tests/pipeline/cases/github_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True + ) # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): @@ -46,11 +48,14 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # create virtual env with (0.3.0) before the current schema upgrade with Venv.create(tempfile.mkdtemp(), ["dlt[duckdb]==0.3.0"]) as venv: # NOTE: we force a newer duckdb into the 0.3.0 dlt version to get compatible duckdb storage - venv._install_deps(venv.context, ["duckdb" + "==" + pkg_version("duckdb")]) + venv._install_deps( + venv.context, ["duckdb" + "==" + pkg_version("duckdb")] + ) # load 20 issues print( venv.run_script( - "../tests/pipeline/cases/github_pipeline/github_pipeline.py", "20" + "../tests/pipeline/cases/github_pipeline/github_pipeline.py", + "20", ) ) # load schema and check _dlt_loads definition @@ -66,14 +71,19 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: ) # check the dlt state table assert { - "version_hash" not in github_schema["tables"][STATE_TABLE_NAME]["columns"] + "version_hash" + not in github_schema["tables"][STATE_TABLE_NAME]["columns"] } # check loads table without attaching to pipeline duckdb_cfg = resolve_configuration( - DuckDbClientConfiguration()._bind_dataset_name(dataset_name=GITHUB_DATASET), + DuckDbClientConfiguration()._bind_dataset_name( + dataset_name=GITHUB_DATASET + ), sections=("destination", "duckdb"), ) - with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: + with DuckDbSqlClient( + GITHUB_DATASET, duckdb_cfg.credentials + ) as client: rows = client.execute_sql(f"SELECT * FROM {LOADS_TABLE_NAME}") # make sure we have just 4 columns assert len(rows[0]) == 4 @@ -84,14 +94,16 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: assert len(rows[0]) == 5 + 2 # inspect old state state_dict = json.loads( - test_storage.load(f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/state.json") + test_storage.load( + f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/state.json" + ) ) assert "_version_hash" not in state_dict # also we expect correctly decoded pendulum here created_at_value = custom_pua_decode( - state_dict["sources"]["github"]["resources"]["load_issues"]["incremental"][ - "created_at" - ]["last_value"] + state_dict["sources"]["github"]["resources"]["load_issues"][ + "incremental" + ]["created_at"]["last_value"] ) assert isinstance(created_at_value, pendulum.DateTime) assert created_at_value == pendulum.parse("2021-04-16T04:34:05Z") @@ -99,7 +111,11 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # execute in current version venv = Venv.restore_current() # load all issues - print(venv.run_script("../tests/pipeline/cases/github_pipeline/github_pipeline.py")) + print( + venv.run_script( + "../tests/pipeline/cases/github_pipeline/github_pipeline.py" + ) + ) # hash hash in schema github_schema = json.loads( test_storage.load( @@ -107,10 +123,15 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: ) ) assert github_schema["engine_version"] == 9 - assert "schema_version_hash" in github_schema["tables"][LOADS_TABLE_NAME]["columns"] + assert ( + "schema_version_hash" + in github_schema["tables"][LOADS_TABLE_NAME]["columns"] + ) # load state state_dict = json.loads( - test_storage.load(f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/state.json") + test_storage.load( + f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/state.json" + ) ) assert "_version_hash" in state_dict @@ -131,7 +152,9 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # two schema versions rows = client.execute_sql(f"SELECT * FROM {VERSION_TABLE_NAME}") assert len(rows) == 2 - rows = client.execute_sql(f"SELECT * FROM {STATE_TABLE_NAME} ORDER BY version") + rows = client.execute_sql( + f"SELECT * FROM {STATE_TABLE_NAME} ORDER BY version" + ) # we have hash columns assert len(rows[0]) == 6 + 2 assert len(rows) == 2 @@ -141,10 +164,12 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: assert rows[1][7] == state_dict["_version_hash"] # attach to existing pipeline - pipeline = dlt.attach(GITHUB_PIPELINE_NAME, credentials=duckdb_cfg.credentials) - created_at_value = pipeline.state["sources"]["github"]["resources"]["load_issues"][ - "incremental" - ]["created_at"]["last_value"] + pipeline = dlt.attach( + GITHUB_PIPELINE_NAME, credentials=duckdb_cfg.credentials + ) + created_at_value = pipeline.state["sources"]["github"]["resources"][ + "load_issues" + ]["incremental"]["created_at"]["last_value"] assert isinstance(created_at_value, pendulum.DateTime) assert created_at_value == pendulum.parse("2023-02-17T09:52:12Z") pipeline = pipeline.drop() @@ -157,11 +182,16 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # we have updated schema assert pipeline.default_schema.ENGINE_VERSION == 9 # make sure that schema hash retrieved from the destination is exactly the same as the schema hash that was in storage before the schema was wiped - assert pipeline.default_schema.stored_version_hash == github_schema["version_hash"] + assert ( + pipeline.default_schema.stored_version_hash + == github_schema["version_hash"] + ) def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: - shutil.copytree("tests/pipeline/cases/github_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True) + shutil.copytree( + "tests/pipeline/cases/github_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True + ) # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): @@ -173,11 +203,14 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: ): # create virtual env with (0.3.0) before the current schema upgrade with Venv.create(tempfile.mkdtemp(), ["dlt[duckdb]==0.3.0"]) as venv: - venv._install_deps(venv.context, ["duckdb" + "==" + pkg_version("duckdb")]) + venv._install_deps( + venv.context, ["duckdb" + "==" + pkg_version("duckdb")] + ) # extract and normalize on old version but DO NOT LOAD print( venv.run_script( - "../tests/pipeline/cases/github_pipeline/github_extract.py", "70" + "../tests/pipeline/cases/github_pipeline/github_extract.py", + "70", ) ) print( @@ -187,9 +220,15 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: ) # switch to current version and make sure the load package loads and schema migrates venv = Venv.restore_current() - print(venv.run_script("../tests/pipeline/cases/github_pipeline/github_load.py")) + print( + venv.run_script( + "../tests/pipeline/cases/github_pipeline/github_load.py" + ) + ) duckdb_cfg = resolve_configuration( - DuckDbClientConfiguration()._bind_dataset_name(dataset_name=GITHUB_DATASET), + DuckDbClientConfiguration()._bind_dataset_name( + dataset_name=GITHUB_DATASET + ), sections=("destination", "duckdb"), ) with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: @@ -201,7 +240,9 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: ) ) # attach to existing pipeline - pipeline = dlt.attach(GITHUB_PIPELINE_NAME, credentials=duckdb_cfg.credentials) + pipeline = dlt.attach( + GITHUB_PIPELINE_NAME, credentials=duckdb_cfg.credentials + ) # get the schema from schema storage before we sync github_schema = json.loads( test_storage.load( @@ -212,7 +253,10 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: pipeline.sync_destination() assert pipeline.default_schema.ENGINE_VERSION == 9 # schema version does not match `dlt.attach` does not update to the right schema by itself - assert pipeline.default_schema.stored_version_hash != github_schema["version_hash"] + assert ( + pipeline.default_schema.stored_version_hash + != github_schema["version_hash"] + ) # state has hash assert pipeline.state["_version_hash"] is not None # but in db there's no hash - we loaded an old package with backward compatible schema @@ -232,7 +276,9 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: def test_normalize_package_with_dlt_update(test_storage: FileStorage) -> None: - shutil.copytree("tests/pipeline/cases/github_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True) + shutil.copytree( + "tests/pipeline/cases/github_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True + ) # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): @@ -244,11 +290,14 @@ def test_normalize_package_with_dlt_update(test_storage: FileStorage) -> None: ): # create virtual env with (0.3.0) before the current schema upgrade with Venv.create(tempfile.mkdtemp(), ["dlt[duckdb]==0.3.0"]) as venv: - venv._install_deps(venv.context, ["duckdb" + "==" + pkg_version("duckdb")]) + venv._install_deps( + venv.context, ["duckdb" + "==" + pkg_version("duckdb")] + ) # extract only print( venv.run_script( - "../tests/pipeline/cases/github_pipeline/github_extract.py", "70" + "../tests/pipeline/cases/github_pipeline/github_extract.py", + "70", ) ) # switch to current version and normalize existing extract package @@ -259,7 +308,9 @@ def test_normalize_package_with_dlt_update(test_storage: FileStorage) -> None: assert mig_ex.value.from_version == "1.0.0" # delete all files in extracted folder - for file in pipeline._pipeline_storage.list_folder_files("normalize/extracted"): + for file in pipeline._pipeline_storage.list_folder_files( + "normalize/extracted" + ): pipeline._pipeline_storage.delete(file) # now we can migrate the storage pipeline.normalize() diff --git a/tests/pipeline/test_import_export_schema.py b/tests/pipeline/test_import_export_schema.py index 6f40e1d1eb..3d369206c3 100644 --- a/tests/pipeline/test_import_export_schema.py +++ b/tests/pipeline/test_import_export_schema.py @@ -99,7 +99,9 @@ def test_import_schema_is_respected() -> None: # take default schema, modify column type and save it to import folder modified_schema = p.default_schema.clone() modified_schema.tables["person"]["columns"]["id"]["data_type"] = "text" - with open(os.path.join(IMPORT_SCHEMA_PATH, name + ".schema.yaml"), "w", encoding="utf-8") as f: + with open( + os.path.join(IMPORT_SCHEMA_PATH, name + ".schema.yaml"), "w", encoding="utf-8" + ) as f: f.write(modified_schema.to_pretty_yaml()) # import schema will be imported into pipeline @@ -109,7 +111,10 @@ def test_import_schema_is_respected() -> None: # change in pipeline schema assert p.default_schema.tables["person"]["columns"]["id"]["data_type"] == "text" # import schema is not overwritten - assert _get_import_schema(name).tables["person"]["columns"]["id"]["data_type"] == "text" + assert ( + _get_import_schema(name).tables["person"]["columns"]["id"]["data_type"] + == "text" + ) # when creating a new schema (e.g. with full refresh), this will work p = dlt.pipeline( @@ -131,7 +136,10 @@ def test_import_schema_is_respected() -> None: assert p.default_schema.tables["person"]["columns"]["id"]["data_type"] == "text" # import schema is not overwritten - assert _get_import_schema(name).tables["person"]["columns"]["id"]["data_type"] == "text" + assert ( + _get_import_schema(name).tables["person"]["columns"]["id"]["data_type"] + == "text" + ) # export now includes the modified column type export_schema = _get_export_schema(name) @@ -181,7 +189,9 @@ def resource(): } # adding column to the resource will not change the import schema, but the pipeline schema will evolve - @dlt.resource(primary_key="id", name="person", columns={"email": {"data_type": "text"}}) + @dlt.resource( + primary_key="id", name="person", columns={"email": {"data_type": "text"}} + ) def resource(): yield EXAMPLE_DATA @@ -205,7 +215,9 @@ def resource(): "name": "age", } with open( - os.path.join(IMPORT_SCHEMA_PATH, "source" + ".schema.yaml"), "w", encoding="utf-8" + os.path.join(IMPORT_SCHEMA_PATH, "source" + ".schema.yaml"), + "w", + encoding="utf-8", ) as f: f.write(import_schema.to_pretty_yaml()) diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 37356c2b44..b71190e672 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -34,12 +34,20 @@ from dlt.common.schema import Schema from dlt.destinations import filesystem, redshift, dummy -from dlt.extract.exceptions import InvalidResourceDataTypeBasic, PipeGenInvalid, SourceExhausted +from dlt.extract.exceptions import ( + InvalidResourceDataTypeBasic, + PipeGenInvalid, + SourceExhausted, +) from dlt.extract.extract import ExtractStorage from dlt.extract import DltResource, DltSource from dlt.extract.extractors import MaterializedEmptyList from dlt.load.exceptions import LoadClientJobFailed -from dlt.pipeline.exceptions import InvalidPipelineName, PipelineNotActive, PipelineStepFailed +from dlt.pipeline.exceptions import ( + InvalidPipelineName, + PipelineNotActive, + PipelineStepFailed, +) from dlt.pipeline.helpers import retry_load from tests.common.utils import TEST_SENTRY_DSN @@ -61,7 +69,9 @@ def test_default_pipeline() -> None: possible_names = ["dlt_pytest", "dlt_pipeline"] possible_dataset_names = ["dlt_pytest_dataset", "dlt_pipeline_dataset"] assert p.pipeline_name in possible_names - assert p.pipelines_dir == os.path.abspath(os.path.join(TEST_STORAGE_ROOT, ".dlt", "pipelines")) + assert p.pipelines_dir == os.path.abspath( + os.path.join(TEST_STORAGE_ROOT, ".dlt", "pipelines") + ) assert p.runtime_config.pipeline_name == p.pipeline_name # dataset that will be used to load data is the pipeline name assert p.dataset_name in possible_dataset_names @@ -277,12 +287,16 @@ def test_destination_staging_config(environment: Any) -> None: fs_dest = filesystem("file:///testing-bucket") p = dlt.pipeline( pipeline_name="staging_pipeline", - destination=redshift(credentials="redshift://loader:loader@localhost:5432/dlt_data"), + destination=redshift( + credentials="redshift://loader:loader@localhost:5432/dlt_data" + ), staging=fs_dest, ) schema = Schema("foo") p._inject_schema(schema) - initial_config = p._get_destination_client_initial_config(p.staging, as_staging=True) + initial_config = p._get_destination_client_initial_config( + p.staging, as_staging=True + ) staging_config = fs_dest.configuration(initial_config) # type: ignore[arg-type] # Ensure that as_staging flag is set in the final resolved conifg @@ -304,7 +318,9 @@ def test_destination_factory_defaults_resolve_from_config(environment: Any) -> N def test_destination_credentials_in_factory(environment: Any) -> None: - os.environ["DESTINATION__REDSHIFT__CREDENTIALS"] = "redshift://abc:123@localhost:5432/some_db" + os.environ["DESTINATION__REDSHIFT__CREDENTIALS"] = ( + "redshift://abc:123@localhost:5432/some_db" + ) redshift_dest = redshift("redshift://abc:123@localhost:5432/other_db") @@ -323,11 +339,15 @@ def test_destination_credentials_in_factory(environment: Any) -> None: assert dest_config.credentials.database == "some_db" -@pytest.mark.skip(reason="does not work on CI. probably takes right credentials from somewhere....") +@pytest.mark.skip( + reason="does not work on CI. probably takes right credentials from somewhere...." +) def test_destination_explicit_invalid_credentials_filesystem(environment: Any) -> None: # if string cannot be parsed p = dlt.pipeline( - pipeline_name="postgres_pipeline", destination="filesystem", credentials="PR8BLEM" + pipeline_name="postgres_pipeline", + destination="filesystem", + credentials="PR8BLEM", ) with pytest.raises(NativeValueError): p._get_destination_client_initial_config(p.destination) @@ -372,12 +392,18 @@ def test_extract_multiple_sources() -> None: s1 = DltSource( dlt.Schema("default"), "module", - [dlt.resource([1, 2, 3], name="resource_1"), dlt.resource([3, 4, 5], name="resource_2")], + [ + dlt.resource([1, 2, 3], name="resource_1"), + dlt.resource([3, 4, 5], name="resource_2"), + ], ) s2 = DltSource( dlt.Schema("default_2"), "module", - [dlt.resource([6, 7, 8], name="resource_3"), dlt.resource([9, 10, 0], name="resource_4")], + [ + dlt.resource([6, 7, 8], name="resource_3"), + dlt.resource([9, 10, 0], name="resource_4"), + ], ) p = dlt.pipeline(destination="dummy") @@ -400,10 +426,15 @@ def i_fail(): s3 = DltSource( dlt.Schema("default_3"), "module", - [dlt.resource([1, 2, 3], name="resource_1"), dlt.resource([3, 4, 5], name="resource_2")], + [ + dlt.resource([1, 2, 3], name="resource_1"), + dlt.resource([3, 4, 5], name="resource_2"), + ], ) s4 = DltSource( - dlt.Schema("default_4"), "module", [dlt.resource([6, 7, 8], name="resource_3"), i_fail] + dlt.Schema("default_4"), + "module", + [dlt.resource([6, 7, 8], name="resource_3"), i_fail], ) with pytest.raises(PipelineStepFailed): @@ -432,7 +463,9 @@ def with_mark(): p = dlt.pipeline(destination="dummy", pipeline_name="mark_pipeline") p.extract(with_mark()) storage = ExtractStorage(p._normalize_storage_config()) - expect_extracted_file(storage, "mark", "spec_table", json.dumps([{"id": 1}, {"id": 2}])) + expect_extracted_file( + storage, "mark", "spec_table", json.dumps([{"id": 1}, {"id": 2}]) + ) p.normalize() # no "with_mark" table in the schema: we update resource hints before any table schema is computed assert "with_mark" not in p.default_schema.tables @@ -478,7 +511,9 @@ def with_table_hints(): "with_table_hints": 1, } # check table counts - assert_data_table_counts(pipeline, {"table_a": 2, "table_b": 2, "with_table_hints": 1}) + assert_data_table_counts( + pipeline, {"table_a": 2, "table_b": 2, "with_table_hints": 1} + ) def test_mark_hints_variant_dynamic_name() -> None: @@ -693,7 +728,10 @@ def data_schema_3(): p = dlt.pipeline(pipeline_name=pipeline_name, destination="dummy") with pytest.raises(PipelineStepFailed): - p.run([data_schema_1(), data_schema_2(), data_schema_3()], write_disposition="replace") + p.run( + [data_schema_1(), data_schema_2(), data_schema_3()], + write_disposition="replace", + ) # first run didn't really happen assert p.first_run is True @@ -828,7 +866,9 @@ def fail_extract(): retry_count = 2 with pytest.raises(PipelineStepFailed) as py_ex: for attempt in Retrying( - stop=stop_after_attempt(3), retry=retry_if_exception(retry_load(())), reraise=True + stop=stop_after_attempt(3), + retry=retry_if_exception(retry_load(())), + reraise=True, ): with attempt: p.run(fail_extract()) @@ -903,7 +943,8 @@ def resource_1(): p.run(resource_1, write_disposition="replace") print(list(p._schema_storage.live_schemas.values())[0].to_pretty_yaml()) assert ( - p.schemas[p.default_schema_name].get_table("resource_1")["write_disposition"] == "replace" + p.schemas[p.default_schema_name].get_table("resource_1")["write_disposition"] + == "replace" ) assert p.default_schema.get_table("resource_1")["write_disposition"] == "replace" @@ -927,13 +968,17 @@ def github_repo_events_table_meta(page): def _get_shuffled_events(repeat: int = 1): for _ in range(repeat): with open( - "tests/normalize/cases/github.events.load_page_1_duck.json", "r", encoding="utf-8" + "tests/normalize/cases/github.events.load_page_1_duck.json", + "r", + encoding="utf-8", ) as f: issues = json.load(f) yield issues -@pytest.mark.parametrize("github_resource", (github_repo_events_table_meta, github_repo_events)) +@pytest.mark.parametrize( + "github_resource", (github_repo_events_table_meta, github_repo_events) +) def test_dispatch_rows_to_tables(github_resource: DltResource): os.environ["COMPLETED_PROB"] = "1.0" pipeline_name = "pipe_" + uniq_id() @@ -950,7 +995,9 @@ def test_dispatch_rows_to_tables(github_resource: DltResource): # all the tables present assert ( - expected_tables.intersection([t["name"] for t in p.default_schema.data_tables()]) + expected_tables.intersection( + [t["name"] for t in p.default_schema.data_tables()] + ) == expected_tables ) @@ -1083,7 +1130,9 @@ def items(): # complete columns preserve order in "columns" p = p.drop() - @dlt.resource(columns={"c3": {"precision": 32, "data_type": "decimal"}}, primary_key="c1") + @dlt.resource( + columns={"c3": {"precision": 32, "data_type": "decimal"}}, primary_key="c1" + ) def items2(): yield {"c1": 1, "c2": 1, "c3": 1} @@ -1104,7 +1153,8 @@ def test_pipeline_log_progress() -> None: # will attach dlt logger p = dlt.pipeline( - destination="dummy", progress=dlt.progress.log(0.5, logger=None, log_level=logging.WARNING) + destination="dummy", + progress=dlt.progress.log(0.5, logger=None, log_level=logging.WARNING), ) # collector was created before pipeline so logger is not attached assert cast(LogCollector, p.collector).logger is None @@ -1186,7 +1236,12 @@ def writes_state(): def test_extract_add_tables() -> None: # we extract and make sure that tables are added to schema s = airtable_emojis() - assert list(s.resources.keys()) == ["💰Budget", "📆 Schedule", "🦚Peacock", "🦚WidePeacock"] + assert list(s.resources.keys()) == [ + "💰Budget", + "📆 Schedule", + "🦚Peacock", + "🦚WidePeacock", + ] assert s.resources["🦚Peacock"].compute_table_schema()["resource"] == "🦚Peacock" # only name will be normalized assert s.resources["🦚Peacock"].compute_table_schema()["name"] == "🦚Peacock" @@ -1501,7 +1556,8 @@ def autodetect(): # unix ts recognized assert ( - pipeline.default_schema.get_table("numbers")["columns"]["value"]["data_type"] == "timestamp" + pipeline.default_schema.get_table("numbers")["columns"]["value"]["data_type"] + == "timestamp" ) pipeline.load() @@ -1513,7 +1569,10 @@ def autodetect(): pipeline = dlt.pipeline(destination="duckdb") pipeline.run(source) - assert pipeline.default_schema.get_table("numbers")["columns"]["value"]["data_type"] == "bigint" + assert ( + pipeline.default_schema.get_table("numbers")["columns"]["value"]["data_type"] + == "bigint" + ) def test_flattened_column_hint() -> None: @@ -1555,7 +1614,10 @@ def nested_resource(): == "timestamp" ) # make sure data is there - assert pipeline.last_trace.last_normalize_info.row_counts["flattened_dict__values"] == 4 + assert ( + pipeline.last_trace.last_normalize_info.row_counts["flattened_dict__values"] + == 4 + ) def test_empty_rows_are_included() -> None: @@ -1590,7 +1652,9 @@ def test_resource_state_name_not_normalized() -> None: with pipeline.destination_client() as client: # type: ignore[assignment] state = load_pipeline_state_from_destination(pipeline.pipeline_name, client) assert "airtable_emojis" in state["sources"] - assert state["sources"]["airtable_emojis"]["resources"] == {"🦚Peacock": {"🦚🦚🦚": "🦚"}} + assert state["sources"]["airtable_emojis"]["resources"] == { + "🦚Peacock": {"🦚🦚🦚": "🦚"} + } def test_pipeline_list_packages() -> None: @@ -1600,7 +1664,11 @@ def test_pipeline_list_packages() -> None: assert len(load_ids) == 1 # two new packages: for emojis schema and emojis_2 pipeline.extract( - [airtable_emojis(), airtable_emojis(), airtable_emojis().clone(with_name="emojis_2")] + [ + airtable_emojis(), + airtable_emojis(), + airtable_emojis().clone(with_name="emojis_2"), + ] ) load_ids = pipeline.list_extracted_load_packages() assert len(load_ids) == 3 @@ -1618,14 +1686,18 @@ def test_pipeline_list_packages() -> None: normalized_package = pipeline.get_load_package_info(load_ids[0]) # same number of new jobs assert normalized_package.state == "normalized" - assert len(normalized_package.jobs["new_jobs"]) == len(extracted_package.jobs["new_jobs"]) + assert len(normalized_package.jobs["new_jobs"]) == len( + extracted_package.jobs["new_jobs"] + ) # load all 3 packages and fail all jobs in them os.environ["FAIL_PROB"] = "1.0" pipeline.load() load_ids_l = pipeline.list_completed_load_packages() assert load_ids == load_ids_l loaded_package = pipeline.get_load_package_info(load_ids[0]) - assert len(loaded_package.jobs["failed_jobs"]) == len(extracted_package.jobs["new_jobs"]) + assert len(loaded_package.jobs["failed_jobs"]) == len( + extracted_package.jobs["new_jobs"] + ) assert loaded_package.state == "loaded" failed_jobs = pipeline.list_failed_jobs_in_package(load_ids[0]) assert len(loaded_package.jobs["failed_jobs"]) == len(failed_jobs) @@ -1680,7 +1752,9 @@ def test_parallel_pipelines_threads(workers: int) -> None: os.environ["PIPELINE_1__EXTRA"] = "CFG_P_1" os.environ["PIPELINE_2__EXTRA"] = "CFG_P_2" - def _run_pipeline(pipeline_name: str) -> Tuple[LoadInfo, PipelineContext, DictStrAny]: + def _run_pipeline( + pipeline_name: str, + ) -> Tuple[LoadInfo, PipelineContext, DictStrAny]: try: @dlt.transformer( @@ -1717,7 +1791,9 @@ def github(extra: str = dlt.config.value): # make sure that only one pipeline is created with init_lock: - pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb") + pipeline = dlt.pipeline( + pipeline_name=pipeline_name, destination="duckdb" + ) context = Container()[PipelineContext] finally: sem.release() @@ -1778,9 +1854,17 @@ def github(extra: str = dlt.config.value): pipeline_2: dlt.Pipeline = context_2.pipeline() # type: ignore n_counts_1 = pipeline_1.last_trace.last_normalize_info - assert n_counts_1.row_counts["push_event"] == 8 * page_repeats == counts_1["push_event"] + assert ( + n_counts_1.row_counts["push_event"] + == 8 * page_repeats + == counts_1["push_event"] + ) n_counts_2 = pipeline_2.last_trace.last_normalize_info - assert n_counts_2.row_counts["push_event"] == 8 * page_repeats == counts_2["push_event"] + assert ( + n_counts_2.row_counts["push_event"] + == 8 * page_repeats + == counts_2["push_event"] + ) assert pipeline_1.pipeline_name == "pipeline_1" assert pipeline_2.pipeline_name == "pipeline_2" @@ -1968,7 +2052,8 @@ def users_source(): @dlt.source def taxi_demand_source(): @dlt.resource( - primary_key="city", columns=[{"name": "id", "data_type": "bigint", "precision": 4}] + primary_key="city", + columns=[{"name": "id", "data_type": "bigint", "precision": 4}], ) def locations(idx=dlt.sources.incremental("id")): for idx in range(10): diff --git a/tests/pipeline/test_pipeline_extra.py b/tests/pipeline/test_pipeline_extra.py index 98323a2412..e2a6613988 100644 --- a/tests/pipeline/test_pipeline_extra.py +++ b/tests/pipeline/test_pipeline_extra.py @@ -40,9 +40,13 @@ class BaseModel: # type: ignore[no-redef] @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) -def test_create_pipeline_all_destinations(destination_config: DestinationTestConfiguration) -> None: +def test_create_pipeline_all_destinations( + destination_config: DestinationTestConfiguration, +) -> None: # create pipelines, extract and normalize. that should be possible without installing any dependencies p = dlt.pipeline( pipeline_name=destination_config.destination + "_pipeline", @@ -235,7 +239,9 @@ def with_mark(): @pytest.mark.parametrize("file_format", ("parquet", "insert_values", "jsonl")) def test_columns_hint_with_file_formats(file_format: TLoaderFileFormat) -> None: - @dlt.resource(write_disposition="replace", columns=[{"name": "text", "data_type": "text"}]) + @dlt.resource( + write_disposition="replace", columns=[{"name": "text", "data_type": "text"}] + ) def generic(start=8): yield [{"id": idx, "text": "A" * idx} for idx in range(start, start + 10)] @@ -342,7 +348,9 @@ class Parent(BaseModel): # Check if complex fields preserved # their contents and were not flattened assert loaded_values == { - "child": '{"child_attribute":"any string","optional_child_attribute":null}', + "child": ( + '{"child_attribute":"any string","optional_child_attribute":null}' + ), "optional_parent_attribute": None, "data_dictionary": '{"child_attribute":"any string"}', } @@ -422,7 +430,10 @@ def pandas_incremental(numbers=dlt.sources.incremental("Numbers")): yield df info = dlt.run( - pandas_incremental(), write_disposition="append", table_name="data", destination="duckdb" + pandas_incremental(), + write_disposition="append", + table_name="data", + destination="duckdb", ) with info.pipeline.sql_client() as client: # type: ignore @@ -447,14 +458,21 @@ def users(): # emit table schema with the item dlt.mark.make_hints( columns=[ - {"name": "id", "data_type": "bigint", "precision": 4, "nullable": False}, + { + "name": "id", + "data_type": "bigint", + "precision": 4, + "nullable": False, + }, {"name": "name", "data_type": "text", "nullable": False}, ] ), ) # write parquet file to storage - info = dlt.run(users, destination=local, loader_file_format="parquet", dataset_name="user_data") + info = dlt.run( + users, destination=local, loader_file_format="parquet", dataset_name="user_data" + ) assert_load_info(info) assert set(info.pipeline.default_schema.tables["users"]["columns"].keys()) == {"id", "name", "_dlt_load_id", "_dlt_id"} # type: ignore # find parquet file diff --git a/tests/pipeline/test_pipeline_file_format_resolver.py b/tests/pipeline/test_pipeline_file_format_resolver.py index 588ad720a5..0565576d12 100644 --- a/tests/pipeline/test_pipeline_file_format_resolver.py +++ b/tests/pipeline/test_pipeline_file_format_resolver.py @@ -15,7 +15,9 @@ def test_file_format_resolution() -> None: # raise on destinations that does not support staging with pytest.raises(DestinationLoadingViaStagingNotSupported): p = dlt.pipeline( - pipeline_name="managed_state_pipeline", destination="postgres", staging="filesystem" + pipeline_name="managed_state_pipeline", + destination="postgres", + staging="filesystem", ) # raise on staging that does not support staging interface @@ -43,7 +45,10 @@ def __init__(self) -> None: assert p._resolve_loader_file_format("some", "some", destcp, None, None) == "jsonl" # check resolution with input - assert p._resolve_loader_file_format("some", "some", destcp, None, "parquet") == "parquet" + assert ( + p._resolve_loader_file_format("some", "some", destcp, None, "parquet") + == "parquet" + ) # check invalid input with pytest.raises(DestinationIncompatibleLoaderFileFormatException): @@ -53,7 +58,10 @@ def __init__(self) -> None: destcp.supported_staging_file_formats = ["jsonl", "insert_values", "parquet"] destcp.preferred_staging_file_format = "insert_values" stagecp.supported_loader_file_formats = ["jsonl", "insert_values", "parquet"] - assert p._resolve_loader_file_format("some", "some", destcp, stagecp, None) == "insert_values" + assert ( + p._resolve_loader_file_format("some", "some", destcp, stagecp, None) + == "insert_values" + ) # check invalid input with pytest.raises(DestinationIncompatibleLoaderFileFormatException): @@ -63,8 +71,14 @@ def __init__(self) -> None: destcp.supported_staging_file_formats = ["insert_values", "parquet"] destcp.preferred_staging_file_format = "csv" # type: ignore[assignment] stagecp.supported_loader_file_formats = ["jsonl", "insert_values", "parquet"] - assert p._resolve_loader_file_format("some", "some", destcp, stagecp, None) == "insert_values" - assert p._resolve_loader_file_format("some", "some", destcp, stagecp, "parquet") == "parquet" + assert ( + p._resolve_loader_file_format("some", "some", destcp, stagecp, None) + == "insert_values" + ) + assert ( + p._resolve_loader_file_format("some", "some", destcp, stagecp, "parquet") + == "parquet" + ) # check incompatible staging destcp.supported_staging_file_formats = ["insert_values", "csv"] # type: ignore[list-item] diff --git a/tests/pipeline/test_pipeline_state.py b/tests/pipeline/test_pipeline_state.py index f0bcda2717..76d6765a60 100644 --- a/tests/pipeline/test_pipeline_state.py +++ b/tests/pipeline/test_pipeline_state.py @@ -12,7 +12,10 @@ from dlt.common.utils import uniq_id from dlt.common.destination.reference import Destination -from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException, PipelineStepFailed +from dlt.pipeline.exceptions import ( + PipelineStateEngineNoUpgradePathException, + PipelineStepFailed, +) from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.state_sync import ( generate_pipeline_state_version_hash, @@ -41,8 +44,12 @@ def some_data_resource_state(): def test_restore_state_props() -> None: p = dlt.pipeline( pipeline_name="restore_state_props", - destination=Destination.from_reference("redshift", destination_name="redshift_name"), - staging=Destination.from_reference("filesystem", destination_name="filesystem_name"), + destination=Destination.from_reference( + "redshift", destination_name="redshift_name" + ), + staging=Destination.from_reference( + "filesystem", destination_name="filesystem_name" + ), dataset_name="the_dataset", ) p.extract(some_data()) @@ -239,7 +246,10 @@ def some_source(): def test_unmanaged_state_no_pipeline() -> None: list(some_data()) print(state_module._last_full_state) - assert state_module._last_full_state["sources"]["test_pipeline_state"]["last_value"] == 1 + assert ( + state_module._last_full_state["sources"]["test_pipeline_state"]["last_value"] + == 1 + ) def _gen_inner(): dlt.current.state()["gen"] = True @@ -270,9 +280,9 @@ def _gen_inner(): r = dlt.resource(_gen_inner(), name="name_ovrd") assert list(r) == [1] assert ( - state_module._last_full_state["sources"][p._make_schema_with_default_name().name][ - "resources" - ]["name_ovrd"]["gen"] + state_module._last_full_state["sources"][ + p._make_schema_with_default_name().name + ]["resources"]["name_ovrd"]["gen"] is True ) with pytest.raises(ResourceNameNotAvailable): @@ -295,9 +305,9 @@ def _gen_inner(tv="df"): p.extract(r) assert r.state["gen"] == "gen_tf" assert ( - state_module._last_full_state["sources"][p.default_schema_name]["resources"]["name_ovrd"][ - "gen" - ] + state_module._last_full_state["sources"][p.default_schema_name]["resources"][ + "name_ovrd" + ]["gen"] == "gen_tf" ) with pytest.raises(ResourceNameNotAvailable): @@ -393,7 +403,11 @@ def _gen_inner(item): # p = dlt.pipeline() # p.extract(dlt.transformer(_gen_inner, data_from=r, name="tx_other_name")) - assert list(dlt.transformer(_gen_inner, data_from=r, name="tx_other_name")) == [2, 4, 6] + assert list(dlt.transformer(_gen_inner, data_from=r, name="tx_other_name")) == [ + 2, + 4, + 6, + ] assert ( state_module._last_full_state["sources"]["test_pipeline_state"]["resources"][ "some_data_resource_state" @@ -413,7 +427,9 @@ def _gen_inner_rv(item): return item * 2 r = some_data_resource_state() - assert list(dlt.transformer(_gen_inner_rv, data_from=r, name="tx_other_name_rv")) == [ + assert list( + dlt.transformer(_gen_inner_rv, data_from=r, name="tx_other_name_rv") + ) == [ 1, 2, 3, @@ -437,7 +453,13 @@ def _gen_inner_rv_defer(item): r = some_data_resource_state() # not available because executed in a pool with pytest.raises(ResourceNameNotAvailable): - print(list(dlt.transformer(_gen_inner_rv_defer, data_from=r, name="tx_other_name_defer"))) + print( + list( + dlt.transformer( + _gen_inner_rv_defer, data_from=r, name="tx_other_name_defer" + ) + ) + ) # async transformer async def _gen_inner_rv_async(item): @@ -447,7 +469,13 @@ async def _gen_inner_rv_async(item): r = some_data_resource_state() # not available because executed in a pool with pytest.raises(ResourceNameNotAvailable): - print(list(dlt.transformer(_gen_inner_rv_async, data_from=r, name="tx_other_name_async"))) + print( + list( + dlt.transformer( + _gen_inner_rv_async, data_from=r, name="tx_other_name_async" + ) + ) + ) # async transformer with explicit resource name async def _gen_inner_rv_async_name(item, r_name): @@ -456,9 +484,9 @@ async def _gen_inner_rv_async_name(item, r_name): r = some_data_resource_state() assert list( - dlt.transformer(_gen_inner_rv_async_name, data_from=r, name="tx_other_name_async")( - "tx_other_name_async" - ) + dlt.transformer( + _gen_inner_rv_async_name, data_from=r, name="tx_other_name_async" + )("tx_other_name_async") ) == [1, 2, 3] assert ( state_module._last_full_state["sources"]["test_pipeline_state"]["resources"][ @@ -489,7 +517,9 @@ def transform(item): def test_migrate_pipeline_state(test_storage: FileStorage) -> None: # test generation of version hash on migration to v3 state_v1 = load_json_case("state/state.v1") - state = migrate_pipeline_state("test_pipeline", state_v1, state_v1["_state_engine_version"], 3) + state = migrate_pipeline_state( + "test_pipeline", state_v1, state_v1["_state_engine_version"], 3 + ) assert state["_state_engine_version"] == 3 assert "_local" in state assert "_version_hash" in state @@ -498,7 +528,10 @@ def test_migrate_pipeline_state(test_storage: FileStorage) -> None: # full migration state_v1 = load_json_case("state/state.v1") state = migrate_pipeline_state( - "test_pipeline", state_v1, state_v1["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION + "test_pipeline", + state_v1, + state_v1["_state_engine_version"], + PIPELINE_STATE_ENGINE_VERSION, ) assert state["_state_engine_version"] == PIPELINE_STATE_ENGINE_VERSION @@ -525,7 +558,9 @@ def test_migrate_pipeline_state(test_storage: FileStorage) -> None: json_case_path("state/state.v1"), test_storage.make_full_path(f"debug_pipeline/{Pipeline.STATE_FILE}"), ) - p = dlt.attach(pipeline_name="debug_pipeline", pipelines_dir=test_storage.storage_path) + p = dlt.attach( + pipeline_name="debug_pipeline", pipelines_dir=test_storage.storage_path + ) assert p.dataset_name == "debug_pipeline_data" assert p.default_schema_name == "example_source" state = p.state diff --git a/tests/pipeline/test_pipeline_trace.py b/tests/pipeline/test_pipeline_trace.py index cec578cb7b..88cab7624e 100644 --- a/tests/pipeline/test_pipeline_trace.py +++ b/tests/pipeline/test_pipeline_trace.py @@ -11,7 +11,9 @@ from dlt.common import json from dlt.common.configuration.specs import CredentialsConfiguration -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContext, +) from dlt.common.pipeline import ExtractInfo, NormalizeInfo, LoadInfo from dlt.common.schema import Schema from dlt.common.runtime.telemetry import stop_telemetry @@ -66,7 +68,9 @@ def data(): assert isinstance(step.started_at, datetime.datetime) assert isinstance(step.finished_at, datetime.datetime) assert isinstance(step.step_info, ExtractInfo) - assert step.step_info.extract_data_info == [{"name": "inject_tomls", "data_type": "source"}] + assert step.step_info.extract_data_info == [ + {"name": "inject_tomls", "data_type": "source"} + ] # check infos extract_info = p.last_trace.last_extract_info assert isinstance(extract_info, ExtractInfo) @@ -84,7 +88,10 @@ def data(): assert metrics["schema_name"] == "inject_tomls" # check dag and hints assert metrics["dag"] == [("data", "data")] - assert metrics["hints"]["data"] == {"write_disposition": "replace", "primary_key": "id"} + assert metrics["hints"]["data"] == { + "write_disposition": "replace", + "primary_key": "id", + } metrics = extract_info.metrics[load_id][1] # inject tomls and dlt state @@ -118,7 +125,9 @@ def data(): assert resolved.is_secret_hint is True assert resolved.value == "2137" assert resolved.default_value == "123" - resolved = _find_resolved_value(trace.resolved_config_values, "credentials", ["databricks"]) + resolved = _find_resolved_value( + trace.resolved_config_values, "credentials", ["databricks"] + ) assert resolved.is_secret_hint is True assert resolved.value == databricks_creds assert_trace_printable(trace) @@ -151,7 +160,9 @@ def data(): assert isinstance(step.step_exception, str) assert isinstance(step.step_info, ExtractInfo) assert len(step.exception_traces) > 0 - assert step.step_info.extract_data_info == [{"name": "async_exception", "data_type": "source"}] + assert step.step_info.extract_data_info == [ + {"name": "async_exception", "data_type": "source"} + ] assert_trace_printable(trace) extract_info = step.step_info @@ -175,7 +186,10 @@ def data(): assert step.step_info is norm_info assert_trace_printable(trace) assert isinstance(p.last_trace.last_normalize_info, NormalizeInfo) - assert p.last_trace.last_normalize_info.row_counts == {"_dlt_pipeline_state": 1, "data": 3} + assert p.last_trace.last_normalize_info.row_counts == { + "_dlt_pipeline_state": 1, + "data": 3, + } assert len(norm_info.loads_ids) == 1 load_id = norm_info.loads_ids[0] @@ -305,7 +319,10 @@ def test_trace_on_restore_state(environment: DictStrStr) -> None: environment["COMPLETED_PROB"] = "1.0" def _sync_destination_patch( - self: Pipeline, destination: str = None, staging: str = None, dataset_name: str = None + self: Pipeline, + destination: str = None, + staging: str = None, + dataset_name: str = None, ): # just wipe the pipeline simulating deleted dataset self._wipe_working_folder() @@ -330,9 +347,9 @@ def test_load_none_trace() -> None: def test_trace_telemetry() -> None: - with patch("dlt.common.runtime.sentry.before_send", _mock_sentry_before_send), patch( - "dlt.common.runtime.segment.before_send", _mock_segment_before_send - ): + with patch( + "dlt.common.runtime.sentry.before_send", _mock_sentry_before_send + ), patch("dlt.common.runtime.segment.before_send", _mock_segment_before_send): # os.environ["FAIL_PROB"] = "1.0" # make it complete immediately start_test_telemetry() @@ -363,7 +380,9 @@ def test_trace_telemetry() -> None: assert isinstance(event["properties"]["transaction_id"], str) # check extract info if step == "extract": - assert event["properties"]["extract_data"] == [{"name": "", "data_type": "int"}] + assert event["properties"]["extract_data"] == [ + {"name": "", "data_type": "int"} + ] if step == "load": # dummy has empty fingerprint assert event["properties"]["destination_fingerprint"] == "" @@ -407,7 +426,9 @@ def data(): assert event["properties"]["destination_type"] is None assert event["properties"]["pipeline_name_hash"] == digest128("fresh") assert event["properties"]["dataset_name_hash"] == digest128(p.dataset_name) - assert event["properties"]["default_schema_name_hash"] == digest128(p.default_schema_name) + assert event["properties"]["default_schema_name_hash"] == digest128( + p.default_schema_name + ) def test_extract_data_describe() -> None: @@ -426,7 +447,10 @@ def test_extract_data_describe() -> None: ] assert describe_extract_data( [DltResource(Pipe("rrr_extract"), None, False), DltSource(schema, "sect")] - ) == [{"name": "rrr_extract", "data_type": "resource"}, {"name": "test", "data_type": "source"}] + ) == [ + {"name": "rrr_extract", "data_type": "resource"}, + {"name": "test", "data_type": "source"}, + ] assert describe_extract_data([{"a": "b"}]) == [{"name": "", "data_type": "dict"}] from pandas import DataFrame @@ -436,8 +460,15 @@ def test_extract_data_describe() -> None: ] # first unnamed element in the list breaks checking info assert describe_extract_data( - [DltResource(Pipe("rrr_extract"), None, False), DataFrame(), DltSource(schema, "sect")] - ) == [{"name": "rrr_extract", "data_type": "resource"}, {"name": "", "data_type": "DataFrame"}] + [ + DltResource(Pipe("rrr_extract"), None, False), + DataFrame(), + DltSource(schema, "sect"), + ] + ) == [ + {"name": "rrr_extract", "data_type": "resource"}, + {"name": "", "data_type": "DataFrame"}, + ] def test_slack_hook(environment: DictStrStr) -> None: @@ -449,7 +480,9 @@ def test_slack_hook(environment: DictStrStr) -> None: environment["RUNTIME__SLACK_INCOMING_HOOK"] = hook_url with requests_mock.mock() as m: m.post(hook_url, json={}) - load_info = dlt.pipeline().run([1, 2, 3], table_name="data", destination="dummy") + load_info = dlt.pipeline().run( + [1, 2, 3], table_name="data", destination="dummy" + ) assert slack_notify_load_success(load_info.pipeline.runtime_config.slack_incoming_hook, load_info, load_info.pipeline.last_trace) == 200 # type: ignore[attr-defined] assert m.called message = m.last_request.json() diff --git a/tests/pipeline/test_platform_connection.py b/tests/pipeline/test_platform_connection.py index a0893cfc93..6922972330 100644 --- a/tests/pipeline/test_platform_connection.py +++ b/tests/pipeline/test_platform_connection.py @@ -67,7 +67,10 @@ def data(): assert state_result["pipeline_name"] == "platform_test_pipeline" assert state_result["dataset_name"] == "platform_test_dataset" assert len(state_result["schemas"]) == 2 - assert {state_result["schemas"][0]["name"], state_result["schemas"][1]["name"]} == { + assert { + state_result["schemas"][0]["name"], + state_result["schemas"][1]["name"], + } == { "first_source", "second_source", } diff --git a/tests/pipeline/test_resources_evaluation.py b/tests/pipeline/test_resources_evaluation.py index 5a85c06462..a67ea4dd17 100644 --- a/tests/pipeline/test_resources_evaluation.py +++ b/tests/pipeline/test_resources_evaluation.py @@ -276,7 +276,9 @@ def source(): @pytest.mark.parametrize( "n_resources,next_item_mode", product([8, 1, 5, 25, 20], ["fifo", "round_robin"]) ) -def test_parallelized_resource_extract_order(n_resources: int, next_item_mode: str) -> None: +def test_parallelized_resource_extract_order( + n_resources: int, next_item_mode: str +) -> None: os.environ["EXTRACT__NEXT_ITEM_MODE"] = next_item_mode threads = set() @@ -396,7 +398,10 @@ def some_source(): # Nothing runs in main thread assert len(threads) > 1 and threading.get_ident() not in threads - assert len(transformer_threads) > 1 and threading.get_ident() not in transformer_threads + assert ( + len(transformer_threads) > 1 + and threading.get_ident() not in transformer_threads + ) def test_parallelized_resource_bare_generator() -> None: diff --git a/tests/pipeline/test_schema_contracts.py b/tests/pipeline/test_schema_contracts.py index 2f2e6b6932..3fd5a5421c 100644 --- a/tests/pipeline/test_schema_contracts.py +++ b/tests/pipeline/test_schema_contracts.py @@ -81,7 +81,9 @@ def load_items(): def new_items(settings: TSchemaContract) -> Any: - @dlt.resource(name="new_items", write_disposition="append", schema_contract=settings) + @dlt.resource( + name="new_items", write_disposition="append", schema_contract=settings + ) def load_items(): for _, index in enumerate(range(0, 10), 1): yield {"id": index, "some_int": 1, "name": f"item {index}"} @@ -176,7 +178,9 @@ def test_new_tables( table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts.get("new_items", 0) == (10 if contract_setting in ["evolve"] else 0) + assert table_counts.get("new_items", 0) == ( + 10 if contract_setting in ["evolve"] else 0 + ) # delete extracted files if left after exception pipeline.drop_pending_packages() @@ -198,7 +202,9 @@ def test_new_tables( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) assert table_counts["items"] == 30 if contract_setting in ["freeze"] else 40 - assert table_counts.get(SUBITEMS_TABLE, 0) == (10 if contract_setting in ["evolve"] else 0) + assert table_counts.get(SUBITEMS_TABLE, 0) == ( + 10 if contract_setting in ["evolve"] else 0 + ) @pytest.mark.parametrize("contract_setting", schema_contract) @@ -228,7 +234,9 @@ def test_new_columns( # test adding new column twice: filter will try to catch it before it is added for the second time with raises_frozen_exception(contract_setting == "freeze"): - run_resource(pipeline, items_with_new_column, full_settings, item_format, duplicates=2) + run_resource( + pipeline, items_with_new_column, full_settings, item_format, duplicates=2 + ) # delete extracted files if left after exception pipeline.drop_pending_packages() @@ -307,11 +315,16 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None: if contract_setting == "evolve": assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] else: - assert VARIANT_COLUMN_NAME not in pipeline.default_schema.tables["items"]["columns"] + assert ( + VARIANT_COLUMN_NAME + not in pipeline.default_schema.tables["items"]["columns"] + ) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == (40 if contract_setting in ["evolve", "discard_value"] else 30) + assert table_counts["items"] == ( + 40 if contract_setting in ["evolve", "discard_value"] else 30 + ) def test_settings_precedence() -> None: @@ -321,7 +334,9 @@ def test_settings_precedence() -> None: run_resource(pipeline, items, {}) # trying to add new column when forbidden on resource will fail - run_resource(pipeline, items_with_new_column, {"resource": {"columns": "discard_row"}}) + run_resource( + pipeline, items_with_new_column, {"resource": {"columns": "discard_row"}} + ) # when allowed on override it will work run_resource( @@ -387,14 +402,18 @@ def test_change_mode(setting_location: str) -> None: assert table_counts["items"] == 10 # trying to add variant when forbidden will fail - run_resource(pipeline, items_with_variant, {setting_location: {"data_type": "discard_row"}}) + run_resource( + pipeline, items_with_variant, {setting_location: {"data_type": "discard_row"}} + ) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) assert table_counts["items"] == 10 # now allow - run_resource(pipeline, items_with_variant, {setting_location: {"data_type": "evolve"}}) + run_resource( + pipeline, items_with_variant, {setting_location: {"data_type": "evolve"}} + ) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) @@ -504,7 +523,9 @@ def get_items(): yield {"id": 1, "name": "dave", "amount": 50} yield {"id": 2, "name": "dave", "amount": 50, "new_column": "some val"} - pipeline.run([get_items()], schema_contract={"columns": "freeze", "tables": "evolve"}) + pipeline.run( + [get_items()], schema_contract={"columns": "freeze", "tables": "evolve"} + ) assert pipeline.last_trace.last_normalize_info.row_counts["items"] == 2 @@ -540,7 +561,9 @@ def get_items(): def test_defined_column_in_new_table(column_mode: str) -> None: pipeline = get_pipeline() - @dlt.resource(name="items", columns=[{"name": "id", "data_type": "bigint", "nullable": False}]) + @dlt.resource( + name="items", columns=[{"name": "id", "data_type": "bigint", "nullable": False}] + ) def get_items(): yield { "id": 1, @@ -559,7 +582,9 @@ def test_new_column_from_hint_and_data(column_mode: str) -> None: # normalizer does not know that it is a new table and discards the row # and it also excepts on column freeze - @dlt.resource(name="items", columns=[{"name": "id", "data_type": "bigint", "nullable": False}]) + @dlt.resource( + name="items", columns=[{"name": "id", "data_type": "bigint", "nullable": False}] + ) def get_items(): yield { "id": 1, diff --git a/tests/pipeline/test_schema_updates.py b/tests/pipeline/test_schema_updates.py index be397f796c..39488e65ff 100644 --- a/tests/pipeline/test_schema_updates.py +++ b/tests/pipeline/test_schema_updates.py @@ -5,7 +5,9 @@ def test_schema_updates() -> None: os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately - p = dlt.pipeline(pipeline_name="test_schema_updates", full_refresh=True, destination="dummy") + p = dlt.pipeline( + pipeline_name="test_schema_updates", full_refresh=True, destination="dummy" + ) @dlt.source() def source(): diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index 94683e4995..32c33e255b 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -47,7 +47,9 @@ def load_table_counts(p: dlt.Pipeline, *table_names: str) -> DictStrAny: # try sql, could be other destination though try: with p.sql_client() as c: - qualified_names = [c.make_qualified_table_name(name) for name in table_names] + qualified_names = [ + c.make_qualified_table_name(name) for name in table_names + ] query = "\nUNION ALL\n".join( [ f"SELECT '{name}' as name, COUNT(1) as c FROM {q_name}" @@ -164,7 +166,9 @@ def load_files(p: dlt.Pipeline, *table_names: str) -> Dict[str, List[Dict[str, A return result -def load_tables_to_dicts(p: dlt.Pipeline, *table_names: str) -> Dict[str, List[Dict[str, Any]]]: +def load_tables_to_dicts( + p: dlt.Pipeline, *table_names: str +) -> Dict[str, List[Dict[str, Any]]]: # try sql, could be other destination though try: result = {} @@ -195,7 +199,8 @@ def load_table_distinct_counts( """Returns counts of distinct values for column `distinct_column` for `table_names` as dict""" query = "\nUNION ALL\n".join( [ - f"SELECT '{name}' as name, COUNT(DISTINCT {distinct_column}) as c FROM {name}" + f"SELECT '{name}' as name, COUNT(DISTINCT {distinct_column}) as c FROM" + f" {name}" for name in table_names ] ) diff --git a/tests/reflection/test_script_inspector.py b/tests/reflection/test_script_inspector.py index 0769a2aa82..113941e370 100644 --- a/tests/reflection/test_script_inspector.py +++ b/tests/reflection/test_script_inspector.py @@ -15,8 +15,12 @@ def test_import_init_module() -> None: with pytest.raises(ModuleNotFoundError): - load_script_module("./tests/reflection/", "module_cases", ignore_missing_imports=False) - m = load_script_module("./tests/reflection/", "module_cases", ignore_missing_imports=True) + load_script_module( + "./tests/reflection/", "module_cases", ignore_missing_imports=False + ) + m = load_script_module( + "./tests/reflection/", "module_cases", ignore_missing_imports=True + ) assert isinstance(m.xxx, DummyModule) assert isinstance(m.a1, SimpleNamespace) @@ -42,7 +46,9 @@ def test_import_module() -> None: def test_import_module_with_missing_dep_exc() -> None: # will ignore MissingDependencyException - m = load_script_module(MODULE_CASES, "dlt_import_exception", ignore_missing_imports=True) + m = load_script_module( + MODULE_CASES, "dlt_import_exception", ignore_missing_imports=True + ) assert isinstance(m.e, SimpleNamespace) @@ -55,7 +61,9 @@ def test_import_module_capitalized_as_type() -> None: def test_import_wrong_pipeline_script() -> None: with pytest.raises(PipelineIsRunning): - inspect_pipeline_script(MODULE_CASES, "executes_resource", ignore_missing_imports=False) + inspect_pipeline_script( + MODULE_CASES, "executes_resource", ignore_missing_imports=False + ) def test_package_dummy_clash() -> None: @@ -63,7 +71,9 @@ def test_package_dummy_clash() -> None: # so if do not recognize package names with following condition (mind the dot): # if any(name == m or name.startswith(m + ".") for m in missing_modules): # we would return dummy for the whole module - m = load_script_module(MODULE_CASES, "stripe_analytics_pipeline", ignore_missing_imports=True) + m = load_script_module( + MODULE_CASES, "stripe_analytics_pipeline", ignore_missing_imports=True + ) # and those would fails assert m.VALUE == 1 assert m.HELPERS_VALUE == 3 diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py index 4d086f1486..4eb9ebe5aa 100644 --- a/tests/sources/helpers/rest_client/test_paginators.py +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -109,7 +109,9 @@ def test_update_state(self, test_case): { "next_reference": "/api/resource/subresource?page=3&sort=desc", "request_url": "http://example.com/api/resource/subresource", - "expected": "http://example.com/api/resource/subresource?page=3&sort=desc", + "expected": ( + "http://example.com/api/resource/subresource?page=3&sort=desc" + ), }, # Test with 'page' in path { @@ -121,7 +123,9 @@ def test_update_state(self, test_case): { "next_reference": "/api/resource?page=3&category=books&sort=author", "request_url": "http://example.com/api/resource?page=2", - "expected": "http://example.com/api/resource?page=3&category=books&sort=author", + "expected": ( + "http://example.com/api/resource?page=3&category=books&sort=author" + ), }, # Test with URL having port number { diff --git a/tests/sources/helpers/test_requests.py b/tests/sources/helpers/test_requests.py index aefdf23e77..a46461adb3 100644 --- a/tests/sources/helpers/test_requests.py +++ b/tests/sources/helpers/test_requests.py @@ -124,7 +124,11 @@ def _no_content(resp: requests.Response, *args, **kwargs) -> requests.Response: @pytest.mark.parametrize( "exception_class", - [requests.ConnectionError, requests.ConnectTimeout, requests.exceptions.ChunkedEncodingError], + [ + requests.ConnectionError, + requests.ConnectTimeout, + requests.exceptions.ChunkedEncodingError, + ], ) def test_retry_on_exception_all_fails( exception_class: Type[Exception], mock_sleep: mock.MagicMock diff --git a/tests/utils.py b/tests/utils.py index 00523486ea..6ee6bf59db 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -49,7 +49,14 @@ "synapse", "databricks", } -NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant", "destination"} +NON_SQL_DESTINATIONS = { + "filesystem", + "weaviate", + "dummy", + "motherduck", + "qdrant", + "destination", +} SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS # exclude destination configs (for now used for athena and athena iceberg separation) @@ -59,7 +66,9 @@ # filter out active destinations for current tests -ACTIVE_DESTINATIONS = set(dlt.config.get("ACTIVE_DESTINATIONS", list) or IMPLEMENTED_DESTINATIONS) +ACTIVE_DESTINATIONS = set( + dlt.config.get("ACTIVE_DESTINATIONS", list) or IMPLEMENTED_DESTINATIONS +) ACTIVE_SQL_DESTINATIONS = SQL_DESTINATIONS.intersection(ACTIVE_DESTINATIONS) ACTIVE_NON_SQL_DESTINATIONS = NON_SQL_DESTINATIONS.intersection(ACTIVE_DESTINATIONS) @@ -68,13 +77,19 @@ assert len(ACTIVE_DESTINATIONS) >= 0, "No active destinations selected" for destination in NON_SQL_DESTINATIONS: - assert destination in IMPLEMENTED_DESTINATIONS, f"Unknown non sql destination {destination}" + assert ( + destination in IMPLEMENTED_DESTINATIONS + ), f"Unknown non sql destination {destination}" for destination in SQL_DESTINATIONS: - assert destination in IMPLEMENTED_DESTINATIONS, f"Unknown sql destination {destination}" + assert ( + destination in IMPLEMENTED_DESTINATIONS + ), f"Unknown sql destination {destination}" for destination in ACTIVE_DESTINATIONS: - assert destination in IMPLEMENTED_DESTINATIONS, f"Unknown active destination {destination}" + assert ( + destination in IMPLEMENTED_DESTINATIONS + ), f"Unknown active destination {destination}" # possible TDataItem types @@ -212,7 +227,10 @@ def data_item_length(data: TDataItem) -> int: if isinstance(data, list): # If data is a list, check if it's a list of supported data types - if all(isinstance(item, (list, pd.DataFrame, pa.Table, pa.RecordBatch)) for item in data): + if all( + isinstance(item, (list, pd.DataFrame, pa.Table, pa.RecordBatch)) + for item in data + ): return sum(data_item_length(item) for item in data) # If it's a list but not a list of supported types, treat it as a single list object else: @@ -265,15 +283,21 @@ def assert_no_dict_key_starts_with(d: StrAny, key_prefix: str) -> None: def skip_if_not_active(destination: str) -> None: - assert destination in IMPLEMENTED_DESTINATIONS, f"Unknown skipped destination {destination}" + assert ( + destination in IMPLEMENTED_DESTINATIONS + ), f"Unknown skipped destination {destination}" if destination not in ACTIVE_DESTINATIONS: - pytest.skip(f"{destination} not in ACTIVE_DESTINATIONS", allow_module_level=True) + pytest.skip( + f"{destination} not in ACTIVE_DESTINATIONS", allow_module_level=True + ) def is_running_in_github_fork() -> bool: """Check if executed by GitHub Actions, in a repo fork.""" is_github_actions = os.environ.get("GITHUB_ACTIONS") == "true" - is_fork = os.environ.get("IS_FORK") == "true" # custom var set by us in the workflow's YAML + is_fork = ( + os.environ.get("IS_FORK") == "true" + ) # custom var set by us in the workflow's YAML return is_github_actions and is_fork @@ -285,12 +309,15 @@ def is_running_in_github_fork() -> bool: platform.python_implementation() == "PyPy", reason="won't run in PyPy interpreter" ) -skipifnotwindows = pytest.mark.skipif(platform.system() != "Windows", reason="runs only on windows") +skipifnotwindows = pytest.mark.skipif( + platform.system() != "Windows", reason="runs only on windows" +) skipifwindows = pytest.mark.skipif( platform.system() == "Windows", reason="does not runs on windows" ) skipifgithubfork = pytest.mark.skipif( - is_running_in_github_fork(), reason="Skipping test because it runs on a PR coming from fork" + is_running_in_github_fork(), + reason="Skipping test because it runs on a PR coming from fork", )