diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index 2490b579a..a3d39f471 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -20,7 +20,7 @@ jobs: if: "!contains(github.event.pull_request.labels.*.name, 'ci:skip-changelog') && github.event.pull_request.user.login != 'pre-commit-ci[bot]' && github.event.pull_request.user.login != 'dependabot[bot]'" steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 22bdc6b84..434905856 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -27,7 +27,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python ${{ env.DEFAULT_PYTHON }} uses: actions/setup-python@v4 @@ -84,7 +84,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python ${{ env.DEFAULT_PYTHON }} uses: actions/setup-python@v4 diff --git a/.github/workflows/dev-release.yml b/.github/workflows/dev-release.yml index fc10af392..76f362733 100644 --- a/.github/workflows/dev-release.yml +++ b/.github/workflows/dev-release.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 diff --git a/.github/workflows/get-matrix.yml b/.github/workflows/get-matrix.yml index d487e64e9..fd7e24aae 100644 --- a/.github/workflows/get-matrix.yml +++ b/.github/workflows/get-matrix.yml @@ -72,7 +72,7 @@ jobs: matrix-webdav: ${{ toJson(fromJson(steps.matrix-webdav.outputs.result)[steps.key-webdav.outputs.key]) }} steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 2 diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index c8bd83c70..3286c5d34 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -342,7 +342,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python ${{ env.DEFAULT_PYTHON }} uses: actions/setup-python@v4 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 39ed60801..e0d227ea9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -23,7 +23,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 diff --git a/.github/workflows/test-clickhouse.yml b/.github/workflows/test-clickhouse.yml index bb36c5b0d..42170ed2e 100644 --- a/.github/workflows/test-clickhouse.yml +++ b/.github/workflows/test-clickhouse.yml @@ -40,7 +40,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-core.yml b/.github/workflows/test-core.yml index 6d8cc93bd..0ed807efa 100644 --- a/.github/workflows/test-core.yml +++ b/.github/workflows/test-core.yml @@ -26,7 +26,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 diff --git a/.github/workflows/test-ftp.yml b/.github/workflows/test-ftp.yml index 000e200a6..4ed90ec89 100644 --- a/.github/workflows/test-ftp.yml +++ b/.github/workflows/test-ftp.yml @@ -23,7 +23,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 diff --git a/.github/workflows/test-ftps.yml b/.github/workflows/test-ftps.yml index f67abe96f..741bd7a0e 100644 --- a/.github/workflows/test-ftps.yml +++ b/.github/workflows/test-ftps.yml @@ -23,7 +23,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 diff --git a/.github/workflows/test-greenplum.yml b/.github/workflows/test-greenplum.yml index 0a7812ec9..67026aab0 100644 --- a/.github/workflows/test-greenplum.yml +++ b/.github/workflows/test-greenplum.yml @@ -42,7 +42,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-hdfs.yml b/.github/workflows/test-hdfs.yml index da4cbdc57..e48f4dd0a 100644 --- a/.github/workflows/test-hdfs.yml +++ b/.github/workflows/test-hdfs.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-hive.yml b/.github/workflows/test-hive.yml index 939fee079..cce56d675 100644 --- a/.github/workflows/test-hive.yml +++ b/.github/workflows/test-hive.yml @@ -26,7 +26,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-kafka.yml b/.github/workflows/test-kafka.yml index ffd7c388d..b6641557b 100644 --- a/.github/workflows/test-kafka.yml +++ b/.github/workflows/test-kafka.yml @@ -73,7 +73,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-local-fs.yml b/.github/workflows/test-local-fs.yml index 57873a5c9..f23beeb49 100644 --- a/.github/workflows/test-local-fs.yml +++ b/.github/workflows/test-local-fs.yml @@ -26,7 +26,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-mongodb.yml b/.github/workflows/test-mongodb.yml index 0d9d80cca..38086671e 100644 --- a/.github/workflows/test-mongodb.yml +++ b/.github/workflows/test-mongodb.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-mssql.yml b/.github/workflows/test-mssql.yml index b84c2f7f1..d3b2a21a8 100644 --- a/.github/workflows/test-mssql.yml +++ b/.github/workflows/test-mssql.yml @@ -41,7 +41,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-mysql.yml b/.github/workflows/test-mysql.yml index c7d4937b5..16745e904 100644 --- a/.github/workflows/test-mysql.yml +++ b/.github/workflows/test-mysql.yml @@ -40,7 +40,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-oracle.yml b/.github/workflows/test-oracle.yml index 213d555cd..dcd51b42a 100644 --- a/.github/workflows/test-oracle.yml +++ b/.github/workflows/test-oracle.yml @@ -43,7 +43,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-postgres.yml b/.github/workflows/test-postgres.yml index 819d3f533..30c91dfca 100644 --- a/.github/workflows/test-postgres.yml +++ b/.github/workflows/test-postgres.yml @@ -39,7 +39,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-s3.yml b/.github/workflows/test-s3.yml index f2005b15c..99a51269e 100644 --- a/.github/workflows/test-s3.yml +++ b/.github/workflows/test-s3.yml @@ -40,7 +40,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-sftp.yml b/.github/workflows/test-sftp.yml index e22f2301c..bd630710b 100644 --- a/.github/workflows/test-sftp.yml +++ b/.github/workflows/test-sftp.yml @@ -33,7 +33,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 diff --git a/.github/workflows/test-teradata.yml b/.github/workflows/test-teradata.yml index 31ec3712b..20ef294b7 100644 --- a/.github/workflows/test-teradata.yml +++ b/.github/workflows/test-teradata.yml @@ -26,7 +26,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Java ${{ inputs.java-version }} uses: actions/setup-java@v3 diff --git a/.github/workflows/test-webdav.yml b/.github/workflows/test-webdav.yml index f2e6acf5c..fda365489 100644 --- a/.github/workflows/test-webdav.yml +++ b/.github/workflows/test-webdav.yml @@ -23,7 +23,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f027fb99d..44125d701 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -325,7 +325,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python ${{ env.DEFAULT_PYTHON }} uses: actions/setup-python@v4 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ba0ab0b60..fd0c89d6b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -69,7 +69,7 @@ repos: - id: black language_version: python3 - repo: https://github.com/asottile/blacken-docs - rev: 1.15.0 + rev: 1.16.0 hooks: - id: blacken-docs - repo: meta @@ -77,7 +77,7 @@ repos: - id: check-hooks-apply - id: check-useless-excludes - repo: https://github.com/PyCQA/autoflake - rev: v2.2.0 + rev: v2.2.1 hooks: - id: autoflake args: diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index c8e1fb52c..4dc2d824f 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -170,7 +170,7 @@ Without docker-compose To run Greenplum tests, you should: - * Download `Pivotal connector for Spark `_ + * Download `Pivotal connector for Spark `_ * Either move it to ``~/.ivy2/jars/``, or pass file path to ``CLASSPATH`` * Set environment variable ``ONETL_DB_WITH_GREENPLUM=true`` to enable adding connector to Spark session diff --git a/README.rst b/README.rst index a98368684..13b280830 100644 --- a/README.rst +++ b/README.rst @@ -321,7 +321,7 @@ Read data from MSSQL, transform & write to Hive. extra={"ApplicationIntent": "ReadOnly"}, ).check() - # >>> INFO:|MSSQL| Connection is available. + # >>> INFO:|MSSQL| Connection is available # Initialize DB reader reader = DBReader( @@ -408,7 +408,7 @@ Download files from SFTP & upload them to HDFS. password="somepassword", ).check() - # >>> INFO:|SFTP| Connection is available. + # >>> INFO:|SFTP| Connection is available # Initialize downloader file_downloader = FileDownloader( @@ -546,7 +546,7 @@ Read files directly from S3 path, convert them to dataframe, transform it and th spark=spark, ).check() - # >>> INFO:|SparkS3| Connection is available. + # >>> INFO:|SparkS3| Connection is available # Describe file format and parsing options csv = CSV( @@ -577,7 +577,7 @@ Read files directly from S3 path, convert them to dataframe, transform it and th ], ) - # Initialize file reader + # Initialize file df reader reader = FileDFReader( connection=spark_s3, source_path="/remote/tests/Report", # path on S3 there *.csv files are located diff --git a/codecov.yml b/codecov.yml index 6087b0edf..7291c7aad 100644 --- a/codecov.yml +++ b/codecov.yml @@ -2,86 +2,5 @@ coverage: status: project: default: - target: 93% + target: 94% threshold: 1% - -flags: - core: - paths: - - onetl/base/base_connection.py - - onetl/hwm/** - - onetl/impl/*model*.py - - onetl/impl/*options*.py - - onetl/strategy/** - - onetl/hooks/** - - onetl/plugins/** - - onetl/exception.py - - onetl/log.py - - onetl/_internal.py - db: - paths: - - onetl/base/*db*.py - - onetl/base/*df*.py - - onetl/core/db*/** - - onetl/db_connection/db_connection.py - - onetl/db_connection/dialect_mixins/** - - onetl/db_connection/jdbc*.py - clickhouse: - paths: - - onetl/db_connection/clickhouse.py - greenplum: - paths: - - onetl/db_connection/greenplum.py - carryforward: true # if someone creates pull request from a fork, do not fail if Greenplum coverage is 0% - hive: - paths: - - onetl/db_connection/hive.py - mongodb: - paths: - - onetl/db_connection/mongodb.py - mssql: - paths: - - onetl/db_connection/mongodb.py - mysql: - paths: - - onetl/db_connection/mongodb.py - oracle: - paths: - - onetl/db_connection/oracle.py - postgres: - paths: - - onetl/db_connection/postgres.py - teradata: - paths: - - onetl/db_connection/teradata.py - file: - paths: - - onetl/base/*file*.py - - onetl/base/*path*.py - - onetl/base/contains_exception.py - - onetl/core/file*/** - - onetl/core/kerberos_helpers.py - - onetl/file_connection/file_connection.py - - onetl/impl/*path*.py - - onetl/impl/*file*.py - - onetl/impl/*directory*.py - ftp: - paths: - - onetl/file_connection/ftp.py - ftps: - paths: - - onetl/file_connection/ftps.py - hdfs: - paths: - - onetl/file_connection/hdfs.py - s3: - paths: - - onetl/file_connection/s3.py - sftp: - paths: - - onetl/file_connection/sftp.py - webdav: - paths: - - onetl/file_connection/webdav.py - nightly: - joined: false diff --git a/docker-compose.yml b/docker-compose.yml index d6f796d31..a08d8fc38 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -33,7 +33,7 @@ services: - onetl clickhouse: - image: ${CLICKHOUSE_IMAGE:-clickhouse/clickhouse-server:latest} + image: ${CLICKHOUSE_IMAGE:-clickhouse/clickhouse-server:latest-alpine} restart: unless-stopped ports: - 8123:8123 @@ -101,7 +101,7 @@ services: - onetl postgres: - image: ${POSTGRES_IMAGE:-postgres:15.2} + image: ${POSTGRES_IMAGE:-postgres:15.2-alpine} restart: unless-stopped env_file: .env.dependencies ports: diff --git a/docs/changelog/0.9.2.rst b/docs/changelog/0.9.2.rst new file mode 100644 index 000000000..4a865da81 --- /dev/null +++ b/docs/changelog/0.9.2.rst @@ -0,0 +1,28 @@ +0.9.2 (2023-09-06) +================== + +Features +-------- + +- Add ``if_exists="ignore"`` and ``error`` to ``Greenplum.WriteOptions`` (:github:pull:`142`) + + +Improvements +------------ + +- Improve validation messages while writing dataframe to Kafka. (:github:pull:`131`) +- Improve documentation: + + * Add notes about reading and writing to database connections documentation + * Add notes about executing statements in JDBC and Greenplum connections + + +Bug Fixes +--------- + +- Fixed validation of ``headers`` column is written to Kafka with default ``Kafka.WriteOptions()`` - default value was ``False``, + but instead of raising an exception, column value was just ignored. (:github:pull:`131`) +- Fix reading data from Oracle with ``partitioningMode="range"`` without explicitly set ``lowerBound`` / ``upperBound``. (:github:pull:`133`) +- Update Kafka documentation with SSLProtocol usage. (:github:pull:`136`) +- Raise exception if someone tries to read data from Kafka topic which does not exist. (:github:pull:`138`) +- Allow to pass Kafka topics with name like ``some.topic.name`` to DBReader. Same for MongoDB collections. (:github:pull:`139`) diff --git a/docs/connection/db_connection/clickhouse.rst b/docs/connection/db_connection/clickhouse.rst deleted file mode 100644 index 9c358c653..000000000 --- a/docs/connection/db_connection/clickhouse.rst +++ /dev/null @@ -1,30 +0,0 @@ -.. _clickhouse: - -Clickhouse connection -===================== - -.. currentmodule:: onetl.connection.db_connection.clickhouse - -.. autosummary:: - - Clickhouse - Clickhouse.ReadOptions - Clickhouse.WriteOptions - Clickhouse.JDBCOptions - -.. autoclass:: Clickhouse - :members: get_packages, check, sql, fetch, execute - -.. currentmodule:: onetl.connection.db_connection.clickhouse.Clickhouse - -.. autopydantic_model:: ReadOptions - :members: fetchsize, partitioning_mode, partition_column, num_partitions, lower_bound, upper_bound, session_init_statement - :member-order: bysource - -.. autopydantic_model:: WriteOptions - :members: mode, batchsize, isolation_level, query_timeout - :member-order: bysource - -.. autopydantic_model:: JDBCOptions - :members: query_timeout, fetchsize - :member-order: bysource diff --git a/docs/connection/db_connection/clickhouse/connection.rst b/docs/connection/db_connection/clickhouse/connection.rst new file mode 100644 index 000000000..862c503e5 --- /dev/null +++ b/docs/connection/db_connection/clickhouse/connection.rst @@ -0,0 +1,9 @@ +.. _clickhouse-connection: + +Clickhouse connection +===================== + +.. currentmodule:: onetl.connection.db_connection.clickhouse.connection + +.. autoclass:: Clickhouse + :members: get_packages, check diff --git a/docs/connection/db_connection/clickhouse/execute.rst b/docs/connection/db_connection/clickhouse/execute.rst new file mode 100644 index 000000000..f43eb3995 --- /dev/null +++ b/docs/connection/db_connection/clickhouse/execute.rst @@ -0,0 +1,17 @@ +.. _clickhouse-execute: + +Executing statements in Clickhouse +================================== + +.. currentmodule:: onetl.connection.db_connection.clickhouse.connection + +.. automethod:: Clickhouse.fetch +.. automethod:: Clickhouse.execute +.. automethod:: Clickhouse.close + +.. currentmodule:: onetl.connection.db_connection.jdbc_mixin.options + +.. autopydantic_model:: JDBCOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/clickhouse/index.rst b/docs/connection/db_connection/clickhouse/index.rst new file mode 100644 index 000000000..1e0d1de65 --- /dev/null +++ b/docs/connection/db_connection/clickhouse/index.rst @@ -0,0 +1,18 @@ +.. _clickhouse: + +Clickhouse +========== + +.. toctree:: + :maxdepth: 1 + :caption: Connection + + connection + +.. toctree:: + :maxdepth: 1 + :caption: Operations + + read + write + execute diff --git a/docs/connection/db_connection/clickhouse/read.rst b/docs/connection/db_connection/clickhouse/read.rst new file mode 100644 index 000000000..a2c07733c --- /dev/null +++ b/docs/connection/db_connection/clickhouse/read.rst @@ -0,0 +1,22 @@ +.. _clickhouse-read: + +Reading from Clickhouse +======================= + +There are 2 ways of distributed data reading from Clickhouse: + +* Using :obj:`DBReader ` with different :ref:`strategy` +* Using :obj:`Clickhouse.sql ` + +Both methods accept :obj:`JDBCReadOptions ` + +.. currentmodule:: onetl.connection.db_connection.clickhouse.connection + +.. automethod:: Clickhouse.sql + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCReadOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/clickhouse/write.rst b/docs/connection/db_connection/clickhouse/write.rst new file mode 100644 index 000000000..97cc95d36 --- /dev/null +++ b/docs/connection/db_connection/clickhouse/write.rst @@ -0,0 +1,13 @@ +.. _clickhouse-write: + +Writing to Clickhouse +===================== + +For writing data to Clickhouse, use :obj:`DBWriter ` with options below. + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCWriteOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/greenplum/connection.rst b/docs/connection/db_connection/greenplum/connection.rst new file mode 100644 index 000000000..59bca8421 --- /dev/null +++ b/docs/connection/db_connection/greenplum/connection.rst @@ -0,0 +1,9 @@ +.. _greenplum-connection: + +Greenplum connection +==================== + +.. currentmodule:: onetl.connection.db_connection.greenplum.connection + +.. autoclass:: Greenplum + :members: get_packages, check diff --git a/docs/connection/db_connection/greenplum/execute.rst b/docs/connection/db_connection/greenplum/execute.rst new file mode 100644 index 000000000..b0833b213 --- /dev/null +++ b/docs/connection/db_connection/greenplum/execute.rst @@ -0,0 +1,17 @@ +.. _greenplum-execute: + +Executing statements in Greenplum +================================== + +.. currentmodule:: onetl.connection.db_connection.greenplum.connection + +.. automethod:: Greenplum.fetch +.. automethod:: Greenplum.execute +.. automethod:: Greenplum.close + +.. currentmodule:: onetl.connection.db_connection.jdbc_mixin.options + +.. autopydantic_model:: JDBCOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/greenplum/greenplum.rst b/docs/connection/db_connection/greenplum/greenplum.rst deleted file mode 100644 index c8192a823..000000000 --- a/docs/connection/db_connection/greenplum/greenplum.rst +++ /dev/null @@ -1,42 +0,0 @@ -.. _greenplum: - -Greenplum connector -==================== - -.. currentmodule:: onetl.connection.db_connection.greenplum - -.. autosummary:: - - Greenplum - Greenplum.ReadOptions - Greenplum.WriteOptions - Greenplum.JDBCOptions - -.. note:: - - Unlike JDBC connectors, *Greenplum connector for Spark* does not support - executing **custom** SQL queries using ``.sql`` method, because this leads to sending - the result through *master* node which is really bad for cluster performance. - - To make distributed queries like ``JOIN`` **on Greenplum side**, you should create a temporary table, - populate it with the data you need (using ``.execute`` method to call ``INSERT INTO ... AS SELECT ...``), - and then read the data from this table using :obj:`DBReader `. - - In this case data will be read directly from segment nodes in a distributed way - -.. autoclass:: Greenplum - :members: get_packages, check, fetch, execute, close - -.. currentmodule:: onetl.connection.db_connection.greenplum.Greenplum - -.. autopydantic_model:: ReadOptions - :members: partition_column, num_partitions - :member-order: bysource - -.. autopydantic_model:: WriteOptions - :members: mode - :member-order: bysource - -.. autopydantic_model:: JDBCOptions - :members: query_timeout, fetchsize - :member-order: bysource diff --git a/docs/connection/db_connection/greenplum/index.rst b/docs/connection/db_connection/greenplum/index.rst index b99888c9f..b4ff40331 100644 --- a/docs/connection/db_connection/greenplum/index.rst +++ b/docs/connection/db_connection/greenplum/index.rst @@ -1,11 +1,19 @@ .. _greenplum: -Greenplum connector +Greenplum ==================== .. toctree:: :maxdepth: 1 - :caption: Greenplum connector + :caption: Connection prerequisites - greenplum + connection + +.. toctree:: + :maxdepth: 1 + :caption: Operations + + read + write + execute diff --git a/docs/connection/db_connection/greenplum/prerequisites.rst b/docs/connection/db_connection/greenplum/prerequisites.rst index 166a936aa..964d9cdcf 100644 --- a/docs/connection/db_connection/greenplum/prerequisites.rst +++ b/docs/connection/db_connection/greenplum/prerequisites.rst @@ -30,7 +30,7 @@ Downloading Pivotal package --------------------------- To use Greenplum connector you should download connector ``.jar`` file from -`Pivotal website `_ +`Pivotal website `_ and then pass it to Spark session. There are several ways to do that. diff --git a/docs/connection/db_connection/greenplum/read.rst b/docs/connection/db_connection/greenplum/read.rst new file mode 100644 index 000000000..2640f7e6c --- /dev/null +++ b/docs/connection/db_connection/greenplum/read.rst @@ -0,0 +1,31 @@ +.. _greenplum-read: + +Reading from Greenplum +======================= + +For reading data from Greenplum, use :obj:`DBReader ` with options below. + +.. note:: + + Unlike JDBC connectors, *Greenplum connector for Spark* does not support + executing **custom** SQL queries using ``.sql`` method, because this leads to sending + the result through *master* node which is really bad for cluster performance. + + To make distributed queries like ``JOIN`` **on Greenplum side**, you should create a staging table, + populate it with the data you need (using ``.execute`` method to call ``INSERT INTO ... AS SELECT ...``), + then read the data from this table using :obj:`DBReader `, + and drop staging table after reading is finished. + + In this case data will be read directly from Greenplum segment nodes in a distributed way. + +.. warning:: + + Greenplum connection does **NOT** support reading data from views which does not have ``gp_segment_id`` column. + Either add this column to a view, or use stating table solution (see above). + +.. currentmodule:: onetl.connection.db_connection.greenplum.options + +.. autopydantic_model:: GreenplumReadOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/greenplum/write.rst b/docs/connection/db_connection/greenplum/write.rst new file mode 100644 index 000000000..aeb688ac5 --- /dev/null +++ b/docs/connection/db_connection/greenplum/write.rst @@ -0,0 +1,13 @@ +.. _greenplum-write: + +Writing to Greenplum +===================== + +For writing data to Greenplum, use :obj:`DBWriter ` with options below. + +.. currentmodule:: onetl.connection.db_connection.greenplum.options + +.. autopydantic_model:: GreenplumWriteOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/hive.rst b/docs/connection/db_connection/hive.rst deleted file mode 100644 index 7cb09ad53..000000000 --- a/docs/connection/db_connection/hive.rst +++ /dev/null @@ -1,26 +0,0 @@ -.. _hive: - -Hive connection -=============== - -.. currentmodule:: onetl.connection.db_connection.hive - -.. autosummary:: - - Hive - Hive.WriteOptions - Hive.Slots - -.. autoclass:: Hive - :members: get_current, check, sql, execute - :member-order: bysource - -.. currentmodule:: onetl.connection.db_connection.hive.Hive - -.. autopydantic_model:: WriteOptions - :members: mode, format, partition_by, bucket_by, sort_by, compression - :member-order: bysource - -.. autoclass:: Slots - :members: normalize_cluster_name, get_known_clusters, get_current_cluster - :member-order: bysource diff --git a/docs/connection/db_connection/hive/connection.rst b/docs/connection/db_connection/hive/connection.rst new file mode 100644 index 000000000..cbc51eac3 --- /dev/null +++ b/docs/connection/db_connection/hive/connection.rst @@ -0,0 +1,10 @@ +.. _hive-connection: + +Hive Connection +=============== + +.. currentmodule:: onetl.connection.db_connection.hive.connection + +.. autoclass:: Hive + :members: get_current, check + :member-order: bysource diff --git a/docs/connection/db_connection/hive/execute.rst b/docs/connection/db_connection/hive/execute.rst new file mode 100644 index 000000000..ae32e61d2 --- /dev/null +++ b/docs/connection/db_connection/hive/execute.rst @@ -0,0 +1,8 @@ +.. _hive-execute: + +Executing statements in Hive +============================ + +.. currentmodule:: onetl.connection.db_connection.hive.connection + +.. automethod:: Hive.execute diff --git a/docs/connection/db_connection/hive/index.rst b/docs/connection/db_connection/hive/index.rst new file mode 100644 index 000000000..9dd900b07 --- /dev/null +++ b/docs/connection/db_connection/hive/index.rst @@ -0,0 +1,24 @@ +.. _hive: + +Hive +==== + +.. toctree:: + :maxdepth: 1 + :caption: Connection + + connection + +.. toctree:: + :maxdepth: 1 + :caption: Operations + + read + write + execute + +.. toctree:: + :maxdepth: 1 + :caption: For developers + + slots diff --git a/docs/connection/db_connection/hive/read.rst b/docs/connection/db_connection/hive/read.rst new file mode 100644 index 000000000..a9961b4ab --- /dev/null +++ b/docs/connection/db_connection/hive/read.rst @@ -0,0 +1,13 @@ +.. _hive-read: + +Reading from Hive +================= + +There are 2 ways of distributed data reading from Hive: + +* Using :obj:`DBReader ` with different :ref:`strategy` +* Using :obj:`Hive.sql ` + +.. currentmodule:: onetl.connection.db_connection.hive.connection + +.. automethod:: Hive.sql diff --git a/docs/connection/db_connection/hive/slots.rst b/docs/connection/db_connection/hive/slots.rst new file mode 100644 index 000000000..f60dc34e5 --- /dev/null +++ b/docs/connection/db_connection/hive/slots.rst @@ -0,0 +1,10 @@ +.. _hive-slots: + +Hive Slots +========== + +.. currentmodule:: onetl.connection.db_connection.hive.slots + +.. autoclass:: HiveSlots + :members: normalize_cluster_name, get_known_clusters, get_current_cluster + :member-order: bysource diff --git a/docs/connection/db_connection/hive/write.rst b/docs/connection/db_connection/hive/write.rst new file mode 100644 index 000000000..70c9f3099 --- /dev/null +++ b/docs/connection/db_connection/hive/write.rst @@ -0,0 +1,13 @@ +.. _hive-write: + +Writing to Hive +=============== + +For writing data to Hive, use :obj:`DBWriter ` with options below. + +.. currentmodule:: onetl.connection.db_connection.hive.options + +.. autopydantic_model:: HiveWriteOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/index.rst b/docs/connection/db_connection/index.rst index 3a4036379..429aaed71 100644 --- a/docs/connection/db_connection/index.rst +++ b/docs/connection/db_connection/index.rst @@ -7,13 +7,13 @@ DB Connections :maxdepth: 1 :caption: DB Connections - Clickhouse + Clickhouse Greenplum Kafka - Hive - MongoDB - MSSQL - MySQL - Oracle - Postgres - Teradata + Hive + MongoDB + MSSQL + MySQL + Oracle + Postgres + Teradata diff --git a/docs/connection/db_connection/kafka/index.rst b/docs/connection/db_connection/kafka/index.rst index 7edec8d10..a02aaecdc 100644 --- a/docs/connection/db_connection/kafka/index.rst +++ b/docs/connection/db_connection/kafka/index.rst @@ -5,11 +5,9 @@ Kafka .. toctree:: :maxdepth: 1 - :caption: Connection & options + :caption: Connection connection - read_options - write_options .. toctree:: :maxdepth: 1 @@ -26,6 +24,13 @@ Kafka kerberos_auth scram_auth +.. toctree:: + :maxdepth: 1 + :caption: Operations + + read + write + .. toctree:: :maxdepth: 1 :caption: For developers diff --git a/docs/connection/db_connection/kafka/read.rst b/docs/connection/db_connection/kafka/read.rst new file mode 100644 index 000000000..a19c5e57b --- /dev/null +++ b/docs/connection/db_connection/kafka/read.rst @@ -0,0 +1,72 @@ +.. _kafka-read: + +Reading from Kafka +================== + +For reading data from Kafka, use :obj:`DBReader ` with specific options (see below). + +.. warning:: + + Currently, Kafka does not support :ref:`strategy`. You can only read the whole topic. + +.. note:: + + Unlike other connection classes, Kafka always return dataframe with fixed schema + (see `documentation `_): + + .. dropdown:: DataFrame Schema + + .. code:: python + + from pyspark.sql.types import ( + ArrayType, + BinaryType, + IntegerType, + LongType, + StringType, + StructField, + StructType, + TimestampType, + ) + + schema = StructType( + [ + StructField("value", BinaryType(), nullable=True), + StructField("key", BinaryType(), nullable=True), + StructField("topic", StringType(), nullable=False), + StructField("partition", IntegerType(), nullable=False), + StructField("offset", LongType(), nullable=False), + StructField("timestamp", TimestampType(), nullable=False), + StructField("timestampType", IntegerType(), nullable=False), + # this field is returned only with ``include_headers=True`` + StructField( + "headers", + ArrayType( + StructType( + [ + StructField("key", StringType(), nullable=False), + StructField("value", BinaryType(), nullable=True), + ], + ), + ), + nullable=True, + ), + ], + ) + +.. warning:: + + Columns: + + * ``value`` + * ``key`` + * ``headers[*].value`` + + are always returned as raw bytes. If they contain values of custom type, these values should be deserialized manually. + +.. currentmodule:: onetl.connection.db_connection.kafka.options + +.. autopydantic_model:: KafkaReadOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/kafka/read_options.rst b/docs/connection/db_connection/kafka/read_options.rst deleted file mode 100644 index 37071c27d..000000000 --- a/docs/connection/db_connection/kafka/read_options.rst +++ /dev/null @@ -1,11 +0,0 @@ -.. _kafka-read-options: - -Kafka ReadOptions -================= - -.. currentmodule:: onetl.connection.db_connection.kafka.options - -.. autopydantic_model:: KafkaReadOptions - :member-order: bysource - :model-show-field-summary: false - :field-show-constraints: false diff --git a/docs/connection/db_connection/kafka/write.rst b/docs/connection/db_connection/kafka/write.rst new file mode 100644 index 000000000..eb04ecccb --- /dev/null +++ b/docs/connection/db_connection/kafka/write.rst @@ -0,0 +1,64 @@ +.. _kafka-write: + +Writing to Kafka +================ + +For writing data to Kafka, use :obj:`DBWriter ` with specific options (see below). + +.. note:: + + Unlike other connection classes, Kafka only accepts dataframe with fixed schema + (see `documentation `_): + + .. dropdown:: DataFrame Schema + + .. code:: python + + from pyspark.sql.types import ( + ArrayType, + BinaryType, + IntegerType, + StringType, + StructField, + StructType, + ) + + schema = StructType( + [ + # mandatory fields: + StructField("value", BinaryType(), nullable=True), + # optional fields, can be omitted: + StructField("key", BinaryType(), nullable=True), + StructField("partition", IntegerType(), nullable=True), + StructField( + "headers", + ArrayType( + StructType( + [ + StructField("key", StringType(), nullable=False), + StructField("value", BinaryType(), nullable=True), + ], + ), + ), + nullable=True, + ), + ], + ) + You cannot pass dataframe with other column names or types. + +.. warning:: + + Columns: + + * ``value`` + * ``key`` + * ``headers[*].value`` + + can only be string or raw bytes. If they contain values of custom type, these values should be serialized manually. + +.. currentmodule:: onetl.connection.db_connection.kafka.options + +.. autopydantic_model:: KafkaWriteOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/kafka/write_options.rst b/docs/connection/db_connection/kafka/write_options.rst deleted file mode 100644 index a3b678951..000000000 --- a/docs/connection/db_connection/kafka/write_options.rst +++ /dev/null @@ -1,11 +0,0 @@ -.. _kafka-write-options: - -Kafka WriteOptions -================== - -.. currentmodule:: onetl.connection.db_connection.kafka.options - -.. autopydantic_model:: KafkaWriteOptions - :member-order: bysource - :model-show-field-summary: false - :field-show-constraints: false diff --git a/docs/connection/db_connection/mongodb.rst b/docs/connection/db_connection/mongodb.rst deleted file mode 100644 index 2bc42d804..000000000 --- a/docs/connection/db_connection/mongodb.rst +++ /dev/null @@ -1,26 +0,0 @@ -.. _mongo: - -MongoDB connection -===================== - -.. currentmodule:: onetl.connection.db_connection.mongodb - -.. autosummary:: - - MongoDB - MongoDB.ReadOptions - MongoDB.WriteOptions - MongoDB.PipelineOptions - -.. autoclass:: MongoDB - :members: get_packages, check, pipeline - -.. currentmodule:: onetl.connection.db_connection.mongodb.MongoDB - -.. autopydantic_model:: ReadOptions - -.. autopydantic_model:: WriteOptions - :members: mode - :member-order: bysource - -.. autopydantic_model:: PipelineOptions diff --git a/docs/connection/db_connection/mongodb/connection.rst b/docs/connection/db_connection/mongodb/connection.rst new file mode 100644 index 000000000..1d0504609 --- /dev/null +++ b/docs/connection/db_connection/mongodb/connection.rst @@ -0,0 +1,10 @@ +.. _mongodb-connection: + +MongoDB Connection +================== + +.. currentmodule:: onetl.connection.db_connection.mongodb.connection + +.. autoclass:: MongoDB + :members: get_packages, check + :member-order: bysource diff --git a/docs/connection/db_connection/mongodb/index.rst b/docs/connection/db_connection/mongodb/index.rst new file mode 100644 index 000000000..a863265a4 --- /dev/null +++ b/docs/connection/db_connection/mongodb/index.rst @@ -0,0 +1,17 @@ +.. _mongodb: + +MongoDB +======= + +.. toctree:: + :maxdepth: 1 + :caption: Connection + + connection + +.. toctree:: + :maxdepth: 1 + :caption: Operations + + read + write diff --git a/docs/connection/db_connection/mongodb/read.rst b/docs/connection/db_connection/mongodb/read.rst new file mode 100644 index 000000000..b3d51686a --- /dev/null +++ b/docs/connection/db_connection/mongodb/read.rst @@ -0,0 +1,25 @@ +.. _mongodb-read: + +Reading from MongoDB +==================== + +There are 2 ways of distributed data reading from MongoDB: + +* Using :obj:`DBReader ` with different :ref:`strategy` and :obj:`MongoDBReadOptions ` +* Using :obj:`MongoDB.pipeline ` with :obj:`MongoDBPipelineOptions ` + +.. currentmodule:: onetl.connection.db_connection.mongodb.connection + +.. automethod:: MongoDB.pipeline + +.. currentmodule:: onetl.connection.db_connection.mongodb.options + +.. autopydantic_model:: MongoDBReadOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false + +.. autopydantic_model:: MongoDBPipelineOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/mongodb/write.rst b/docs/connection/db_connection/mongodb/write.rst new file mode 100644 index 000000000..5fff32d70 --- /dev/null +++ b/docs/connection/db_connection/mongodb/write.rst @@ -0,0 +1,13 @@ +.. _mongodb-write: + +Writing to MongoDB +================== + +For writing data to MongoDB, use :obj:`DBWriter ` with options below. + +.. currentmodule:: onetl.connection.db_connection.mongodb.options + +.. autopydantic_model:: MongoDBWriteOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/mssql.rst b/docs/connection/db_connection/mssql.rst deleted file mode 100644 index 5b3941b04..000000000 --- a/docs/connection/db_connection/mssql.rst +++ /dev/null @@ -1,30 +0,0 @@ -.. _mssql: - -MSSQL connection -================ - -.. currentmodule:: onetl.connection.db_connection.mssql - -.. autosummary:: - - MSSQL - MSSQL.ReadOptions - MSSQL.WriteOptions - MSSQL.JDBCOptions - -.. autoclass:: MSSQL - :members: get_packages, check, sql, fetch, execute, close - -.. currentmodule:: onetl.connection.db_connection.mssql.MSSQL - -.. autopydantic_model:: ReadOptions - :members: fetchsize, partitioning_mode, partition_column, num_partitions, lower_bound, upper_bound, session_init_statement - :member-order: bysource - -.. autopydantic_model:: WriteOptions - :members: mode, batchsize, isolation_level, query_timeout - :member-order: bysource - -.. autopydantic_model:: JDBCOptions - :members: query_timeout, fetchsize - :member-order: bysource diff --git a/docs/connection/db_connection/mssql/connection.rst b/docs/connection/db_connection/mssql/connection.rst new file mode 100644 index 000000000..15fe31260 --- /dev/null +++ b/docs/connection/db_connection/mssql/connection.rst @@ -0,0 +1,9 @@ +.. _mssql-connection: + +MSSQL connection +================ + +.. currentmodule:: onetl.connection.db_connection.mssql.connection + +.. autoclass:: MSSQL + :members: get_packages, check diff --git a/docs/connection/db_connection/mssql/execute.rst b/docs/connection/db_connection/mssql/execute.rst new file mode 100644 index 000000000..bed53fec5 --- /dev/null +++ b/docs/connection/db_connection/mssql/execute.rst @@ -0,0 +1,17 @@ +.. _mssql-execute: + +Executing statements in MSSQL +============================= + +.. currentmodule:: onetl.connection.db_connection.mssql.connection + +.. automethod:: MSSQL.fetch +.. automethod:: MSSQL.execute +.. automethod:: MSSQL.close + +.. currentmodule:: onetl.connection.db_connection.jdbc_mixin.options + +.. autopydantic_model:: JDBCOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/mssql/index.rst b/docs/connection/db_connection/mssql/index.rst new file mode 100644 index 000000000..5e511e83e --- /dev/null +++ b/docs/connection/db_connection/mssql/index.rst @@ -0,0 +1,18 @@ +.. _mssql: + +MSSQL +===== + +.. toctree:: + :maxdepth: 1 + :caption: Connection + + connection + +.. toctree:: + :maxdepth: 1 + :caption: Operations + + read + write + execute diff --git a/docs/connection/db_connection/mssql/read.rst b/docs/connection/db_connection/mssql/read.rst new file mode 100644 index 000000000..3a336f823 --- /dev/null +++ b/docs/connection/db_connection/mssql/read.rst @@ -0,0 +1,22 @@ +.. _mssql-read: + +Reading from MSSQL +================== + +There are 2 ways of distributed data reading from MSSQL: + +* Using :obj:`DBReader ` with different :ref:`strategy` +* Using :obj:`MSSQL.sql ` + +Both methods accept :obj:`JDBCReadOptions ` + +.. currentmodule:: onetl.connection.db_connection.mssql.connection + +.. automethod:: MSSQL.sql + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCReadOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/mssql/write.rst b/docs/connection/db_connection/mssql/write.rst new file mode 100644 index 000000000..c8a5e5906 --- /dev/null +++ b/docs/connection/db_connection/mssql/write.rst @@ -0,0 +1,13 @@ +.. _mssql-write: + +Writing to MSSQL +================ + +For writing data to MSSQL, use :obj:`DBWriter ` with options below. + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCWriteOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/mysql.rst b/docs/connection/db_connection/mysql.rst deleted file mode 100644 index 85f59403d..000000000 --- a/docs/connection/db_connection/mysql.rst +++ /dev/null @@ -1,30 +0,0 @@ -.. _mysql: - -MySQL connection -================= - -.. currentmodule:: onetl.connection.db_connection.mysql - -.. autosummary:: - - MySQL - MySQL.ReadOptions - MySQL.WriteOptions - MySQL.JDBCOptions - -.. autoclass:: MySQL - :members: get_packages, check, sql, fetch, execute, close - -.. currentmodule:: onetl.connection.db_connection.mysql.MySQL - -.. autopydantic_model:: ReadOptions - :members: fetchsize, partitioning_mode, partition_column, num_partitions, lower_bound, upper_bound, session_init_statement - :member-order: bysource - -.. autopydantic_model:: WriteOptions - :members: mode, batchsize, isolation_level, query_timeout - :member-order: bysource - -.. autopydantic_model:: JDBCOptions - :members: query_timeout, fetchsize - :member-order: bysource diff --git a/docs/connection/db_connection/mysql/connection.rst b/docs/connection/db_connection/mysql/connection.rst new file mode 100644 index 000000000..cfeb00206 --- /dev/null +++ b/docs/connection/db_connection/mysql/connection.rst @@ -0,0 +1,9 @@ +.. _mysql-connection: + +MySQL connection +================ + +.. currentmodule:: onetl.connection.db_connection.mysql.connection + +.. autoclass:: MySQL + :members: get_packages, check diff --git a/docs/connection/db_connection/mysql/execute.rst b/docs/connection/db_connection/mysql/execute.rst new file mode 100644 index 000000000..ec5d01482 --- /dev/null +++ b/docs/connection/db_connection/mysql/execute.rst @@ -0,0 +1,17 @@ +.. _mysql-execute: + +Executing statements in MySQL +============================= + +.. currentmodule:: onetl.connection.db_connection.mysql.connection + +.. automethod:: MySQL.fetch +.. automethod:: MySQL.execute +.. automethod:: MySQL.close + +.. currentmodule:: onetl.connection.db_connection.jdbc_mixin.options + +.. autopydantic_model:: JDBCOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/mysql/index.rst b/docs/connection/db_connection/mysql/index.rst new file mode 100644 index 000000000..e221165cd --- /dev/null +++ b/docs/connection/db_connection/mysql/index.rst @@ -0,0 +1,18 @@ +.. _mysql: + +MySQL +===== + +.. toctree:: + :maxdepth: 1 + :caption: Connection + + connection + +.. toctree:: + :maxdepth: 1 + :caption: Operations + + read + write + execute diff --git a/docs/connection/db_connection/mysql/read.rst b/docs/connection/db_connection/mysql/read.rst new file mode 100644 index 000000000..6a6960532 --- /dev/null +++ b/docs/connection/db_connection/mysql/read.rst @@ -0,0 +1,22 @@ +.. _mysql-read: + +Reading from MySQL +================== + +There are 2 ways of distributed data reading from MySQL: + +* Using :obj:`DBReader ` with different :ref:`strategy` +* Using :obj:`MySQL.sql ` + +Both methods accept :obj:`JDBCReadOptions ` + +.. currentmodule:: onetl.connection.db_connection.mysql.connection + +.. automethod:: MySQL.sql + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCReadOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/mysql/write.rst b/docs/connection/db_connection/mysql/write.rst new file mode 100644 index 000000000..67f13cf1b --- /dev/null +++ b/docs/connection/db_connection/mysql/write.rst @@ -0,0 +1,13 @@ +.. _mysql-write: + +Writing to MySQL +================ + +For writing data to MySQL, use :obj:`DBWriter ` with options below. + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCWriteOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/oracle.rst b/docs/connection/db_connection/oracle.rst deleted file mode 100644 index 45b3f0b84..000000000 --- a/docs/connection/db_connection/oracle.rst +++ /dev/null @@ -1,30 +0,0 @@ -.. _oracle: - -Oracle connection -================== - -.. currentmodule:: onetl.connection.db_connection.oracle - -.. autosummary:: - - Oracle - Oracle.ReadOptions - Oracle.WriteOptions - Oracle.JDBCOptions - -.. autoclass:: Oracle - :members: get_packages, check, sql, fetch, execute, close - -.. currentmodule:: onetl.connection.db_connection.oracle.Oracle - -.. autopydantic_model:: ReadOptions - :members: fetchsize, partitioning_mode, partition_column, num_partitions, lower_bound, upper_bound, session_init_statement - :member-order: bysource - -.. autopydantic_model:: WriteOptions - :members: mode, batchsize, isolation_level, query_timeout - :member-order: bysource - -.. autopydantic_model:: JDBCOptions - :members: query_timeout, fetchsize - :member-order: bysource diff --git a/docs/connection/db_connection/oracle/connection.rst b/docs/connection/db_connection/oracle/connection.rst new file mode 100644 index 000000000..25e544823 --- /dev/null +++ b/docs/connection/db_connection/oracle/connection.rst @@ -0,0 +1,9 @@ +.. _oracle-connection: + +Oracle connection +================= + +.. currentmodule:: onetl.connection.db_connection.oracle.connection + +.. autoclass:: Oracle + :members: get_packages, check diff --git a/docs/connection/db_connection/oracle/execute.rst b/docs/connection/db_connection/oracle/execute.rst new file mode 100644 index 000000000..24ea689a4 --- /dev/null +++ b/docs/connection/db_connection/oracle/execute.rst @@ -0,0 +1,17 @@ +.. _oracle-execute: + +Executing statements in Oracle +============================== + +.. currentmodule:: onetl.connection.db_connection.oracle.connection + +.. automethod:: Oracle.fetch +.. automethod:: Oracle.execute +.. automethod:: Oracle.close + +.. currentmodule:: onetl.connection.db_connection.jdbc_mixin.options + +.. autopydantic_model:: JDBCOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/oracle/index.rst b/docs/connection/db_connection/oracle/index.rst new file mode 100644 index 000000000..519250fb5 --- /dev/null +++ b/docs/connection/db_connection/oracle/index.rst @@ -0,0 +1,18 @@ +.. _oracle: + +Oracle +====== + +.. toctree:: + :maxdepth: 1 + :caption: Connection + + connection + +.. toctree:: + :maxdepth: 1 + :caption: Operations + + read + write + execute diff --git a/docs/connection/db_connection/oracle/read.rst b/docs/connection/db_connection/oracle/read.rst new file mode 100644 index 000000000..ffd393e6e --- /dev/null +++ b/docs/connection/db_connection/oracle/read.rst @@ -0,0 +1,22 @@ +.. _oracle-read: + +Reading from Oracle +=================== + +There are 2 ways of distributed data reading from Oracle: + +* Using :obj:`DBReader ` with different :ref:`strategy` +* Using :obj:`Oracle.sql ` + +Both methods accept :obj:`JDBCReadOptions ` + +.. currentmodule:: onetl.connection.db_connection.oracle.connection + +.. automethod:: Oracle.sql + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCReadOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/oracle/write.rst b/docs/connection/db_connection/oracle/write.rst new file mode 100644 index 000000000..78c57d915 --- /dev/null +++ b/docs/connection/db_connection/oracle/write.rst @@ -0,0 +1,13 @@ +.. _oracle-write: + +Writing to Oracle +================= + +For writing data to Oracle, use :obj:`DBWriter ` with options below. + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCWriteOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/postgres.rst b/docs/connection/db_connection/postgres.rst deleted file mode 100644 index 605a462d7..000000000 --- a/docs/connection/db_connection/postgres.rst +++ /dev/null @@ -1,30 +0,0 @@ -.. _postgres: - -Postgres connection -==================== - -.. currentmodule:: onetl.connection.db_connection.postgres - -.. autosummary:: - - Postgres - Postgres.ReadOptions - Postgres.WriteOptions - Postgres.JDBCOptions - -.. autoclass:: Postgres - :members: get_packages, check, sql, fetch, execute, close - -.. currentmodule:: onetl.connection.db_connection.postgres.Postgres - -.. autopydantic_model:: ReadOptions - :members: fetchsize, partitioning_mode, partition_column, num_partitions, lower_bound, upper_bound, session_init_statement - :member-order: bysource - -.. autopydantic_model:: WriteOptions - :members: mode, batchsize, isolation_level, query_timeout - :member-order: bysource - -.. autopydantic_model:: JDBCOptions - :members: query_timeout, fetchsize - :member-order: bysource diff --git a/docs/connection/db_connection/postgres/connection.rst b/docs/connection/db_connection/postgres/connection.rst new file mode 100644 index 000000000..517bcd5f2 --- /dev/null +++ b/docs/connection/db_connection/postgres/connection.rst @@ -0,0 +1,9 @@ +.. _postgres-connection: + +Postgres connection +=================== + +.. currentmodule:: onetl.connection.db_connection.postgres.connection + +.. autoclass:: Postgres + :members: get_packages, check diff --git a/docs/connection/db_connection/postgres/execute.rst b/docs/connection/db_connection/postgres/execute.rst new file mode 100644 index 000000000..042b97197 --- /dev/null +++ b/docs/connection/db_connection/postgres/execute.rst @@ -0,0 +1,17 @@ +.. _postgres-execute: + +Executing statements in Postgres +================================ + +.. currentmodule:: onetl.connection.db_connection.postgres.connection + +.. automethod:: Postgres.fetch +.. automethod:: Postgres.execute +.. automethod:: Postgres.close + +.. currentmodule:: onetl.connection.db_connection.jdbc_mixin.options + +.. autopydantic_model:: JDBCOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/postgres/index.rst b/docs/connection/db_connection/postgres/index.rst new file mode 100644 index 000000000..f5376ee93 --- /dev/null +++ b/docs/connection/db_connection/postgres/index.rst @@ -0,0 +1,18 @@ +.. _postgres: + +Postgres +======== + +.. toctree:: + :maxdepth: 1 + :caption: Connection + + connection + +.. toctree:: + :maxdepth: 1 + :caption: Operations + + read + write + execute diff --git a/docs/connection/db_connection/postgres/read.rst b/docs/connection/db_connection/postgres/read.rst new file mode 100644 index 000000000..737a731db --- /dev/null +++ b/docs/connection/db_connection/postgres/read.rst @@ -0,0 +1,22 @@ +.. _postgres-read: + +Reading from Postgres +===================== + +There are 2 ways of distributed data reading from Postgres: + +* Using :obj:`DBReader ` with different :ref:`strategy` +* Using :obj:`Postgres.sql ` + +Both methods accept :obj:`JDBCReadOptions ` + +.. currentmodule:: onetl.connection.db_connection.postgres.connection + +.. automethod:: Postgres.sql + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCReadOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/postgres/write.rst b/docs/connection/db_connection/postgres/write.rst new file mode 100644 index 000000000..db96b4ec8 --- /dev/null +++ b/docs/connection/db_connection/postgres/write.rst @@ -0,0 +1,13 @@ +.. _postgres-write: + +Writing to Postgres +=================== + +For writing data to Postgres, use :obj:`DBWriter ` with options below. + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCWriteOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/teradata.rst b/docs/connection/db_connection/teradata.rst deleted file mode 100644 index cca8ea3cb..000000000 --- a/docs/connection/db_connection/teradata.rst +++ /dev/null @@ -1,30 +0,0 @@ -.. _teradata: - -Teradata connection -==================== - -.. currentmodule:: onetl.connection.db_connection.teradata - -.. autosummary:: - - Teradata - Teradata.ReadOptions - Teradata.WriteOptions - Teradata.JDBCOptions - -.. autoclass:: Teradata - :members: get_packages, check, sql, fetch, execute, close - -.. currentmodule:: onetl.connection.db_connection.teradata.Teradata - -.. autopydantic_model:: ReadOptions - :members: fetchsize, partitioning_mode, partition_column, num_partitions, lower_bound, upper_bound, session_init_statement - :member-order: bysource - -.. autopydantic_model:: WriteOptions - :members: mode, batchsize, isolation_level, query_timeout - :member-order: bysource - -.. autopydantic_model:: JDBCOptions - :members: query_timeout, fetchsize - :member-order: bysource diff --git a/docs/connection/db_connection/teradata/connection.rst b/docs/connection/db_connection/teradata/connection.rst new file mode 100644 index 000000000..0e70dda34 --- /dev/null +++ b/docs/connection/db_connection/teradata/connection.rst @@ -0,0 +1,9 @@ +.. _teradata-connection: + +Teradata connection +=================== + +.. currentmodule:: onetl.connection.db_connection.teradata.connection + +.. autoclass:: Teradata + :members: get_packages, check diff --git a/docs/connection/db_connection/teradata/execute.rst b/docs/connection/db_connection/teradata/execute.rst new file mode 100644 index 000000000..80853f919 --- /dev/null +++ b/docs/connection/db_connection/teradata/execute.rst @@ -0,0 +1,17 @@ +.. _teradata-execute: + +Executing statements in Teradata +================================ + +.. currentmodule:: onetl.connection.db_connection.teradata.connection + +.. automethod:: Teradata.fetch +.. automethod:: Teradata.execute +.. automethod:: Teradata.close + +.. currentmodule:: onetl.connection.db_connection.jdbc_mixin.options + +.. autopydantic_model:: JDBCOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/teradata/index.rst b/docs/connection/db_connection/teradata/index.rst new file mode 100644 index 000000000..2f6d6636d --- /dev/null +++ b/docs/connection/db_connection/teradata/index.rst @@ -0,0 +1,18 @@ +.. _teradata: + +Teradata +======== + +.. toctree:: + :maxdepth: 1 + :caption: Connection + + connection + +.. toctree:: + :maxdepth: 1 + :caption: Operations + + read + write + execute diff --git a/docs/connection/db_connection/teradata/read.rst b/docs/connection/db_connection/teradata/read.rst new file mode 100644 index 000000000..b23c42a59 --- /dev/null +++ b/docs/connection/db_connection/teradata/read.rst @@ -0,0 +1,22 @@ +.. _teradata-read: + +Reading from Teradata +===================== + +There are 2 ways of distributed data reading from Teradata: + +* Using :obj:`DBReader ` with different :ref:`strategy` +* Using :obj:`Teradata.sql ` + +Both methods accept :obj:`JDBCReadOptions ` + +.. currentmodule:: onetl.connection.db_connection.teradata.connection + +.. automethod:: Teradata.sql + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCReadOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/db_connection/teradata/write.rst b/docs/connection/db_connection/teradata/write.rst new file mode 100644 index 000000000..7e5cbef40 --- /dev/null +++ b/docs/connection/db_connection/teradata/write.rst @@ -0,0 +1,13 @@ +.. _teradata-write: + +Writing to Teradata +=================== + +For writing data to Teradata, use :obj:`DBWriter ` with options below. + +.. currentmodule:: onetl.connection.db_connection.jdbc_connection.options + +.. autopydantic_model:: JDBCWriteOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/connection/file_connection/hdfs.rst b/docs/connection/file_connection/hdfs.rst deleted file mode 100644 index 38752a467..000000000 --- a/docs/connection/file_connection/hdfs.rst +++ /dev/null @@ -1,20 +0,0 @@ -.. _hdfs: - -HDFS connection -=============== - -.. currentmodule:: onetl.connection.file_connection.hdfs - -.. autosummary:: - - HDFS - HDFS.Slots - -.. autoclass:: HDFS - :members: __init__, check, path_exists, is_file, is_dir, get_stat, resolve_dir, resolve_file, create_dir, remove_file, remove_dir, rename_dir, rename_file, list_dir, walk, download_file, upload_file - -.. currentmodule:: onetl.connection.file_connection.hdfs.HDFS - -.. autoclass:: Slots - :members: normalize_cluster_name, normalize_namenode_name, get_known_clusters, get_cluster_namenodes, get_current_cluster, get_webhdfs_port, is_namenode_active - :member-order: bysource diff --git a/docs/connection/file_connection/hdfs/connection.rst b/docs/connection/file_connection/hdfs/connection.rst new file mode 100644 index 000000000..7fd657571 --- /dev/null +++ b/docs/connection/file_connection/hdfs/connection.rst @@ -0,0 +1,9 @@ +.. _hdfs-connection: + +HDFS connection +=============== + +.. currentmodule:: onetl.connection.file_connection.hdfs.connection + +.. autoclass:: HDFS + :members: get_current, check, path_exists, is_file, is_dir, get_stat, resolve_dir, resolve_file, create_dir, remove_file, remove_dir, rename_dir, rename_file, list_dir, walk, download_file, upload_file diff --git a/docs/connection/file_connection/hdfs/index.rst b/docs/connection/file_connection/hdfs/index.rst new file mode 100644 index 000000000..a9d57a7a5 --- /dev/null +++ b/docs/connection/file_connection/hdfs/index.rst @@ -0,0 +1,16 @@ +.. _hdfs: + +HDFS +==== + +.. toctree:: + :maxdepth: 1 + :caption: Connection + + connection + +.. toctree:: + :maxdepth: 1 + :caption: For developers + + slots diff --git a/docs/connection/file_connection/hdfs/slots.rst b/docs/connection/file_connection/hdfs/slots.rst new file mode 100644 index 000000000..2128b328c --- /dev/null +++ b/docs/connection/file_connection/hdfs/slots.rst @@ -0,0 +1,10 @@ +.. _hdfs-slots: + +HDFS Slots +========== + +.. currentmodule:: onetl.connection.file_connection.hdfs.slots + +.. autoclass:: HDFSSlots + :members: normalize_cluster_name, normalize_namenode_host, get_known_clusters, get_cluster_namenodes, get_current_cluster, get_webhdfs_port, is_namenode_active + :member-order: bysource diff --git a/docs/connection/file_connection/index.rst b/docs/connection/file_connection/index.rst index c47ddf75c..2fc998c7f 100644 --- a/docs/connection/file_connection/index.rst +++ b/docs/connection/file_connection/index.rst @@ -9,7 +9,7 @@ File Connections FTP FTPS - HDFS + HDFS SFTP S3 Webdav diff --git a/docs/connection/file_df_connection/spark_hdfs/slots.rst b/docs/connection/file_df_connection/spark_hdfs/slots.rst index 6adb4e1f0..3797c54ca 100644 --- a/docs/connection/file_df_connection/spark_hdfs/slots.rst +++ b/docs/connection/file_df_connection/spark_hdfs/slots.rst @@ -6,5 +6,5 @@ Spark HDFS Slots .. currentmodule:: onetl.connection.file_df_connection.spark_hdfs.slots .. autoclass:: SparkHDFSSlots - :members: normalize_cluster_name, normalize_namenode_name, get_known_clusters, get_cluster_namenodes, get_current_cluster, get_ipc_port, is_namenode_active + :members: normalize_cluster_name, normalize_namenode_host, get_known_clusters, get_cluster_namenodes, get_current_cluster, get_ipc_port, is_namenode_active :member-order: bysource diff --git a/docs/file/file_downloader/file_downloader.rst b/docs/file/file_downloader/file_downloader.rst index e1d807443..6c8728f1b 100644 --- a/docs/file/file_downloader/file_downloader.rst +++ b/docs/file/file_downloader/file_downloader.rst @@ -8,13 +8,9 @@ File Downloader .. autosummary:: FileDownloader - FileDownloader.Options + FileDownloader.run + FileDownloader.view_files .. autoclass:: FileDownloader :members: run, view_files - -.. currentmodule:: onetl.file.file_downloader.file_downloader.FileDownloader - -.. autopydantic_model:: Options - :members: mode, delete_source, workers :member-order: bysource diff --git a/docs/file/file_downloader/index.rst b/docs/file/file_downloader/index.rst index cf7276529..b20d859ea 100644 --- a/docs/file/file_downloader/index.rst +++ b/docs/file/file_downloader/index.rst @@ -8,4 +8,5 @@ File Downloader :caption: File Downloader file_downloader - download_result + options + result diff --git a/docs/file/file_downloader/options.rst b/docs/file/file_downloader/options.rst new file mode 100644 index 000000000..8fa37e613 --- /dev/null +++ b/docs/file/file_downloader/options.rst @@ -0,0 +1,11 @@ +.. _file-downloader-options: + +File Downloader Options +======================= + +.. currentmodule:: onetl.file.file_downloader.options + +.. autopydantic_model:: FileDownloaderOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/file/file_downloader/download_result.rst b/docs/file/file_downloader/result.rst similarity index 74% rename from docs/file/file_downloader/download_result.rst rename to docs/file/file_downloader/result.rst index 3d1a9d4f5..8fd20e9df 100644 --- a/docs/file/file_downloader/download_result.rst +++ b/docs/file/file_downloader/result.rst @@ -1,9 +1,9 @@ -.. _download-result: +.. _file-downloader-result: -Download result -============== +File Downloader Result +====================== -.. currentmodule:: onetl.file.file_downloader.download_result +.. currentmodule:: onetl.file.file_downloader.result .. autoclass:: DownloadResult :members: successful, failed, skipped, missing, successful_count, failed_count, skipped_count, missing_count, total_count, successful_size, failed_size, skipped_size, total_size, raise_if_failed, reraise_failed, raise_if_missing, raise_if_skipped, raise_if_empty, is_empty, raise_if_contains_zero_size, details, summary, dict, json diff --git a/docs/file/file_mover/file_mover.rst b/docs/file/file_mover/file_mover.rst index 6e12bb9b4..4c191ba94 100644 --- a/docs/file/file_mover/file_mover.rst +++ b/docs/file/file_mover/file_mover.rst @@ -1,7 +1,7 @@ .. _file-mover: File Mover -============== +========== .. currentmodule:: onetl.file.file_mover.file_mover @@ -10,14 +10,7 @@ File Mover FileMover FileMover.run FileMover.view_files - FileMover.Options .. autoclass:: FileMover :members: run, view_files :member-order: bysource - -.. currentmodule:: onetl.file.file_mover.file_mover.FileMover - -.. autopydantic_model:: Options - :members: mode, workers - :member-order: bysource diff --git a/docs/file/file_mover/index.rst b/docs/file/file_mover/index.rst index c0ca8a19d..e28f6316f 100644 --- a/docs/file/file_mover/index.rst +++ b/docs/file/file_mover/index.rst @@ -8,4 +8,5 @@ File Mover :caption: File Mover file_mover - move_result + options + result diff --git a/docs/file/file_mover/options.rst b/docs/file/file_mover/options.rst new file mode 100644 index 000000000..743ae5dd0 --- /dev/null +++ b/docs/file/file_mover/options.rst @@ -0,0 +1,11 @@ +.. _file-mover-options: + +File Mover Options +================== + +.. currentmodule:: onetl.file.file_mover.options + +.. autopydantic_model:: FileMoverOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/file/file_mover/move_result.rst b/docs/file/file_mover/result.rst similarity index 77% rename from docs/file/file_mover/move_result.rst rename to docs/file/file_mover/result.rst index c2c4581fe..d4ea950f3 100644 --- a/docs/file/file_mover/move_result.rst +++ b/docs/file/file_mover/result.rst @@ -1,9 +1,9 @@ -.. _move-result: +.. _file-mover-result: -Move result -============== +File Mover Result +================= -.. currentmodule:: onetl.file.file_mover.move_result +.. currentmodule:: onetl.file.file_mover.result .. autoclass:: MoveResult :members: successful, failed, skipped, missing, successful_count, failed_count, skipped_count, missing_count, total_count, successful_size, failed_size, skipped_size, total_size, raise_if_failed, reraise_failed, raise_if_missing, raise_if_skipped, raise_if_empty, is_empty, raise_if_contains_zero_size, details, summary, dict, json diff --git a/docs/file/file_uploader/file_uploader.rst b/docs/file/file_uploader/file_uploader.rst index 3f13c44b2..6f00c7dc6 100644 --- a/docs/file/file_uploader/file_uploader.rst +++ b/docs/file/file_uploader/file_uploader.rst @@ -1,7 +1,7 @@ .. _file-uploader: File Uploader -============== +============= .. currentmodule:: onetl.file.file_uploader.file_uploader @@ -10,14 +10,7 @@ File Uploader FileUploader FileUploader.run FileUploader.view_files - FileUploader.Options .. autoclass:: FileUploader :members: run, view_files :member-order: bysource - -.. currentmodule:: onetl.file.file_uploader.file_uploader.FileUploader - -.. autopydantic_model:: Options - :members: mode, delete_local, workers - :member-order: bysource diff --git a/docs/file/file_uploader/index.rst b/docs/file/file_uploader/index.rst index e12c65b20..d65c83e42 100644 --- a/docs/file/file_uploader/index.rst +++ b/docs/file/file_uploader/index.rst @@ -8,4 +8,5 @@ File Uploader :caption: File Uploader file_uploader - upload_result + options + result diff --git a/docs/file/file_uploader/options.rst b/docs/file/file_uploader/options.rst new file mode 100644 index 000000000..b0e614b53 --- /dev/null +++ b/docs/file/file_uploader/options.rst @@ -0,0 +1,11 @@ +.. _file-uploader-options: + +File Uploader Options +===================== + +.. currentmodule:: onetl.file.file_uploader.options + +.. autopydantic_model:: FileUploaderOptions + :member-order: bysource + :model-show-field-summary: false + :field-show-constraints: false diff --git a/docs/file/file_uploader/upload_result.rst b/docs/file/file_uploader/result.rst similarity index 75% rename from docs/file/file_uploader/upload_result.rst rename to docs/file/file_uploader/result.rst index 9c7b189d6..af20ace14 100644 --- a/docs/file/file_uploader/upload_result.rst +++ b/docs/file/file_uploader/result.rst @@ -1,9 +1,9 @@ -.. _upload-result: +.. _file-uploader-result: -Upload result -============== +File Uploader Result +==================== -.. currentmodule:: onetl.file.file_uploader.upload_result +.. currentmodule:: onetl.file.file_uploader.result .. autoclass:: UploadResult :members: successful, failed, skipped, missing, successful_count, failed_count, skipped_count, missing_count, total_count, successful_size, failed_size, skipped_size, total_size, raise_if_failed, reraise_failed, raise_if_missing, raise_if_skipped, raise_if_empty, is_empty, raise_if_contains_zero_size, details, summary, dict, json diff --git a/docs/file_df/file_df_reader/file_df_reader.rst b/docs/file_df/file_df_reader/file_df_reader.rst index d0c3db6d7..f9e2ef739 100644 --- a/docs/file_df/file_df_reader/file_df_reader.rst +++ b/docs/file_df/file_df_reader/file_df_reader.rst @@ -1,7 +1,7 @@ .. _file-df-reader: -File Reader -=========== +FileDF Reader +============= .. currentmodule:: onetl.file.file_df_reader.file_df_reader diff --git a/docs/file_df/file_df_reader/index.rst b/docs/file_df/file_df_reader/index.rst index c135b618d..e151c7376 100644 --- a/docs/file_df/file_df_reader/index.rst +++ b/docs/file_df/file_df_reader/index.rst @@ -1,11 +1,11 @@ .. _file-df-reader-root: -File Reader -=============== +FileDF Reader +============= .. toctree:: :maxdepth: 1 - :caption: File Reader + :caption: FileDF Reader file_df_reader options diff --git a/docs/file_df/file_df_writer/file_df_writer.rst b/docs/file_df/file_df_writer/file_df_writer.rst index 73f46bdb6..b7482cd26 100644 --- a/docs/file_df/file_df_writer/file_df_writer.rst +++ b/docs/file_df/file_df_writer/file_df_writer.rst @@ -1,7 +1,7 @@ .. _file-df-writer: -File Writer -=========== +FileDF Writer +============= .. currentmodule:: onetl.file.file_df_writer.file_df_writer diff --git a/docs/file_df/file_df_writer/index.rst b/docs/file_df/file_df_writer/index.rst index 22c02c09b..e9b74ee0f 100644 --- a/docs/file_df/file_df_writer/index.rst +++ b/docs/file_df/file_df_writer/index.rst @@ -1,11 +1,11 @@ .. _file-df-writer-root: -File Writer -=============== +FileDF Writer +============= .. toctree:: :maxdepth: 1 - :caption: File Writer + :caption: FileDF Writer file_df_writer options diff --git a/docs/hooks/design.rst b/docs/hooks/design.rst index 1b5266e8b..b05f9d64e 100644 --- a/docs/hooks/design.rst +++ b/docs/hooks/design.rst @@ -696,7 +696,7 @@ But most of logs are emitted with even lower level ``NOTICE``, to make output le NOTICE |Hooks| Calling hook 'mymodule.callback1' (1/2) NOTICE |Hooks| Hook is finished with returning non-None result NOTICE |Hooks| Calling hook 'mymodule.callback2' (2/2) - NOTICE |Hooks| This is a context manager, entering... + NOTICE |Hooks| This is a context manager, entering ... NOTICE |Hooks| Calling original method 'MyClass.method' NOTICE |Hooks| Method call is finished NOTICE |Hooks| Method call result (*NOT* None) will be replaced with result of hook 'mymodule.callback1' diff --git a/onetl/VERSION b/onetl/VERSION index f374f6662..2003b639c 100644 --- a/onetl/VERSION +++ b/onetl/VERSION @@ -1 +1 @@ -0.9.1 +0.9.2 diff --git a/onetl/base/__init__.py b/onetl/base/__init__.py index d418c33cc..c46eb5e27 100644 --- a/onetl/base/__init__.py +++ b/onetl/base/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from onetl.base.base_connection import BaseConnection -from onetl.base.base_db_connection import BaseDBConnection +from onetl.base.base_db_connection import BaseDBConnection, BaseDBDialect from onetl.base.base_file_connection import BaseFileConnection from onetl.base.base_file_df_connection import ( BaseFileDFConnection, diff --git a/onetl/base/base_db_connection.py b/onetl/base/base_db_connection.py index cb6b92a5f..56a7c5ae5 100644 --- a/onetl/base/base_db_connection.py +++ b/onetl/base/base_db_connection.py @@ -27,131 +27,134 @@ from pyspark.sql.types import StructType +class BaseDBDialect(ABC): + """ + Collection of methods used for validating input values before passing them to read_source_as_df/write_df_to_target + """ + + @classmethod + @abstractmethod + def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table: + """Check if ``source`` or ``target`` value is valid. + + Raises + ------ + TypeError + If value type is invalid + ValueError + If value is invalid + """ + + @classmethod + @abstractmethod + def validate_columns(cls, connection: BaseDBConnection, columns: list[str] | None) -> list[str] | None: + """Check if ``columns`` value is valid. + + Raises + ------ + TypeError + If value type is invalid + ValueError + If value is invalid + """ + + @classmethod + @abstractmethod + def validate_hwm_column( + cls, + connection: BaseDBConnection, + hwm_column: str | None, + ) -> str | None: + """Check if ``hwm_column`` value is valid. + + Raises + ------ + TypeError + If value type is invalid + ValueError + If value is invalid + """ + + @classmethod + @abstractmethod + def validate_df_schema(cls, connection: BaseDBConnection, df_schema: StructType | None) -> StructType | None: + """Check if ``df_schema`` value is valid. + + Raises + ------ + TypeError + If value type is invalid + ValueError + If value is invalid + """ + + @classmethod + @abstractmethod + def validate_where(cls, connection: BaseDBConnection, where: Any) -> Any | None: + """Check if ``where`` value is valid. + + Raises + ------ + TypeError + If value type is invalid + ValueError + If value is invalid + """ + + @classmethod + @abstractmethod + def validate_hint(cls, connection: BaseDBConnection, hint: Any) -> Any | None: + """Check if ``hint`` value is valid. + + Raises + ------ + TypeError + If value type is invalid + ValueError + If value is invalid + """ + + @classmethod + @abstractmethod + def validate_hwm_expression(cls, connection: BaseDBConnection, value: Any) -> str | None: + """Check if ``hwm_expression`` value is valid. + + Raises + ------ + TypeError + If value type is invalid + ValueError + If value is invalid + """ + + @classmethod + @abstractmethod + def _merge_conditions(cls, conditions: list[Any]) -> Any: + """ + Convert multiple WHERE conditions to one + """ + + @classmethod + @abstractmethod + def _expression_with_alias(cls, expression: Any, alias: str) -> Any: + """ + Return "expression AS alias" statement + """ + + @classmethod + @abstractmethod + def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> Any: + """ + Return "arg1 COMPARATOR arg2" statement + """ + + class BaseDBConnection(BaseConnection): """ Implements generic methods for reading and writing dataframe from/to database-like source """ - class Dialect(ABC): - """ - Collection of methods used for validating input values before passing them to read_source_as_df/write_df_to_target - """ - - @classmethod - @abstractmethod - def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table: - """Check if ``source`` or ``target`` value is valid. - - Raises - ------ - TypeError - If value type is invalid - ValueError - If value is invalid - """ - - @classmethod - @abstractmethod - def validate_columns(cls, connection: BaseDBConnection, columns: list[str] | None) -> list[str] | None: - """Check if ``columns`` value is valid. - - Raises - ------ - TypeError - If value type is invalid - ValueError - If value is invalid - """ - - @classmethod - @abstractmethod - def validate_hwm_column( - cls, - connection: BaseDBConnection, - hwm_column: str | None, - ) -> str | None: - """Check if ``hwm_column`` value is valid. - - Raises - ------ - TypeError - If value type is invalid - ValueError - If value is invalid - """ - - @classmethod - @abstractmethod - def validate_df_schema(cls, connection: BaseDBConnection, df_schema: StructType | None) -> StructType | None: - """Check if ``df_schema`` value is valid. - - Raises - ------ - TypeError - If value type is invalid - ValueError - If value is invalid - """ - - @classmethod - @abstractmethod - def validate_where(cls, connection: BaseDBConnection, where: Any) -> Any | None: - """Check if ``where`` value is valid. - - Raises - ------ - TypeError - If value type is invalid - ValueError - If value is invalid - """ - - @classmethod - @abstractmethod - def validate_hint(cls, connection: BaseDBConnection, hint: Any) -> Any | None: - """Check if ``hint`` value is valid. - - Raises - ------ - TypeError - If value type is invalid - ValueError - If value is invalid - """ - - @classmethod - @abstractmethod - def validate_hwm_expression(cls, connection: BaseDBConnection, value: Any) -> str | None: - """Check if ``hwm_expression`` value is valid. - - Raises - ------ - TypeError - If value type is invalid - ValueError - If value is invalid - """ - - @classmethod - @abstractmethod - def _merge_conditions(cls, conditions: list[Any]) -> Any: - """ - Convert multiple WHERE conditions to one - """ - - @classmethod - @abstractmethod - def _expression_with_alias(cls, expression: Any, alias: str) -> Any: - """ - Return "expression AS alias" statement - """ - - @classmethod - @abstractmethod - def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> Any: - """ - Return "arg1 COMPARATOR arg2" statement - """ + Dialect = BaseDBDialect @property @abstractmethod diff --git a/onetl/connection/db_connection/clickhouse/__init__.py b/onetl/connection/db_connection/clickhouse/__init__.py new file mode 100644 index 000000000..0fbebdd70 --- /dev/null +++ b/onetl/connection/db_connection/clickhouse/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.clickhouse.connection import ( + Clickhouse, + ClickhouseExtra, +) +from onetl.connection.db_connection.clickhouse.dialect import ClickhouseDialect diff --git a/onetl/connection/db_connection/clickhouse.py b/onetl/connection/db_connection/clickhouse/connection.py similarity index 83% rename from onetl/connection/db_connection/clickhouse.py rename to onetl/connection/db_connection/clickhouse/connection.py index 12a1e5126..dc6acf163 100644 --- a/onetl/connection/db_connection/clickhouse.py +++ b/onetl/connection/db_connection/clickhouse/connection.py @@ -16,13 +16,14 @@ import logging import warnings -from datetime import date, datetime from typing import ClassVar, Optional from onetl._util.classproperty import classproperty +from onetl.connection.db_connection.clickhouse.dialect import ClickhouseDialect from onetl.connection.db_connection.jdbc_connection import JDBCConnection -from onetl.connection.db_connection.jdbc_mixin import StatementType +from onetl.connection.db_connection.jdbc_mixin import JDBCStatementType from onetl.hooks import slot, support_hooks +from onetl.impl import GenericOptions # do not import PySpark here, as we allow user to use `Clickhouse.get_packages()` for creating Spark session @@ -30,6 +31,11 @@ log = logging.getLogger(__name__) +class ClickhouseExtra(GenericOptions): + class Config: + extra = "allow" + + @support_hooks class Clickhouse(JDBCConnection): """Clickhouse JDBC connection. |support_hooks| @@ -125,6 +131,10 @@ class Clickhouse(JDBCConnection): port: int = 8123 database: Optional[str] = None + extra: ClickhouseExtra = ClickhouseExtra() + + Extra = ClickhouseExtra + Dialect = ClickhouseDialect DRIVER: ClassVar[str] = "ru.yandex.clickhouse.ClickHouseDriver" @@ -163,32 +173,10 @@ def jdbc_url(self) -> str: return f"jdbc:clickhouse://{self.host}:{self.port}?{parameters}".rstrip("?") - class Dialect(JDBCConnection.Dialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: - result = value.strftime("%Y-%m-%d %H:%M:%S") - return f"CAST('{result}' AS DateTime)" - - @classmethod - def _get_date_value_sql(cls, value: date) -> str: - result = value.strftime("%Y-%m-%d") - return f"CAST('{result}' AS Date)" - - class ReadOptions(JDBCConnection.ReadOptions): - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"modulo(halfMD5({partition_column}), {num_partitions})" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" - - ReadOptions.__doc__ = JDBCConnection.ReadOptions.__doc__ - @staticmethod def _build_statement( statement: str, - statement_type: StatementType, + statement_type: JDBCStatementType, jdbc_connection, statement_args, ): diff --git a/onetl/connection/db_connection/clickhouse/dialect.py b/onetl/connection/db_connection/clickhouse/dialect.py new file mode 100644 index 000000000..56fe44b33 --- /dev/null +++ b/onetl/connection/db_connection/clickhouse/dialect.py @@ -0,0 +1,39 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import date, datetime + +from onetl.connection.db_connection.jdbc_connection import JDBCDialect + + +class ClickhouseDialect(JDBCDialect): + @classmethod + def _get_datetime_value_sql(cls, value: datetime) -> str: + result = value.strftime("%Y-%m-%d %H:%M:%S") + return f"CAST('{result}' AS DateTime)" + + @classmethod + def _get_date_value_sql(cls, value: date) -> str: + result = value.strftime("%Y-%m-%d") + return f"CAST('{result}' AS Date)" + + @classmethod + def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: + return f"modulo(halfMD5({partition_column}), {num_partitions})" + + @classmethod + def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} % {num_partitions}" diff --git a/onetl/connection/db_connection/db_connection.py b/onetl/connection/db_connection/db_connection.py deleted file mode 100644 index e86c989ff..000000000 --- a/onetl/connection/db_connection/db_connection.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2023 MTS (Mobile Telesystems) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import operator -from datetime import date, datetime -from logging import getLogger -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict - -from pydantic import Field - -from onetl._util.spark import try_import_pyspark -from onetl.base import BaseDBConnection -from onetl.hwm import Statement -from onetl.impl import FrozenModel -from onetl.log import log_with_indent - -if TYPE_CHECKING: - from pyspark.sql import SparkSession - -log = getLogger(__name__) - - -class DBConnection(BaseDBConnection, FrozenModel): - spark: SparkSession = Field(repr=False) - - class Dialect(BaseDBConnection.Dialect): - @classmethod - def _expression_with_alias(cls, expression: str, alias: str) -> str: - return f"{expression} AS {alias}" - - @classmethod - def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> Any: - template = cls._compare_statements[comparator] - return template.format(arg1, cls._serialize_datetime_value(arg2)) - - @classmethod - def _merge_conditions(cls, conditions: list[Any]) -> Any: - if len(conditions) == 1: - return conditions[0] - - return " AND ".join(f"({item})" for item in conditions) - - @classmethod - def _condition_assembler( - cls, - condition: Any, - start_from: Statement | None, - end_at: Statement | None, - ) -> Any: - conditions = [condition] - - if start_from: - condition1 = cls._get_compare_statement( - comparator=start_from.operator, - arg1=start_from.expression, - arg2=start_from.value, - ) - conditions.append(condition1) - - if end_at: - condition2 = cls._get_compare_statement( - comparator=end_at.operator, - arg1=end_at.expression, - arg2=end_at.value, - ) - conditions.append(condition2) - - result: list[Any] = list(filter(None, conditions)) - if not result: - return None - - return cls._merge_conditions(result) - - _compare_statements: ClassVar[Dict[Callable, str]] = { - operator.ge: "{} >= {}", - operator.gt: "{} > {}", - operator.le: "{} <= {}", - operator.lt: "{} < {}", - operator.eq: "{} == {}", - operator.ne: "{} != {}", - } - - @classmethod - def _serialize_datetime_value(cls, value: Any) -> str | int | dict: - """ - Transform the value into an SQL Dialect-supported form. - """ - - if isinstance(value, datetime): - return cls._get_datetime_value_sql(value) - - if isinstance(value, date): - return cls._get_date_value_sql(value) - - return str(value) - - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: - """ - Transform the datetime value into supported by SQL Dialect - """ - result = value.isoformat() - return repr(result) - - @classmethod - def _get_date_value_sql(cls, value: date) -> str: - """ - Transform the date value into supported by SQL Dialect - """ - result = value.isoformat() - return repr(result) - - @classmethod - def _get_max_value_sql(cls, value: Any) -> str: - """ - Generate `MAX(value)` clause for given value - """ - result = cls._serialize_datetime_value(value) - return f"MAX({result})" - - @classmethod - def _get_min_value_sql(cls, value: Any) -> str: - """ - Generate `MIN(value)` clause for given value - """ - result = cls._serialize_datetime_value(value) - return f"MIN({result})" - - @classmethod - def _forward_refs(cls) -> dict[str, type]: - try_import_pyspark() - - from pyspark.sql import SparkSession # noqa: WPS442 - - # avoid importing pyspark unless user called the constructor, - # as we allow user to use `Connection.get_packages()` for creating Spark session - refs = super()._forward_refs() - refs["SparkSession"] = SparkSession - return refs - - def _log_parameters(self): - log.info("|Spark| Using connection parameters:") - log_with_indent(log, "type = %s", self.__class__.__name__) - parameters = self.dict(exclude_none=True, exclude={"spark"}) - for attr, value in sorted(parameters.items()): - log_with_indent(log, "%s = %r", attr, value) diff --git a/onetl/connection/db_connection/db_connection/__init__.py b/onetl/connection/db_connection/db_connection/__init__.py new file mode 100644 index 000000000..9eb19e84f --- /dev/null +++ b/onetl/connection/db_connection/db_connection/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.db_connection.connection import DBConnection +from onetl.connection.db_connection.db_connection.dialect import DBDialect diff --git a/onetl/connection/db_connection/db_connection/connection.py b/onetl/connection/db_connection/db_connection/connection.py new file mode 100644 index 000000000..315f5b17c --- /dev/null +++ b/onetl/connection/db_connection/db_connection/connection.py @@ -0,0 +1,55 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from logging import getLogger +from typing import TYPE_CHECKING + +from pydantic import Field + +from onetl._util.spark import try_import_pyspark +from onetl.base import BaseDBConnection +from onetl.connection.db_connection.db_connection.dialect import DBDialect +from onetl.impl import FrozenModel +from onetl.log import log_with_indent + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + +log = getLogger(__name__) + + +class DBConnection(BaseDBConnection, FrozenModel): + spark: SparkSession = Field(repr=False) + + Dialect = DBDialect + + @classmethod + def _forward_refs(cls) -> dict[str, type]: + try_import_pyspark() + + from pyspark.sql import SparkSession # noqa: WPS442 + + # avoid importing pyspark unless user called the constructor, + # as we allow user to use `Connection.get_packages()` for creating Spark session + refs = super()._forward_refs() + refs["SparkSession"] = SparkSession + return refs + + def _log_parameters(self): + log.info("|%s| Using connection parameters:", self.__class__.__name__) + parameters = self.dict(exclude_none=True, exclude={"spark"}) + for attr, value in parameters.items(): + log_with_indent(log, "%s = %r", attr, value) diff --git a/onetl/connection/db_connection/db_connection/dialect.py b/onetl/connection/db_connection/db_connection/dialect.py new file mode 100644 index 000000000..5c9472189 --- /dev/null +++ b/onetl/connection/db_connection/db_connection/dialect.py @@ -0,0 +1,130 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import operator +from datetime import date, datetime +from typing import Any, Callable, ClassVar, Dict + +from onetl.base import BaseDBDialect +from onetl.hwm import Statement + + +class DBDialect(BaseDBDialect): + _compare_statements: ClassVar[Dict[Callable, str]] = { + operator.ge: "{} >= {}", + operator.gt: "{} > {}", + operator.le: "{} <= {}", + operator.lt: "{} < {}", + operator.eq: "{} == {}", + operator.ne: "{} != {}", + } + + @classmethod + def _escape_column(cls, value: str) -> str: + return f'"{value}"' + + @classmethod + def _expression_with_alias(cls, expression: str, alias: str) -> str: + return f"{expression} AS {alias}" + + @classmethod + def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> Any: + template = cls._compare_statements[comparator] + return template.format(arg1, cls._serialize_datetime_value(arg2)) + + @classmethod + def _merge_conditions(cls, conditions: list[Any]) -> Any: + if len(conditions) == 1: + return conditions[0] + + return " AND ".join(f"({item})" for item in conditions) + + @classmethod + def _condition_assembler( + cls, + condition: Any, + start_from: Statement | None, + end_at: Statement | None, + ) -> Any: + conditions = [condition] + + if start_from: + condition1 = cls._get_compare_statement( + comparator=start_from.operator, + arg1=start_from.expression, + arg2=start_from.value, + ) + conditions.append(condition1) + + if end_at: + condition2 = cls._get_compare_statement( + comparator=end_at.operator, + arg1=end_at.expression, + arg2=end_at.value, + ) + conditions.append(condition2) + + result: list[Any] = list(filter(None, conditions)) + if not result: + return None + + return cls._merge_conditions(result) + + @classmethod + def _serialize_datetime_value(cls, value: Any) -> str | int | dict: + """ + Transform the value into an SQL Dialect-supported form. + """ + + if isinstance(value, datetime): + return cls._get_datetime_value_sql(value) + + if isinstance(value, date): + return cls._get_date_value_sql(value) + + return str(value) + + @classmethod + def _get_datetime_value_sql(cls, value: datetime) -> str: + """ + Transform the datetime value into supported by SQL Dialect + """ + result = value.isoformat() + return repr(result) + + @classmethod + def _get_date_value_sql(cls, value: date) -> str: + """ + Transform the date value into supported by SQL Dialect + """ + result = value.isoformat() + return repr(result) + + @classmethod + def _get_max_value_sql(cls, value: Any) -> str: + """ + Generate `MAX(value)` clause for given value + """ + result = cls._serialize_datetime_value(value) + return f"MAX({result})" + + @classmethod + def _get_min_value_sql(cls, value: Any) -> str: + """ + Generate `MIN(value)` clause for given value + """ + result = cls._serialize_datetime_value(value) + return f"MIN({result})" diff --git a/onetl/connection/db_connection/dialect_mixins/__init__.py b/onetl/connection/db_connection/dialect_mixins/__init__.py index 1eb0d7f17..4d889b276 100644 --- a/onetl/connection/db_connection/dialect_mixins/__init__.py +++ b/onetl/connection/db_connection/dialect_mixins/__init__.py @@ -25,11 +25,11 @@ from onetl.connection.db_connection.dialect_mixins.support_hwm_expression_str import ( SupportHWMExpressionStr, ) -from onetl.connection.db_connection.dialect_mixins.support_table_with_dbschema import ( - SupportTableWithDBSchema, +from onetl.connection.db_connection.dialect_mixins.support_name_any import ( + SupportNameAny, ) -from onetl.connection.db_connection.dialect_mixins.support_table_without_dbschema import ( - SupportTableWithoutDBSchema, +from onetl.connection.db_connection.dialect_mixins.support_name_with_schema_only import ( + SupportNameWithSchemaOnly, ) from onetl.connection.db_connection.dialect_mixins.support_where_none import ( SupportWhereNone, diff --git a/onetl/connection/db_connection/dialect_mixins/support_name_any.py b/onetl/connection/db_connection/dialect_mixins/support_name_any.py new file mode 100644 index 000000000..8ecb34fd6 --- /dev/null +++ b/onetl/connection/db_connection/dialect_mixins/support_name_any.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from etl_entities import Table + +from onetl.base import BaseDBConnection + + +class SupportNameAny: + @classmethod + def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table: + return value diff --git a/onetl/connection/db_connection/dialect_mixins/support_table_without_dbschema.py b/onetl/connection/db_connection/dialect_mixins/support_name_with_schema_only.py similarity index 60% rename from onetl/connection/db_connection/dialect_mixins/support_table_without_dbschema.py rename to onetl/connection/db_connection/dialect_mixins/support_name_with_schema_only.py index 456181411..eb374ca3a 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_table_without_dbschema.py +++ b/onetl/connection/db_connection/dialect_mixins/support_name_with_schema_only.py @@ -5,11 +5,12 @@ from onetl.base import BaseDBConnection -class SupportTableWithoutDBSchema: +class SupportNameWithSchemaOnly: @classmethod def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table: - if value.db is not None: + if value.name.count(".") != 1: raise ValueError( - f"Table name should be passed in `mytable` format (not `myschema.mytable`), got '{value}'", + f"Name should be passed in `schema.name` format, got '{value}'", ) + return value diff --git a/onetl/connection/db_connection/dialect_mixins/support_table_with_dbschema.py b/onetl/connection/db_connection/dialect_mixins/support_table_with_dbschema.py deleted file mode 100644 index 3fa81ad54..000000000 --- a/onetl/connection/db_connection/dialect_mixins/support_table_with_dbschema.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -from etl_entities import Table - -from onetl.base import BaseDBConnection - - -class SupportTableWithDBSchema: - @classmethod - def validate_name(cls, connection: BaseDBConnection, value: Table) -> Table: - if value.db is None: - # Same error text as in etl_entites.Table value error. - raise ValueError( - f"Table name should be passed in `schema.name` format, got '{value}'", - ) - - return value diff --git a/onetl/connection/db_connection/greenplum/__init__.py b/onetl/connection/db_connection/greenplum/__init__.py new file mode 100644 index 000000000..d080e8932 --- /dev/null +++ b/onetl/connection/db_connection/greenplum/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.greenplum.connection import Greenplum +from onetl.connection.db_connection.greenplum.dialect import GreenplumDialect +from onetl.connection.db_connection.greenplum.options import ( + GreenplumReadOptions, + GreenplumTableExistBehavior, + GreenplumWriteOptions, +) diff --git a/onetl/connection/db_connection/greenplum.py b/onetl/connection/db_connection/greenplum/connection.py similarity index 60% rename from onetl/connection/db_connection/greenplum.py rename to onetl/connection/db_connection/greenplum/connection.py index 0986ac333..99de7d90c 100644 --- a/onetl/connection/db_connection/greenplum.py +++ b/onetl/connection/db_connection/greenplum/connection.py @@ -18,13 +18,10 @@ import os import textwrap import warnings -from dataclasses import dataclass -from datetime import date, datetime -from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import TYPE_CHECKING, Any, ClassVar from etl_entities.instance import Host -from pydantic import Field, root_validator, validator +from pydantic import validator from onetl._internal import get_sql_query from onetl._util.classproperty import classproperty @@ -33,16 +30,17 @@ from onetl._util.spark import get_executor_total_cores, get_spark_version from onetl._util.version import Version from onetl.connection.db_connection.db_connection import DBConnection -from onetl.connection.db_connection.dialect_mixins import ( - SupportColumnsList, - SupportDfSchemaNone, - SupportHintNone, - SupportHWMColumnStr, - SupportHWMExpressionStr, - SupportTableWithDBSchema, - SupportWhereStr, +from onetl.connection.db_connection.greenplum.connection_limit import ( + GreenplumConnectionLimit, +) +from onetl.connection.db_connection.greenplum.dialect import GreenplumDialect +from onetl.connection.db_connection.greenplum.options import ( + GreenplumReadOptions, + GreenplumTableExistBehavior, + GreenplumWriteOptions, ) from onetl.connection.db_connection.jdbc_mixin import JDBCMixin +from onetl.connection.db_connection.jdbc_mixin.options import JDBCOptions from onetl.exception import MISSING_JVM_CLASS_MSG, TooManyParallelJobsError from onetl.hooks import slot, support_hooks from onetl.hwm import Statement @@ -56,14 +54,6 @@ log = logging.getLogger(__name__) -# options from which are populated by Greenplum class methods -GENERIC_PROHIBITED_OPTIONS = frozenset( - ( - "dbschema", - "dbtable", - ), -) - EXTRA_OPTIONS = frozenset( ( "server.*", @@ -71,67 +61,15 @@ ), ) -WRITE_OPTIONS = frozenset( - ( - "mode", - "truncate", - "distributedBy", - "distributed_by", - "iteratorOptimization", - "iterator_optimization", - ), -) - -READ_OPTIONS = frozenset( - ( - "partitions", - "num_partitions", - "numPartitions", - "partitionColumn", - "partition_column", - ), -) - - -class GreenplumTableExistBehavior(str, Enum): - APPEND = "append" - REPLACE_ENTIRE_TABLE = "replace_entire_table" - def __str__(self) -> str: - return str(self.value) - - @classmethod # noqa: WPS120 - def _missing_(cls, value: object): # noqa: WPS120 - if str(value) == "overwrite": - warnings.warn( - "Mode `overwrite` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `replace_entire_table` instead", - category=UserWarning, - stacklevel=4, - ) - return cls.REPLACE_ENTIRE_TABLE +class GreenplumExtra(GenericOptions): + # avoid closing connections from server side + # while connector is moving data to executors before insert + tcpKeepAlive: str = "true" # noqa: N815 - -@dataclass -class ConnectionLimits: - maximum: int - reserved: int - occupied: int - - @property - def available(self) -> int: - return self.maximum - self.reserved - self.occupied - - @property - def summary(self) -> str: - return textwrap.dedent( - f""" - available connections: {self.available} - occupied: {self.occupied} - max: {self.maximum} ("max_connection" in postgresql.conf) - reserved: {self.reserved} ("superuser_reserved_connections" in postgresql.conf) - """, - ).strip() + class Config: + extra = "allow" + prohibited_options = JDBCOptions.Config.prohibited_options @support_hooks @@ -139,7 +77,7 @@ class Greenplum(JDBCMixin, DBConnection): """Greenplum connection. |support_hooks| Based on package ``io.pivotal:greenplum-spark:2.1.4`` - (`Pivotal connector for Spark `_). + (`Pivotal connector for Spark `_). .. warning:: @@ -226,272 +164,15 @@ class Greenplum(JDBCMixin, DBConnection): ) """ - class Extra(GenericOptions): - # avoid closing connections from server side - # while connector is moving data to executors before insert - tcpKeepAlive: str = "true" # noqa: N815 - - class Config: - extra = "allow" - prohibited_options = JDBCMixin.JDBCOptions.Config.prohibited_options | GENERIC_PROHIBITED_OPTIONS - - class ReadOptions(JDBCMixin.JDBCOptions): - """Pivotal's Greenplum Spark connector reading options. - - .. note :: - - You can pass any value - `supported by connector `_, - even if it is not mentioned in this documentation. - - The set of supported options depends on connector version. See link above. - - .. warning:: - - Some options, like ``url``, ``dbtable``, ``server.*``, ``pool.*``, - etc are populated from connection attributes, and cannot be set in ``ReadOptions`` class - - Examples - -------- - - Read options initialization - - .. code:: python - - Greenplum.ReadOptions( - partition_column="reg_id", - num_partitions=10, - ) - """ - - class Config: - known_options = READ_OPTIONS - prohibited_options = ( - JDBCMixin.JDBCOptions.Config.prohibited_options - | EXTRA_OPTIONS - | GENERIC_PROHIBITED_OPTIONS - | WRITE_OPTIONS - ) - - partition_column: Optional[str] = Field(alias="partitionColumn") - """Column used to parallelize reading from a table. - - .. warning:: - - You should not change this option, unless you know what you're doing - - Possible values: - * ``None`` (default): - Spark generates N jobs (where N == number of segments in Greenplum cluster), - each job is reading only data from a specific segment - (filtering data by ``gp_segment_id`` column). - - This is very effective way to fetch the data from a cluster. - - * table column - Allocate each executor a range of values from a specific column. - - .. note:: - Column type must be numeric. Other types are not supported. - - Spark generates for each executor an SQL query like: - - Executor 1: - - .. code:: sql - - SELECT ... FROM table - WHERE (partition_column >= lowerBound - OR partition_column IS NULL) - AND partition_column < (lower_bound + stride) - - Executor 2: - - .. code:: sql - - SELECT ... FROM table - WHERE partition_column >= (lower_bound + stride) - AND partition_column < (lower_bound + 2 * stride) - - ... - - Executor N: - - .. code:: sql - - SELECT ... FROM table - WHERE partition_column >= (lower_bound + (N-1) * stride) - AND partition_column <= upper_bound - - Where ``stride=(upper_bound - lower_bound) / num_partitions``, - ``lower_bound=MIN(partition_column)``, ``upper_bound=MAX(partition_column)``. - - .. note:: - - :obj:`~num_partitions` is used just to - calculate the partition stride, **NOT** for filtering the rows in table. - So all rows in the table will be returned (unlike *Incremental* :ref:`strategy`). - - .. note:: - - All queries are executed in parallel. To execute them sequentially, use *Batch* :ref:`strategy`. - - .. warning:: - - Both options :obj:`~partition_column` and :obj:`~num_partitions` should have a value, - or both should be ``None`` - - Examples - -------- - - Read data in 10 parallel jobs by range of values in ``id_column`` column: - - .. code:: python - - Greenplum.ReadOptions( - partition_column="id_column", - num_partitions=10, - ) - """ - - num_partitions: Optional[int] = Field(alias="partitions") - """Number of jobs created by Spark to read the table content in parallel. - - See documentation for :obj:`~partition_column` for more details - - .. warning:: - - By default connector uses number of segments in the Greenplum cluster. - You should not change this option, unless you know what you're doing - - .. warning:: - - Both options :obj:`~partition_column` and :obj:`~num_partitions` should have a value, - or both should be ``None`` - """ - - class WriteOptions(JDBCMixin.JDBCOptions): - """Pivotal's Greenplum Spark connector writing options. - - .. note :: - - You can pass any value - `supported by connector `_, - even if it is not mentioned in this documentation. - - The set of supported options depends on connector version. See link above. - - .. warning:: - - Some options, like ``url``, ``dbtable``, ``server.*``, ``pool.*``, - etc are populated from connection attributes, and cannot be set in ``WriteOptions`` class - - Examples - -------- - - Write options initialization - - .. code:: python - - options = Greenplum.WriteOptions( - if_exists="append", - truncate="false", - distributedBy="mycolumn", - ) - """ - - class Config: - known_options = WRITE_OPTIONS - prohibited_options = ( - JDBCMixin.JDBCOptions.Config.prohibited_options - | EXTRA_OPTIONS - | GENERIC_PROHIBITED_OPTIONS - | READ_OPTIONS - ) - - if_exists: GreenplumTableExistBehavior = Field(default=GreenplumTableExistBehavior.APPEND, alias="mode") - """Behavior of writing data into existing table. - - Possible values: - * ``append`` (default) - Adds new rows into existing table. - - .. dropdown:: Behavior in details - - * Table does not exist - Table is created using options provided by user - (``distributedBy`` and others). - - * Table exists - Data is appended to a table. Table has the same DDL as before writing data. - - .. warning:: - - This mode does not check whether table already contains - rows from dataframe, so duplicated rows can be created. - - Also Spark does not support passing custom options to - insert statement, like ``ON CONFLICT``, so don't try to - implement deduplication using unique indexes or constraints. - - Instead, write to staging table and perform deduplication - using :obj:`~execute` method. - - * ``replace_entire_table`` - **Table is dropped and then created**. - - .. dropdown:: Behavior in details - - * Table does not exist - Table is created using options provided by user - (``distributedBy`` and others). - - * Table exists - Table content is replaced with dataframe content. - - After writing completed, target table could either have the same DDL as - before writing data (``truncate=True``), or can be recreated (``truncate=False``). - - .. note:: - - ``error`` and ``ignore`` modes are not supported. - """ - - @root_validator(pre=True) - def mode_is_deprecated(cls, values): - if "mode" in values: - warnings.warn( - "Option `Greenplum.WriteOptions(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `Greenplum.WriteOptions(if_exists=...)` instead", - category=UserWarning, - stacklevel=3, - ) - return values - - class Dialect( # noqa: WPS215 - SupportTableWithDBSchema, - SupportColumnsList, - SupportDfSchemaNone, - SupportWhereStr, - SupportHintNone, - SupportHWMExpressionStr, - SupportHWMColumnStr, - DBConnection.Dialect, - ): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: - result = value.isoformat() - return f"cast('{result}' as timestamp)" - - @classmethod - def _get_date_value_sql(cls, value: date) -> str: - result = value.isoformat() - return f"cast('{result}' as date)" - host: Host database: str port: int = 5432 - extra: Extra = Extra() + extra: GreenplumExtra = GreenplumExtra() + + Extra = GreenplumExtra + Dialect = GreenplumDialect + ReadOptions = GreenplumReadOptions + WriteOptions = GreenplumWriteOptions DRIVER: ClassVar[str] = "org.postgresql.Driver" CONNECTIONS_WARNING_LIMIT: ClassVar[int] = 31 @@ -598,7 +279,7 @@ def read_source_as_df( df_schema: StructType | None = None, start_from: Statement | None = None, end_at: Statement | None = None, - options: ReadOptions | dict | None = None, + options: GreenplumReadOptions | None = None, ) -> DataFrame: read_options = self.ReadOptions.parse(options).dict(by_alias=True, exclude_none=True) log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__) @@ -623,7 +304,7 @@ def write_df_to_target( self, df: DataFrame, target: str, - options: WriteOptions | dict | None = None, + options: GreenplumWriteOptions | None = None, ) -> None: write_options = self.WriteOptions.parse(options) options_dict = write_options.dict(by_alias=True, exclude_none=True, exclude={"if_exists"}) @@ -631,7 +312,11 @@ def write_df_to_target( self._check_expected_jobs_number(df, action="write") log.info("|%s| Saving data to a table %r", self.__class__.__name__, target) - mode = "overwrite" if write_options.if_exists == GreenplumTableExistBehavior.REPLACE_ENTIRE_TABLE else "append" + mode = ( + "overwrite" + if write_options.if_exists == GreenplumTableExistBehavior.REPLACE_ENTIRE_TABLE + else write_options.if_exists.value + ) df.write.format("greenplum").options( **self._connector_params(target), **options_dict, @@ -644,9 +329,9 @@ def get_df_schema( self, source: str, columns: list[str] | None = None, - options: JDBCMixin.JDBCOptions | dict | None = None, + options: JDBCOptions | None = None, ) -> StructType: - log.info("|%s| Fetching schema of table %r", self.__class__.__name__, source) + log.info("|%s| Fetching schema of table %r ...", self.__class__.__name__, source) query = get_sql_query(source, columns=columns, where="1=0", compact=True) jdbc_options = self.JDBCOptions.parse(options).copy(update={"fetchsize": 0}) @@ -655,7 +340,7 @@ def get_df_schema( log_lines(log, query, level=logging.DEBUG) df = self._query_on_driver(query, jdbc_options) - log.info("|%s| Schema fetched", self.__class__.__name__) + log.info("|%s| Schema fetched.", self.__class__.__name__) return df.schema @@ -667,9 +352,9 @@ def get_min_max_bounds( expression: str | None = None, hint: str | None = None, where: str | None = None, - options: JDBCMixin.JDBCOptions | dict | None = None, + options: JDBCOptions | None = None, ) -> tuple[Any, Any]: - log.info("|Spark| Getting min and max values for column %r", column) + log.info("|%s| Getting min and max values for column %r ...", self.__class__.__name__, column) jdbc_options = self.JDBCOptions.parse(options).copy(update={"fetchsize": 1}) @@ -678,11 +363,11 @@ def get_min_max_bounds( columns=[ self.Dialect._expression_with_alias( self.Dialect._get_min_value_sql(expression or column), - "min", + self.Dialect._escape_column("min"), ), self.Dialect._expression_with_alias( self.Dialect._get_max_value_sql(expression or column), - "max", + self.Dialect._escape_column("max"), ), ], where=where, @@ -696,7 +381,7 @@ def get_min_max_bounds( min_value = row["min"] max_value = row["max"] - log.info("|Spark| Received values:") + log.info("|%s| Received values:", self.__class__.__name__) log_with_indent(log, "MIN(%r) = %r", column, min_value) log_with_indent(log, "MAX(%r) = %r", column, max_value) @@ -737,7 +422,7 @@ def _connector_params( **extra, } - def _options_to_connection_properties(self, options: JDBCMixin.JDBCOptions): + def _options_to_connection_properties(self, options: JDBCOptions): # See https://github.com/pgjdbc/pgjdbc/pull/1252 # Since 42.2.9 Postgres JDBC Driver added new option readOnlyMode=transaction # Which is not a desired behavior, because `.fetch()` method should always be read-only @@ -786,11 +471,11 @@ def _get_occupied_connections_count(self) -> int: ) return int(result[0][0]) - def _get_connections_limits(self) -> ConnectionLimits: + def _get_connections_limits(self) -> GreenplumConnectionLimit: max_connections = int(self._get_server_setting("max_connections")) reserved_connections = int(self._get_server_setting("superuser_reserved_connections")) occupied_connections = self._get_occupied_connections_count() - return ConnectionLimits( + return GreenplumConnectionLimit( maximum=max_connections, reserved=reserved_connections, occupied=occupied_connections, diff --git a/onetl/connection/db_connection/greenplum/connection_limit.py b/onetl/connection/db_connection/greenplum/connection_limit.py new file mode 100644 index 000000000..7dec5ac26 --- /dev/null +++ b/onetl/connection/db_connection/greenplum/connection_limit.py @@ -0,0 +1,40 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import textwrap +from dataclasses import dataclass + + +@dataclass +class GreenplumConnectionLimit: + maximum: int + reserved: int + occupied: int + + @property + def available(self) -> int: + return self.maximum - self.reserved - self.occupied + + @property + def summary(self) -> str: + return textwrap.dedent( + f""" + available connections: {self.available} + occupied: {self.occupied} + max: {self.maximum} ("max_connection" in postgresql.conf) + reserved: {self.reserved} ("superuser_reserved_connections" in postgresql.conf) + """, + ).strip() diff --git a/onetl/connection/db_connection/greenplum/dialect.py b/onetl/connection/db_connection/greenplum/dialect.py new file mode 100644 index 000000000..a998811aa --- /dev/null +++ b/onetl/connection/db_connection/greenplum/dialect.py @@ -0,0 +1,49 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import date, datetime + +from onetl.connection.db_connection.db_connection import DBDialect +from onetl.connection.db_connection.dialect_mixins import ( + SupportColumnsList, + SupportDfSchemaNone, + SupportHintNone, + SupportHWMColumnStr, + SupportHWMExpressionStr, + SupportNameWithSchemaOnly, + SupportWhereStr, +) + + +class GreenplumDialect( # noqa: WPS215 + SupportNameWithSchemaOnly, + SupportColumnsList, + SupportDfSchemaNone, + SupportWhereStr, + SupportHintNone, + SupportHWMExpressionStr, + SupportHWMColumnStr, + DBDialect, +): + @classmethod + def _get_datetime_value_sql(cls, value: datetime) -> str: + result = value.isoformat() + return f"cast('{result}' as timestamp)" + + @classmethod + def _get_date_value_sql(cls, value: date) -> str: + result = value.isoformat() + return f"cast('{result}' as date)" diff --git a/onetl/connection/db_connection/greenplum/options.py b/onetl/connection/db_connection/greenplum/options.py new file mode 100644 index 000000000..86785155e --- /dev/null +++ b/onetl/connection/db_connection/greenplum/options.py @@ -0,0 +1,315 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from enum import Enum +from typing import Optional + +from pydantic import Field, root_validator + +from onetl.connection.db_connection.jdbc_mixin import JDBCOptions + +# options from which are populated by Greenplum class methods +GENERIC_PROHIBITED_OPTIONS = frozenset( + ( + "dbschema", + "dbtable", + ), +) + +WRITE_OPTIONS = frozenset( + ( + "mode", + "truncate", + "distributedBy", + "iteratorOptimization", + ), +) + +READ_OPTIONS = frozenset( + ( + "partitions", + "numPartitions", + "partitionColumn", + ), +) + + +class GreenplumTableExistBehavior(str, Enum): + APPEND = "append" + IGNORE = "ignore" + ERROR = "error" + REPLACE_ENTIRE_TABLE = "replace_entire_table" + + def __str__(self) -> str: + return str(self.value) + + @classmethod # noqa: WPS120 + def _missing_(cls, value: object): # noqa: WPS120 + if str(value) == "overwrite": + warnings.warn( + "Mode `overwrite` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `replace_entire_table` instead", + category=UserWarning, + stacklevel=4, + ) + return cls.REPLACE_ENTIRE_TABLE + + +class GreenplumReadOptions(JDBCOptions): + """Pivotal's Greenplum Spark connector reading options. + + .. note :: + + You can pass any value + `supported by connector `_, + even if it is not mentioned in this documentation. + + The set of supported options depends on connector version. See link above. + + .. warning:: + + Some options, like ``url``, ``dbtable``, ``server.*``, ``pool.*``, + etc are populated from connection attributes, and cannot be set in ``ReadOptions`` class + + Examples + -------- + + Read options initialization + + .. code:: python + + Greenplum.ReadOptions( + partition_column="reg_id", + num_partitions=10, + ) + """ + + class Config: + known_options = READ_OPTIONS + prohibited_options = JDBCOptions.Config.prohibited_options | GENERIC_PROHIBITED_OPTIONS | WRITE_OPTIONS + + partition_column: Optional[str] = Field(alias="partitionColumn") + """Column used to parallelize reading from a table. + + .. warning:: + + You should not change this option, unless you know what you're doing + + Possible values: + * ``None`` (default): + Spark generates N jobs (where N == number of segments in Greenplum cluster), + each job is reading only data from a specific segment + (filtering data by ``gp_segment_id`` column). + + This is very effective way to fetch the data from a cluster. + + * table column + Allocate each executor a range of values from a specific column. + + .. note:: + Column type must be numeric. Other types are not supported. + + Spark generates for each executor an SQL query like: + + Executor 1: + + .. code:: sql + + SELECT ... FROM table + WHERE (partition_column >= lowerBound + OR partition_column IS NULL) + AND partition_column < (lower_bound + stride) + + Executor 2: + + .. code:: sql + + SELECT ... FROM table + WHERE partition_column >= (lower_bound + stride) + AND partition_column < (lower_bound + 2 * stride) + + ... + + Executor N: + + .. code:: sql + + SELECT ... FROM table + WHERE partition_column >= (lower_bound + (N-1) * stride) + AND partition_column <= upper_bound + + Where ``stride=(upper_bound - lower_bound) / num_partitions``, + ``lower_bound=MIN(partition_column)``, ``upper_bound=MAX(partition_column)``. + + .. note:: + + :obj:`~num_partitions` is used just to + calculate the partition stride, **NOT** for filtering the rows in table. + So all rows in the table will be returned (unlike *Incremental* :ref:`strategy`). + + .. note:: + + All queries are executed in parallel. To execute them sequentially, use *Batch* :ref:`strategy`. + + .. warning:: + + Both options :obj:`~partition_column` and :obj:`~num_partitions` should have a value, + or both should be ``None`` + + Examples + -------- + + Read data in 10 parallel jobs by range of values in ``id_column`` column: + + .. code:: python + + Greenplum.ReadOptions( + partition_column="id_column", + num_partitions=10, + ) + """ + + num_partitions: Optional[int] = Field(alias="partitions") + """Number of jobs created by Spark to read the table content in parallel. + + See documentation for :obj:`~partition_column` for more details + + .. warning:: + + By default connector uses number of segments in the Greenplum cluster. + You should not change this option, unless you know what you're doing + + .. warning:: + + Both options :obj:`~partition_column` and :obj:`~num_partitions` should have a value, + or both should be ``None`` + """ + + +class GreenplumWriteOptions(JDBCOptions): + """Pivotal's Greenplum Spark connector writing options. + + .. note :: + + You can pass any value + `supported by connector `_, + even if it is not mentioned in this documentation. + + The set of supported options depends on connector version. See link above. + + .. warning:: + + Some options, like ``url``, ``dbtable``, ``server.*``, ``pool.*``, + etc are populated from connection attributes, and cannot be set in ``WriteOptions`` class + + Examples + -------- + + Write options initialization + + .. code:: python + + options = Greenplum.WriteOptions( + if_exists="append", + truncate="false", + distributedBy="mycolumn", + ) + """ + + class Config: + known_options = WRITE_OPTIONS + prohibited_options = JDBCOptions.Config.prohibited_options | GENERIC_PROHIBITED_OPTIONS | READ_OPTIONS + + if_exists: GreenplumTableExistBehavior = Field(default=GreenplumTableExistBehavior.APPEND, alias="mode") + """Behavior of writing data into existing table. + + Possible values: + * ``append`` (default) + Adds new rows into existing table. + + .. dropdown:: Behavior in details + + * Table does not exist + Table is created using options provided by user + (``distributedBy`` and others). + + * Table exists + Data is appended to a table. Table has the same DDL as before writing data. + + .. warning:: + + This mode does not check whether table already contains + rows from dataframe, so duplicated rows can be created. + + Also Spark does not support passing custom options to + insert statement, like ``ON CONFLICT``, so don't try to + implement deduplication using unique indexes or constraints. + + Instead, write to staging table and perform deduplication + using :obj:`~execute` method. + + * ``replace_entire_table`` + **Table is dropped and then created**. + + .. dropdown:: Behavior in details + + * Table does not exist + Table is created using options provided by user + (``distributedBy`` and others). + + * Table exists + Table content is replaced with dataframe content. + + After writing completed, target table could either have the same DDL as + before writing data (``truncate=True``), or can be recreated (``truncate=False``). + + * ``ignore`` + Ignores the write operation if the table already exists. + + .. dropdown:: Behavior in details + + * Table does not exist + Table is created using options provided by user + (``distributedBy`` and others). + + * Table exists + The write operation is ignored, and no data is written to the table. + + * ``error`` + Raises an error if the table already exists. + + .. dropdown:: Behavior in details + + * Table does not exist + Table is created using options provided by user + (``distributedBy`` and others). + + * Table exists + An error is raised, and no data is written to the table. + + """ + + @root_validator(pre=True) + def _mode_is_deprecated(cls, values): + if "mode" in values: + warnings.warn( + "Option `Greenplum.WriteOptions(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `Greenplum.WriteOptions(if_exists=...)` instead", + category=UserWarning, + stacklevel=3, + ) + return values diff --git a/onetl/connection/db_connection/hive.py b/onetl/connection/db_connection/hive.py deleted file mode 100644 index d82d05142..000000000 --- a/onetl/connection/db_connection/hive.py +++ /dev/null @@ -1,978 +0,0 @@ -# Copyright 2023 MTS (Mobile Telesystems) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import logging -import warnings -from enum import Enum -from textwrap import dedent -from typing import TYPE_CHECKING, Any, ClassVar, Iterable, List, Optional, Tuple, Union - -from deprecated import deprecated -from etl_entities.instance import Cluster -from pydantic import Field, root_validator, validator - -from onetl._internal import clear_statement, get_sql_query -from onetl._util.spark import inject_spark_param -from onetl.connection.db_connection.db_connection import DBConnection -from onetl.connection.db_connection.dialect_mixins import ( - SupportColumnsList, - SupportDfSchemaNone, - SupportHintStr, - SupportHWMColumnStr, - SupportHWMExpressionStr, - SupportWhereStr, -) -from onetl.connection.db_connection.dialect_mixins.support_table_with_dbschema import ( - SupportTableWithDBSchema, -) -from onetl.hooks import slot, support_hooks -from onetl.hwm import Statement -from onetl.impl import GenericOptions -from onetl.log import log_lines, log_with_indent - -if TYPE_CHECKING: - from pyspark.sql import DataFrame, SparkSession - from pyspark.sql.types import StructType - -PARTITION_OVERWRITE_MODE_PARAM = "spark.sql.sources.partitionOverwriteMode" -log = logging.getLogger(__name__) - - -class HiveTableExistBehavior(str, Enum): - APPEND = "append" - REPLACE_ENTIRE_TABLE = "replace_entire_table" - REPLACE_OVERLAPPING_PARTITIONS = "replace_overlapping_partitions" - - def __str__(self): - return str(self.value) - - @classmethod # noqa: WPS120 - def _missing_(cls, value: object): # noqa: WPS120 - if str(value) == "overwrite": - warnings.warn( - "Mode `overwrite` is deprecated since v0.4.0 and will be removed in v1.0.0. " - "Use `replace_overlapping_partitions` instead", - category=UserWarning, - stacklevel=4, - ) - return cls.REPLACE_OVERLAPPING_PARTITIONS - - if str(value) == "overwrite_partitions": - warnings.warn( - "Mode `overwrite_partitions` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `replace_overlapping_partitions` instead", - category=UserWarning, - stacklevel=4, - ) - return cls.REPLACE_OVERLAPPING_PARTITIONS - - if str(value) == "overwrite_table": - warnings.warn( - "Mode `overwrite_table` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `replace_entire_table` instead", - category=UserWarning, - stacklevel=4, - ) - return cls.REPLACE_ENTIRE_TABLE - - -@support_hooks -class Hive(DBConnection): - """Spark connection with Hive MetaStore support. |support_hooks| - - You don't need a Hive server to use this connector. - - .. dropdown:: Version compatibility - - * Hive metastore version: 0.12 - 3.1.2 (may require to add proper .jar file explicitly) - * Spark versions: 2.3.x - 3.4.x - * Java versions: 8 - 20 - - .. warning:: - - To use Hive connector you should have PySpark installed (or injected to ``sys.path``) - BEFORE creating the connector instance. - - You can install PySpark as follows: - - .. code:: bash - - pip install onetl[spark] # latest PySpark version - - # or - pip install onetl pyspark=3.4.1 # pass specific PySpark version - - See :ref:`spark-install` instruction for more details. - - .. warning:: - - This connector requires some additional configuration files to be present (``hive-site.xml`` and so on), - as well as .jar files with Hive MetaStore client. - - See `Spark Hive Tables documentation `_ - and `this guide `_ for more details. - - .. note:: - - Most of Hadoop instances use Kerberos authentication. In this case, you should call ``kinit`` - **BEFORE** starting Spark session to generate Kerberos ticket. See :ref:`kerberos-install`. - - In case of creating session with ``"spark.master": "yarn"``, you should also pass some additional options - to Spark session, allowing executors to generate their own Kerberos tickets to access HDFS. - See `Spark security documentation `_ - for more details. - - Parameters - ---------- - cluster : str - Cluster name. Used for HWM and lineage. - - spark : :obj:`pyspark.sql.SparkSession` - Spark session with Hive metastore support enabled - - Examples - -------- - - Hive connection initialization - - .. code:: python - - from onetl.connection import Hive - from pyspark.sql import SparkSession - - # Create Spark session - spark = SparkSession.builder.appName("spark-app-name").enableHiveSupport().getOrCreate() - - # Create connection - hive = Hive(cluster="rnd-dwh", spark=spark).check() - - Hive connection initialization with Kerberos support - - .. code:: python - - from onetl.connection import Hive - from pyspark.sql import SparkSession - - # Create Spark session - # Use names "spark.yarn.access.hadoopFileSystems", "spark.yarn.principal" - # and "spark.yarn.keytab" for Spark 2 - - spark = ( - SparkSession.builder.appName("spark-app-name") - .option("spark.kerberos.access.hadoopFileSystems", "hdfs://cluster.name.node:8020") - .option("spark.kerberos.principal", "user") - .option("spark.kerberos.keytab", "/path/to/keytab") - .enableHiveSupport() - .getOrCreate() - ) - - # Create connection - hive = Hive(cluster="rnd-dwh", spark=spark).check() - """ - - class WriteOptions(GenericOptions): - """Hive source writing options. - - You can pass here key-value items which then will be converted to calls - of :obj:`pyspark.sql.readwriter.DataFrameWriter` methods. - - For example, ``Hive.WriteOptions(if_exists="append", partitionBy="reg_id")`` will - be converted to ``df.write.mode("append").partitionBy("reg_id")`` call, and so on. - - .. note:: - - You can pass any method and its value - `supported by Spark `_, - even if it is not mentioned in this documentation. **Option names should be in** ``camelCase``! - - The set of supported options depends on Spark version used. See link above. - - Examples - -------- - - Writing options initialization - - .. code:: python - - options = Hive.WriteOptions( - if_exists="append", - partitionBy="reg_id", - someNewOption="value", - ) - """ - - class Config: - known_options: frozenset = frozenset() - extra = "allow" - - if_exists: HiveTableExistBehavior = Field(default=HiveTableExistBehavior.APPEND, alias="mode") - """Behavior of writing data into existing table. - - Possible values: - * ``append`` (default) - Appends data into existing partition/table, or create partition/table if it does not exist. - - Same as Spark's ``df.write.insertInto(table, overwrite=False)``. - - .. dropdown:: Behavior in details - - * Table does not exist - Table is created using options provided by user (``format``, ``compression``, etc). - - * Table exists, but not partitioned, :obj:`~partition_by` is set - Data is appended to a table. Table is still not partitioned (DDL is unchanged). - - * Table exists and partitioned, but has different partitioning schema than :obj:`~partition_by` - Partition is created based on table's ``PARTITIONED BY (...)`` options. - Explicit :obj:`~partition_by` value is ignored. - - * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in dataframe - Partition is created. - - * Table exists and partitioned according :obj:`~partition_by`, partition is present in both dataframe and table - Data is appended to existing partition. - - .. warning:: - - This mode does not check whether table already contains - rows from dataframe, so duplicated rows can be created. - - To implement deduplication, write data to staging table first, - and then perform some deduplication logic using :obj:`~sql`. - - * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in table, not dataframe - Existing partition is left intact. - - * ``replace_overlapping_partitions`` - Overwrites data in the existing partition, or create partition/table if it does not exist. - - Same as Spark's ``df.write.insertInto(table, overwrite=True)`` + - ``spark.sql.sources.partitionOverwriteMode=dynamic``. - - .. dropdown:: Behavior in details - - * Table does not exist - Table is created using options provided by user (``format``, ``compression``, etc). - - * Table exists, but not partitioned, :obj:`~partition_by` is set - Data is **overwritten in all the table**. Table is still not partitioned (DDL is unchanged). - - * Table exists and partitioned, but has different partitioning schema than :obj:`~partition_by` - Partition is created based on table's ``PARTITIONED BY (...)`` options. - Explicit :obj:`~partition_by` value is ignored. - - * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in dataframe - Partition is created. - - * Table exists and partitioned according :obj:`~partition_by`, partition is present in both dataframe and table - Existing partition **replaced** with data from dataframe. - - * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in table, not dataframe - Existing partition is left intact. - - * ``replace_entire_table`` - **Recreates table** (via ``DROP + CREATE``), **deleting all existing data**. - **All existing partitions are dropped.** - - Same as Spark's ``df.write.saveAsTable(table, mode="overwrite")`` (NOT ``insertInto``)! - - .. warning:: - - Table is recreated using options provided by user (``format``, ``compression``, etc) - **instead of using original table options**. Be careful - - .. note:: - - ``error`` and ``ignore`` modes are not supported. - - .. note:: - - Unlike using pure Spark, config option ``spark.sql.sources.partitionOverwriteMode`` - does not affect behavior. - """ - - format: str = "orc" - """Format of files which should be used for storing table data. - - Examples: ``orc`` (default), ``parquet``, ``csv`` (NOT recommended) - - .. note:: - - It's better to use column-based formats like ``orc`` or ``parquet``, - not row-based (``csv``, ``json``) - - .. warning:: - - Used **only** while **creating new table**, or in case of ``if_exists=recreate_entire_table`` - """ - - partition_by: Optional[Union[List[str], str]] = Field(default=None, alias="partitionBy") - """ - List of columns should be used for data partitioning. ``None`` means partitioning is disabled. - - Each partition is a folder which contains only files with the specific column value, - like ``myschema.db/mytable/col1=value1``, ``myschema.db/mytable/col1=value2``, and so on. - - Multiple partitions columns means nested folder structure, like ``myschema.db/mytable/col1=val1/col2=val2``. - - If ``WHERE`` clause in the query contains expression like ``partition = value``, - Spark will scan only files in a specific partition. - - Examples: ``reg_id`` or ``["reg_id", "business_dt"]`` - - .. note:: - - Values should be scalars (integers, strings), - and either static (``countryId``) or incrementing (dates, years), with low - number of distinct values. - - Columns like ``userId`` or ``datetime``/``timestamp`` should **NOT** be used for partitioning. - - .. warning:: - - Used **only** while **creating new table**, or in case of ``if_exists=recreate_entire_table`` - """ - - bucket_by: Optional[Tuple[int, Union[List[str], str]]] = Field(default=None, alias="bucketBy") # noqa: WPS234 - """Number of buckets plus bucketing columns. ``None`` means bucketing is disabled. - - Each bucket is created as a set of files with name containing result of calculation ``hash(columns) mod num_buckets``. - - This allows to remove shuffle from queries containing ``GROUP BY`` or ``JOIN`` or using ``=`` / ``IN`` predicates - on specific columns. - - Examples: ``(10, "user_id")``, ``(10, ["user_id", "user_phone"])`` - - .. note:: - - Bucketing should be used on columns containing a lot of unique values, - like ``userId``. - - Columns like ``date`` should **NOT** be used for bucketing - because of too low number of unique values. - - .. warning:: - - It is recommended to use this option **ONLY** if you have a large table - (hundreds of Gb or more), which is used mostly for JOINs with other tables, - and you're inserting data using ``if_exists=overwrite_partitions`` or ``if_exists=recreate_entire_table``. - - Otherwise Spark will create a lot of small files - (one file for each bucket and each executor), drastically **decreasing** HDFS performance. - - .. warning:: - - Used **only** while **creating new table**, or in case of ``if_exists=recreate_entire_table`` - """ - - sort_by: Optional[Union[List[str], str]] = Field(default=None, alias="sortBy") - """Each file in a bucket will be sorted by these columns value. ``None`` means sorting is disabled. - - Examples: ``user_id`` or ``["user_id", "user_phone"]`` - - .. note:: - - Sorting columns should contain values which are used in ``ORDER BY`` clauses. - - .. warning:: - - Could be used only with :obj:`~bucket_by` option - - .. warning:: - - Used **only** while **creating new table**, or in case of ``if_exists=recreate_entire_table`` - """ - - compression: Optional[str] = None - """Compressing algorithm which should be used for compressing created files in HDFS. - ``None`` means compression is disabled. - - Examples: ``snappy``, ``zlib`` - - .. warning:: - - Used **only** while **creating new table**, or in case of ``if_exists=recreate_entire_table`` - """ - - @validator("sort_by") - def sort_by_cannot_be_used_without_bucket_by(cls, sort_by, values): - options = values.copy() - bucket_by = options.pop("bucket_by", None) - if sort_by and not bucket_by: - raise ValueError("`sort_by` option can only be used with non-empty `bucket_by`") - - return sort_by - - @root_validator - def partition_overwrite_mode_is_not_allowed(cls, values): - partition_overwrite_mode = values.get("partitionOverwriteMode") or values.get("partition_overwrite_mode") - if partition_overwrite_mode: - if partition_overwrite_mode == "static": - recommend_mode = "replace_entire_table" - else: - recommend_mode = "replace_overlapping_partitions" - raise ValueError( - f"`partitionOverwriteMode` option should be replaced with if_exists='{recommend_mode}'", - ) - - if values.get("insert_into") is not None or values.get("insertInto") is not None: - raise ValueError( - "`insertInto` option was removed in onETL 0.4.0, " - "now df.write.insertInto or df.write.saveAsTable is selected based on table existence", - ) - - return values - - @root_validator(pre=True) - def mode_is_deprecated(cls, values): - if "mode" in values: - warnings.warn( - "Option `Hive.WriteOptions(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `Hive.WriteOptions(if_exists=...)` instead", - category=UserWarning, - stacklevel=3, - ) - return values - - @deprecated( - version="0.5.0", - reason="Please use 'WriteOptions' class instead. Will be removed in v1.0.0", - action="always", - category=UserWarning, - ) - class Options(WriteOptions): - pass - - @support_hooks - class Slots: - """:ref:`Slots ` that could be implemented by third-party plugins.""" - - @slot - @staticmethod - def normalize_cluster_name(cluster: str) -> str | None: - """ - Normalize cluster name passed into Hive constructor. |support_hooks| - - If hooks didn't return anything, cluster name is left intact. - - Parameters - ---------- - cluster : :obj:`str` - Cluster name (raw) - - Returns - ------- - str | None - Normalized cluster name. - - If hook cannot be applied to a specific cluster, it should return ``None``. - - Examples - -------- - - .. code:: python - - from onetl.connection import Hive - from onetl.hooks import hook - - - @Hive.Slots.normalize_cluster_name.bind - @hook - def normalize_cluster_name(cluster: str) -> str: - return cluster.lower() - """ - - @slot - @staticmethod - def get_known_clusters() -> set[str] | None: - """ - Return collection of known clusters. |support_hooks| - - Cluster passed into Hive constructor should be present in this list. - If hooks didn't return anything, no validation will be performed. - - Returns - ------- - set[str] | None - Collection of cluster names (normalized). - - If hook cannot be applied, it should return ``None``. - - Examples - -------- - - .. code:: python - - from onetl.connection import Hive - from onetl.hooks import hook - - - @Hive.Slots.get_known_clusters.bind - @hook - def get_known_clusters() -> str[str]: - return {"rnd-dwh", "rnd-prod"} - """ - - @slot - @staticmethod - def get_current_cluster() -> str | None: - """ - Get current cluster name. |support_hooks| - - Used in :obj:`~check` method to verify that connection is created only from the same cluster. - If hooks didn't return anything, no validation will be performed. - - Returns - ------- - str | None - Current cluster name (normalized). - - If hook cannot be applied, it should return ``None``. - - Examples - -------- - - .. code:: python - - from onetl.connection import Hive - from onetl.hooks import hook - - - @Hive.Slots.get_current_cluster.bind - @hook - def get_current_cluster() -> str: - # some magic here - return "rnd-dwh" - """ - - # TODO: remove in v1.0.0 - slots = Slots - - class Dialect( # noqa: WPS215 - SupportTableWithDBSchema, - SupportColumnsList, - SupportDfSchemaNone, - SupportWhereStr, - SupportHintStr, - SupportHWMExpressionStr, - SupportHWMColumnStr, - DBConnection.Dialect, - ): - pass - - cluster: Cluster - _CHECK_QUERY: ClassVar[str] = "SELECT 1" - - @validator("cluster") - def validate_cluster_name(cls, cluster): - log.debug("|%s| Normalizing cluster %r name ...", cls.__name__, cluster) - validated_cluster = cls.Slots.normalize_cluster_name(cluster) or cluster - if validated_cluster != cluster: - log.debug("|%s| Got %r", cls.__name__) - - log.debug("|%s| Checking if cluster %r is a known cluster ...", cls.__name__, validated_cluster) - known_clusters = cls.Slots.get_known_clusters() - if known_clusters and validated_cluster not in known_clusters: - raise ValueError( - f"Cluster {validated_cluster!r} is not in the known clusters list: {sorted(known_clusters)!r}", - ) - - return validated_cluster - - @slot - @classmethod - def get_current(cls, spark: SparkSession): - """ - Create connection for current cluster. |support_hooks| - - .. note:: - - Can be used only if there are some hooks bound :obj:`~slots.get_current_cluster` slot. - - Parameters - ---------- - spark : :obj:`pyspark.sql.SparkSession` - Spark session - - Examples - -------- - - .. code:: python - - from onetl.connection import Hive - from pyspark.sql import SparkSession - - spark = SparkSession.builder.appName("spark-app-name").enableHiveSupport().getOrCreate() - - # injecting current cluster name via hooks mechanism - hive = Hive.get_current(spark=spark) - """ - - log.info("|%s| Detecting current cluster...", cls.__name__) - current_cluster = cls.Slots.get_current_cluster() - if not current_cluster: - raise RuntimeError( - f"{cls.__name__}.get_current() can be used only if there are " - f"some hooks bound to {cls.__name__}.Slots.get_current_cluster", - ) - - log.info("|%s| Got %r", cls.__name__, current_cluster) - return cls(cluster=current_cluster, spark=spark) # type: ignore[arg-type] - - @property - def instance_url(self) -> str: - return self.cluster - - @slot - def check(self): - log.debug("|%s| Detecting current cluster...", self.__class__.__name__) - current_cluster = self.Slots.get_current_cluster() - if current_cluster and self.cluster != current_cluster: - raise ValueError("You can connect to a Hive cluster only from the same cluster") - - log.info("|%s| Checking connection availability...", self.__class__.__name__) - self._log_parameters() - - log.debug("|%s| Executing SQL query:", self.__class__.__name__) - log_lines(log, self._CHECK_QUERY, level=logging.DEBUG) - - try: - self._execute_sql(self._CHECK_QUERY) - log.info("|%s| Connection is available.", self.__class__.__name__) - except Exception as e: - log.exception("|%s| Connection is unavailable", self.__class__.__name__) - raise RuntimeError("Connection is unavailable") from e - - return self - - @slot - def sql( - self, - query: str, - ) -> DataFrame: - """ - Lazily execute SELECT statement and return DataFrame. |support_hooks| - - Same as ``spark.sql(query)``. - - Parameters - ---------- - query : str - - SQL query to be executed, like: - - * ``SELECT ... FROM ...`` - * ``WITH ... AS (...) SELECT ... FROM ...`` - * ``SHOW ...`` queries are also supported, like ``SHOW TABLES`` - - Returns - ------- - df : pyspark.sql.dataframe.DataFrame - - Spark dataframe - - Examples - -------- - - Read data from Hive table: - - .. code:: python - - connection = Hive(cluster="rnd-dwh", spark=spark) - - df = connection.sql("SELECT * FROM mytable") - """ - - query = clear_statement(query) - - log.info("|%s| Executing SQL query:", self.__class__.__name__) - log_lines(log, query) - - df = self._execute_sql(query) - log.info("|Spark| DataFrame successfully created from SQL statement") - return df - - @slot - def execute( - self, - statement: str, - ) -> None: - """ - Execute DDL or DML statement. |support_hooks| - - Parameters - ---------- - statement : str - - Statement to be executed, like: - - DML statements: - - * ``INSERT INTO target_table SELECT * FROM source_table`` - * ``TRUNCATE TABLE mytable`` - - DDL statements: - - * ``CREATE TABLE mytable (...)`` - * ``ALTER TABLE mytable ...`` - * ``DROP TABLE mytable`` - * ``MSCK REPAIR TABLE mytable`` - - The exact list of supported statements depends on Hive version, - for example some new versions support ``CREATE FUNCTION`` syntax. - - Examples - -------- - - Create table: - - .. code:: python - - connection = Hive(cluster="rnd-dwh", spark=spark) - - connection.execute( - "CREATE TABLE mytable (id NUMBER, data VARCHAR) PARTITIONED BY (date DATE)" - ) - - Drop table partition: - - .. code:: python - - connection = Hive(cluster="rnd-dwh", spark=spark) - - connection.execute("ALTER TABLE mytable DROP PARTITION(date='2023-02-01')") - """ - - statement = clear_statement(statement) - - log.info("|%s| Executing statement:", self.__class__.__name__) - log_lines(log, statement) - - self._execute_sql(statement).collect() - log.info("|%s| Call succeeded", self.__class__.__name__) - - @slot - def write_df_to_target( - self, - df: DataFrame, - target: str, - options: WriteOptions | dict | None = None, - ) -> None: - write_options = self.WriteOptions.parse(options) - - try: - self.get_df_schema(target) - table_exists = True - - log.info("|%s| Table %r already exists", self.__class__.__name__, target) - except Exception: - table_exists = False - - # https://stackoverflow.com/a/72747050 - if table_exists and write_options.if_exists != HiveTableExistBehavior.REPLACE_ENTIRE_TABLE: - # using saveAsTable on existing table does not handle - # spark.sql.sources.partitionOverwriteMode=dynamic, so using insertInto instead. - self._insert_into(df, target, options) - else: - # if someone needs to recreate the entire table using new set of options, like partitionBy or bucketBy, - # if_exists="replace_entire_table" should be used - self._save_as_table(df, target, options) - - @slot - def read_source_as_df( - self, - source: str, - columns: list[str] | None = None, - hint: str | None = None, - where: str | None = None, - df_schema: StructType | None = None, - start_from: Statement | None = None, - end_at: Statement | None = None, - ) -> DataFrame: - where = self.Dialect._condition_assembler(condition=where, start_from=start_from, end_at=end_at) - sql_text = get_sql_query( - table=source, - columns=columns, - where=where, - hint=hint, - ) - - return self.sql(sql_text) - - @slot - def get_df_schema( - self, - source: str, - columns: list[str] | None = None, - ) -> StructType: - log.info("|%s| Fetching schema of table table %r", self.__class__.__name__, source) - query = get_sql_query(source, columns=columns, where="1=0", compact=True) - - log.debug("|%s| Executing SQL query:", self.__class__.__name__) - log_lines(log, query, level=logging.DEBUG) - - df = self._execute_sql(query) - log.info("|%s| Schema fetched", self.__class__.__name__) - return df.schema - - @slot - def get_min_max_bounds( - self, - source: str, - column: str, - expression: str | None = None, - hint: str | None = None, - where: str | None = None, - ) -> Tuple[Any, Any]: - log.info("|Spark| Getting min and max values for column %r", column) - - sql_text = get_sql_query( - table=source, - columns=[ - self.Dialect._expression_with_alias( - self.Dialect._get_min_value_sql(expression or column), - "min", - ), - self.Dialect._expression_with_alias( - self.Dialect._get_max_value_sql(expression or column), - "max", - ), - ], - where=where, - hint=hint, - ) - - log.debug("|%s| Executing SQL query:", self.__class__.__name__) - log_lines(log, sql_text, level=logging.DEBUG) - - df = self._execute_sql(sql_text) - row = df.collect()[0] - min_value = row["min"] - max_value = row["max"] - - log.info("|Spark| Received values:") - log_with_indent(log, "MIN(%s) = %r", column, min_value) - log_with_indent(log, "MAX(%s) = %r", column, max_value) - - return min_value, max_value - - def _execute_sql(self, query: str) -> DataFrame: - return self.spark.sql(query) - - def _sort_df_columns_like_table(self, table: str, df_columns: list[str]) -> list[str]: - # Hive is inserting columns by the order, not by their name - # so if you're inserting dataframe with columns B, A, C to table with columns A, B, C, data will be damaged - # so it is important to sort columns in dataframe to match columns in the table. - - table_columns = self.spark.table(table).columns - - # But names could have different cases, this should not cause errors - table_columns_normalized = [column.casefold() for column in table_columns] - df_columns_normalized = [column.casefold() for column in df_columns] - - missing_columns_df = [column for column in df_columns_normalized if column not in table_columns_normalized] - missing_columns_table = [column for column in table_columns_normalized if column not in df_columns_normalized] - - if missing_columns_df or missing_columns_table: - missing_columns_df_message = "" - if missing_columns_df: - missing_columns_df_message = f""" - These columns present only in dataframe: - {missing_columns_df!r} - """ - - missing_columns_table_message = "" - if missing_columns_table: - missing_columns_table_message = f""" - These columns present only in table: - {missing_columns_table!r} - """ - - raise ValueError( - dedent( - f""" - Inconsistent columns between a table and the dataframe! - - Table {table!r} has columns: - {table_columns!r} - - Dataframe has columns: - {df_columns!r} - {missing_columns_df_message}{missing_columns_table_message} - """, - ).strip(), - ) - - return sorted(df_columns, key=lambda column: table_columns_normalized.index(column.casefold())) - - def _insert_into( - self, - df: DataFrame, - table: str, - options: WriteOptions | dict | None = None, - ) -> None: - write_options = self.WriteOptions.parse(options) - - log.info("|%s| Inserting data into existing table %r", self.__class__.__name__, table) - - unsupported_options = write_options.dict(by_alias=True, exclude_unset=True, exclude={"if_exists"}) - if unsupported_options: - log.warning( - "|%s| Options %r are not supported while inserting into existing table, ignoring", - self.__class__.__name__, - unsupported_options, - ) - - # Hive is inserting data to table by column position, not by name - # So we should sort columns according their order in the existing table - # instead of using order from the dataframe - columns = self._sort_df_columns_like_table(table, df.columns) - writer = df.select(*columns).write - - # Writer option "partitionOverwriteMode" was added to Spark only in 2.4.0 - # so using a workaround with patching Spark config and then setting up the previous value - with inject_spark_param(self.spark.conf, PARTITION_OVERWRITE_MODE_PARAM, "dynamic"): - overwrite = write_options.if_exists != HiveTableExistBehavior.APPEND - writer.insertInto(table, overwrite=overwrite) - - log.info("|%s| Data is successfully inserted into table %r", self.__class__.__name__, table) - - def _save_as_table( - self, - df: DataFrame, - table: str, - options: WriteOptions | dict | None = None, - ) -> None: - write_options = self.WriteOptions.parse(options) - - log.info("|%s| Saving data to a table %r", self.__class__.__name__, table) - - writer = df.write - for method, value in write_options.dict(by_alias=True, exclude_none=True, exclude={"if_exists"}).items(): - # is the arguments that will be passed to the - # format orc, parquet methods and format simultaneously - if hasattr(writer, method): - if isinstance(value, Iterable) and not isinstance(value, str): - writer = getattr(writer, method)(*value) # noqa: WPS220 - else: - writer = getattr(writer, method)(value) # noqa: WPS220 - else: - writer = writer.option(method, value) - - mode = "append" if write_options.if_exists == HiveTableExistBehavior.APPEND else "overwrite" - writer.mode(mode).saveAsTable(table) - - log.info("|%s| Table %r is successfully created", self.__class__.__name__, table) diff --git a/onetl/connection/db_connection/hive/__init__.py b/onetl/connection/db_connection/hive/__init__.py new file mode 100644 index 000000000..54a00f004 --- /dev/null +++ b/onetl/connection/db_connection/hive/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.hive.connection import Hive +from onetl.connection.db_connection.hive.dialect import HiveDialect +from onetl.connection.db_connection.hive.options import ( + HiveLegacyOptions, + HiveTableExistBehavior, + HiveWriteOptions, +) +from onetl.connection.db_connection.hive.slots import HiveSlots diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py new file mode 100644 index 000000000..d0bc08d29 --- /dev/null +++ b/onetl/connection/db_connection/hive/connection.py @@ -0,0 +1,551 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from textwrap import dedent +from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Tuple + +from etl_entities.instance import Cluster +from pydantic import validator + +from onetl._internal import clear_statement, get_sql_query +from onetl._util.spark import inject_spark_param +from onetl.connection.db_connection.db_connection import DBConnection +from onetl.connection.db_connection.hive.dialect import HiveDialect +from onetl.connection.db_connection.hive.options import ( + HiveLegacyOptions, + HiveTableExistBehavior, + HiveWriteOptions, +) +from onetl.connection.db_connection.hive.slots import HiveSlots +from onetl.hooks import slot, support_hooks +from onetl.hwm import Statement +from onetl.log import log_lines, log_with_indent + +if TYPE_CHECKING: + from pyspark.sql import DataFrame, SparkSession + from pyspark.sql.types import StructType + +PARTITION_OVERWRITE_MODE_PARAM = "spark.sql.sources.partitionOverwriteMode" +log = logging.getLogger(__name__) + + +@support_hooks +class Hive(DBConnection): + """Spark connection with Hive MetaStore support. |support_hooks| + + You don't need a Hive server to use this connector. + + .. dropdown:: Version compatibility + + * Hive metastore version: 0.12 - 3.1.2 (may require to add proper .jar file explicitly) + * Spark versions: 2.3.x - 3.4.x + * Java versions: 8 - 20 + + .. warning:: + + To use Hive connector you should have PySpark installed (or injected to ``sys.path``) + BEFORE creating the connector instance. + + You can install PySpark as follows: + + .. code:: bash + + pip install onetl[spark] # latest PySpark version + + # or + pip install onetl pyspark=3.4.1 # pass specific PySpark version + + See :ref:`spark-install` instruction for more details. + + .. warning:: + + This connector requires some additional configuration files to be present (``hive-site.xml`` and so on), + as well as .jar files with Hive MetaStore client. + + See `Spark Hive Tables documentation `_ + and `this guide `_ for more details. + + .. note:: + + Most of Hadoop instances use Kerberos authentication. In this case, you should call ``kinit`` + **BEFORE** starting Spark session to generate Kerberos ticket. See :ref:`kerberos-install`. + + In case of creating session with ``"spark.master": "yarn"``, you should also pass some additional options + to Spark session, allowing executors to generate their own Kerberos tickets to access HDFS. + See `Spark security documentation `_ + for more details. + + Parameters + ---------- + cluster : str + Cluster name. Used for HWM and lineage. + + spark : :obj:`pyspark.sql.SparkSession` + Spark session with Hive metastore support enabled + + Examples + -------- + + Hive connection initialization + + .. code:: python + + from onetl.connection import Hive + from pyspark.sql import SparkSession + + # Create Spark session + spark = SparkSession.builder.appName("spark-app-name").enableHiveSupport().getOrCreate() + + # Create connection + hive = Hive(cluster="rnd-dwh", spark=spark).check() + + Hive connection initialization with Kerberos support + + .. code:: python + + from onetl.connection import Hive + from pyspark.sql import SparkSession + + # Create Spark session + # Use names "spark.yarn.access.hadoopFileSystems", "spark.yarn.principal" + # and "spark.yarn.keytab" for Spark 2 + + spark = ( + SparkSession.builder.appName("spark-app-name") + .option("spark.kerberos.access.hadoopFileSystems", "hdfs://cluster.name.node:8020") + .option("spark.kerberos.principal", "user") + .option("spark.kerberos.keytab", "/path/to/keytab") + .enableHiveSupport() + .getOrCreate() + ) + + # Create connection + hive = Hive(cluster="rnd-dwh", spark=spark).check() + """ + + cluster: Cluster + + Dialect = HiveDialect + WriteOptions = HiveWriteOptions + Options = HiveLegacyOptions + Slots = HiveSlots + # TODO: remove in v1.0.0 + slots = HiveSlots + + _CHECK_QUERY: ClassVar[str] = "SELECT 1" + + @slot + @classmethod + def get_current(cls, spark: SparkSession): + """ + Create connection for current cluster. |support_hooks| + + .. note:: + + Can be used only if there are some hooks bound to + :obj:`Slots.get_current_cluster ` slot. + + Parameters + ---------- + spark : :obj:`pyspark.sql.SparkSession` + Spark session + + Examples + -------- + + .. code:: python + + from onetl.connection import Hive + from pyspark.sql import SparkSession + + spark = SparkSession.builder.appName("spark-app-name").enableHiveSupport().getOrCreate() + + # injecting current cluster name via hooks mechanism + hive = Hive.get_current(spark=spark) + """ + + log.info("|%s| Detecting current cluster...", cls.__name__) + current_cluster = cls.Slots.get_current_cluster() + if not current_cluster: + raise RuntimeError( + f"{cls.__name__}.get_current() can be used only if there are " + f"some hooks bound to {cls.__name__}.Slots.get_current_cluster", + ) + + log.info("|%s| Got %r", cls.__name__, current_cluster) + return cls(cluster=current_cluster, spark=spark) # type: ignore[arg-type] + + @property + def instance_url(self) -> str: + return self.cluster + + @slot + def check(self): + log.debug("|%s| Detecting current cluster...", self.__class__.__name__) + current_cluster = self.Slots.get_current_cluster() + if current_cluster and self.cluster != current_cluster: + raise ValueError("You can connect to a Hive cluster only from the same cluster") + + log.info("|%s| Checking connection availability...", self.__class__.__name__) + self._log_parameters() + + log.debug("|%s| Executing SQL query:", self.__class__.__name__) + log_lines(log, self._CHECK_QUERY, level=logging.DEBUG) + + try: + self._execute_sql(self._CHECK_QUERY) + log.info("|%s| Connection is available.", self.__class__.__name__) + except Exception as e: + log.exception("|%s| Connection is unavailable", self.__class__.__name__) + raise RuntimeError("Connection is unavailable") from e + + return self + + @slot + def sql( + self, + query: str, + ) -> DataFrame: + """ + Lazily execute SELECT statement and return DataFrame. |support_hooks| + + Same as ``spark.sql(query)``. + + Parameters + ---------- + query : str + + SQL query to be executed, like: + + * ``SELECT ... FROM ...`` + * ``WITH ... AS (...) SELECT ... FROM ...`` + * ``SHOW ...`` queries are also supported, like ``SHOW TABLES`` + + Returns + ------- + df : pyspark.sql.dataframe.DataFrame + + Spark dataframe + + Examples + -------- + + Read data from Hive table: + + .. code:: python + + connection = Hive(cluster="rnd-dwh", spark=spark) + + df = connection.sql("SELECT * FROM mytable") + """ + + query = clear_statement(query) + + log.info("|%s| Executing SQL query:", self.__class__.__name__) + log_lines(log, query) + + df = self._execute_sql(query) + log.info("|Spark| DataFrame successfully created from SQL statement") + return df + + @slot + def execute( + self, + statement: str, + ) -> None: + """ + Execute DDL or DML statement. |support_hooks| + + Parameters + ---------- + statement : str + + Statement to be executed, like: + + DML statements: + + * ``INSERT INTO target_table SELECT * FROM source_table`` + * ``TRUNCATE TABLE mytable`` + + DDL statements: + + * ``CREATE TABLE mytable (...)`` + * ``ALTER TABLE mytable ...`` + * ``DROP TABLE mytable`` + * ``MSCK REPAIR TABLE mytable`` + + The exact list of supported statements depends on Hive version, + for example some new versions support ``CREATE FUNCTION`` syntax. + + Examples + -------- + + Create table: + + .. code:: python + + connection = Hive(cluster="rnd-dwh", spark=spark) + + connection.execute( + "CREATE TABLE mytable (id NUMBER, data VARCHAR) PARTITIONED BY (date DATE)" + ) + + Drop table partition: + + .. code:: python + + connection = Hive(cluster="rnd-dwh", spark=spark) + + connection.execute("ALTER TABLE mytable DROP PARTITION(date='2023-02-01')") + """ + + statement = clear_statement(statement) + + log.info("|%s| Executing statement:", self.__class__.__name__) + log_lines(log, statement) + + self._execute_sql(statement).collect() + log.info("|%s| Call succeeded", self.__class__.__name__) + + @slot + def write_df_to_target( + self, + df: DataFrame, + target: str, + options: HiveWriteOptions | None = None, + ) -> None: + write_options = self.WriteOptions.parse(options) + + try: + self.get_df_schema(target) + table_exists = True + + log.info("|%s| Table %r already exists", self.__class__.__name__, target) + except Exception: + table_exists = False + + # https://stackoverflow.com/a/72747050 + if table_exists and write_options.if_exists != HiveTableExistBehavior.REPLACE_ENTIRE_TABLE: + # using saveAsTable on existing table does not handle + # spark.sql.sources.partitionOverwriteMode=dynamic, so using insertInto instead. + self._insert_into(df, target, options) + else: + # if someone needs to recreate the entire table using new set of options, like partitionBy or bucketBy, + # if_exists="replace_entire_table" should be used + self._save_as_table(df, target, options) + + @slot + def read_source_as_df( + self, + source: str, + columns: list[str] | None = None, + hint: str | None = None, + where: str | None = None, + df_schema: StructType | None = None, + start_from: Statement | None = None, + end_at: Statement | None = None, + ) -> DataFrame: + where = self.Dialect._condition_assembler(condition=where, start_from=start_from, end_at=end_at) + sql_text = get_sql_query( + table=source, + columns=columns, + where=where, + hint=hint, + ) + + return self.sql(sql_text) + + @slot + def get_df_schema( + self, + source: str, + columns: list[str] | None = None, + ) -> StructType: + log.info("|%s| Fetching schema of table table %r ...", self.__class__.__name__, source) + query = get_sql_query(source, columns=columns, where="1=0", compact=True) + + log.debug("|%s| Executing SQL query:", self.__class__.__name__) + log_lines(log, query, level=logging.DEBUG) + + df = self._execute_sql(query) + log.info("|%s| Schema fetched.", self.__class__.__name__) + return df.schema + + @slot + def get_min_max_bounds( + self, + source: str, + column: str, + expression: str | None = None, + hint: str | None = None, + where: str | None = None, + ) -> Tuple[Any, Any]: + log.info("|%s| Getting min and max values for column %r ...", self.__class__.__name__, column) + + sql_text = get_sql_query( + table=source, + columns=[ + self.Dialect._expression_with_alias( + self.Dialect._get_min_value_sql(expression or column), + self.Dialect._escape_column("min"), + ), + self.Dialect._expression_with_alias( + self.Dialect._get_max_value_sql(expression or column), + self.Dialect._escape_column("max"), + ), + ], + where=where, + hint=hint, + ) + + log.debug("|%s| Executing SQL query:", self.__class__.__name__) + log_lines(log, sql_text, level=logging.DEBUG) + + df = self._execute_sql(sql_text) + row = df.collect()[0] + min_value = row["min"] + max_value = row["max"] + + log.info("|%s| Received values:", self.__class__.__name__) + log_with_indent(log, "MIN(%s) = %r", column, min_value) + log_with_indent(log, "MAX(%s) = %r", column, max_value) + + return min_value, max_value + + @validator("cluster") + def _validate_cluster_name(cls, cluster): + log.debug("|%s| Normalizing cluster %r name...", cls.__name__, cluster) + validated_cluster = cls.Slots.normalize_cluster_name(cluster) or cluster + if validated_cluster != cluster: + log.debug("|%s| Got %r", cls.__name__) + + log.debug("|%s| Checking if cluster %r is a known cluster...", cls.__name__, validated_cluster) + known_clusters = cls.Slots.get_known_clusters() + if known_clusters and validated_cluster not in known_clusters: + raise ValueError( + f"Cluster {validated_cluster!r} is not in the known clusters list: {sorted(known_clusters)!r}", + ) + + return validated_cluster + + def _execute_sql(self, query: str) -> DataFrame: + return self.spark.sql(query) + + def _sort_df_columns_like_table(self, table: str, df_columns: list[str]) -> list[str]: + # Hive is inserting columns by the order, not by their name + # so if you're inserting dataframe with columns B, A, C to table with columns A, B, C, data will be damaged + # so it is important to sort columns in dataframe to match columns in the table. + + table_columns = self.spark.table(table).columns + + # But names could have different cases, this should not cause errors + table_columns_normalized = [column.casefold() for column in table_columns] + df_columns_normalized = [column.casefold() for column in df_columns] + + missing_columns_df = [column for column in df_columns_normalized if column not in table_columns_normalized] + missing_columns_table = [column for column in table_columns_normalized if column not in df_columns_normalized] + + if missing_columns_df or missing_columns_table: + missing_columns_df_message = "" + if missing_columns_df: + missing_columns_df_message = f""" + These columns present only in dataframe: + {missing_columns_df!r} + """ + + missing_columns_table_message = "" + if missing_columns_table: + missing_columns_table_message = f""" + These columns present only in table: + {missing_columns_table!r} + """ + + raise ValueError( + dedent( + f""" + Inconsistent columns between a table and the dataframe! + + Table {table!r} has columns: + {table_columns!r} + + Dataframe has columns: + {df_columns!r} + {missing_columns_df_message}{missing_columns_table_message} + """, + ).strip(), + ) + + return sorted(df_columns, key=lambda column: table_columns_normalized.index(column.casefold())) + + def _insert_into( + self, + df: DataFrame, + table: str, + options: HiveWriteOptions | dict | None = None, + ) -> None: + write_options = self.WriteOptions.parse(options) + + unsupported_options = write_options.dict(by_alias=True, exclude_unset=True, exclude={"if_exists"}) + if unsupported_options: + log.warning( + "|%s| Options %r are not supported while inserting into existing table, ignoring", + self.__class__.__name__, + unsupported_options, + ) + + # Hive is inserting data to table by column position, not by name + # So we should sort columns according their order in the existing table + # instead of using order from the dataframe + columns = self._sort_df_columns_like_table(table, df.columns) + writer = df.select(*columns).write + + # Writer option "partitionOverwriteMode" was added to Spark only in 2.4.0 + # so using a workaround with patching Spark config and then setting up the previous value + with inject_spark_param(self.spark.conf, PARTITION_OVERWRITE_MODE_PARAM, "dynamic"): + overwrite = write_options.if_exists != HiveTableExistBehavior.APPEND + + log.info("|%s| Inserting data into existing table %r ...", self.__class__.__name__, table) + writer.insertInto(table, overwrite=overwrite) + + log.info("|%s| Data is successfully inserted into table %r.", self.__class__.__name__, table) + + def _save_as_table( + self, + df: DataFrame, + table: str, + options: HiveWriteOptions | dict | None = None, + ) -> None: + write_options = self.WriteOptions.parse(options) + + writer = df.write + for method, value in write_options.dict(by_alias=True, exclude_none=True, exclude={"if_exists"}).items(): + # is the arguments that will be passed to the + # format orc, parquet methods and format simultaneously + if hasattr(writer, method): + if isinstance(value, Iterable) and not isinstance(value, str): + writer = getattr(writer, method)(*value) # noqa: WPS220 + else: + writer = getattr(writer, method)(value) # noqa: WPS220 + else: + writer = writer.option(method, value) + + mode = "append" if write_options.if_exists == HiveTableExistBehavior.APPEND else "overwrite" + + log.info("|%s| Saving data to a table %r ...", self.__class__.__name__, table) + writer.mode(mode).saveAsTable(table) + + log.info("|%s| Table %r is successfully created.", self.__class__.__name__, table) diff --git a/onetl/connection/db_connection/hive/dialect.py b/onetl/connection/db_connection/hive/dialect.py new file mode 100644 index 000000000..552e66559 --- /dev/null +++ b/onetl/connection/db_connection/hive/dialect.py @@ -0,0 +1,41 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from onetl.connection.db_connection.db_connection import DBDialect +from onetl.connection.db_connection.dialect_mixins import ( + SupportColumnsList, + SupportDfSchemaNone, + SupportHintStr, + SupportHWMColumnStr, + SupportHWMExpressionStr, + SupportNameWithSchemaOnly, + SupportWhereStr, +) + + +class HiveDialect( # noqa: WPS215 + SupportNameWithSchemaOnly, + SupportColumnsList, + SupportDfSchemaNone, + SupportWhereStr, + SupportHintStr, + SupportHWMExpressionStr, + SupportHWMColumnStr, + DBDialect, +): + @classmethod + def _escape_column(cls, value: str) -> str: + return f"`{value}`" diff --git a/onetl/connection/db_connection/hive/options.py b/onetl/connection/db_connection/hive/options.py new file mode 100644 index 000000000..c46b7882d --- /dev/null +++ b/onetl/connection/db_connection/hive/options.py @@ -0,0 +1,337 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from enum import Enum +from typing import List, Optional, Tuple, Union + +from deprecated import deprecated +from pydantic import Field, root_validator, validator + +from onetl.impl import GenericOptions + + +class HiveTableExistBehavior(str, Enum): + APPEND = "append" + REPLACE_ENTIRE_TABLE = "replace_entire_table" + REPLACE_OVERLAPPING_PARTITIONS = "replace_overlapping_partitions" + + def __str__(self): + return str(self.value) + + @classmethod # noqa: WPS120 + def _missing_(cls, value: object): # noqa: WPS120 + if str(value) == "overwrite": + warnings.warn( + "Mode `overwrite` is deprecated since v0.4.0 and will be removed in v1.0.0. " + "Use `replace_overlapping_partitions` instead", + category=UserWarning, + stacklevel=4, + ) + return cls.REPLACE_OVERLAPPING_PARTITIONS + + if str(value) == "overwrite_partitions": + warnings.warn( + "Mode `overwrite_partitions` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `replace_overlapping_partitions` instead", + category=UserWarning, + stacklevel=4, + ) + return cls.REPLACE_OVERLAPPING_PARTITIONS + + if str(value) == "overwrite_table": + warnings.warn( + "Mode `overwrite_table` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `replace_entire_table` instead", + category=UserWarning, + stacklevel=4, + ) + return cls.REPLACE_ENTIRE_TABLE + + +class HiveWriteOptions(GenericOptions): + """Hive source writing options. + + You can pass here key-value items which then will be converted to calls + of :obj:`pyspark.sql.readwriter.DataFrameWriter` methods. + + For example, ``Hive.WriteOptions(if_exists="append", partitionBy="reg_id")`` will + be converted to ``df.write.mode("append").partitionBy("reg_id")`` call, and so on. + + .. note:: + + You can pass any method and its value + `supported by Spark `_, + even if it is not mentioned in this documentation. **Option names should be in** ``camelCase``! + + The set of supported options depends on Spark version used. See link above. + + Examples + -------- + + Writing options initialization + + .. code:: python + + options = Hive.WriteOptions( + if_exists="append", + partitionBy="reg_id", + someNewOption="value", + ) + """ + + class Config: + known_options: frozenset = frozenset() + extra = "allow" + + if_exists: HiveTableExistBehavior = Field(default=HiveTableExistBehavior.APPEND, alias="mode") + """Behavior of writing data into existing table. + + Possible values: + * ``append`` (default) + Appends data into existing partition/table, or create partition/table if it does not exist. + + Same as Spark's ``df.write.insertInto(table, overwrite=False)``. + + .. dropdown:: Behavior in details + + * Table does not exist + Table is created using options provided by user (``format``, ``compression``, etc). + + * Table exists, but not partitioned, :obj:`~partition_by` is set + Data is appended to a table. Table is still not partitioned (DDL is unchanged). + + * Table exists and partitioned, but has different partitioning schema than :obj:`~partition_by` + Partition is created based on table's ``PARTITIONED BY (...)`` options. + Explicit :obj:`~partition_by` value is ignored. + + * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in dataframe + Partition is created. + + * Table exists and partitioned according :obj:`~partition_by`, partition is present in both dataframe and table + Data is appended to existing partition. + + .. warning:: + + This mode does not check whether table already contains + rows from dataframe, so duplicated rows can be created. + + To implement deduplication, write data to staging table first, + and then perform some deduplication logic using :obj:`~sql`. + + * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in table, not dataframe + Existing partition is left intact. + + * ``replace_overlapping_partitions`` + Overwrites data in the existing partition, or create partition/table if it does not exist. + + Same as Spark's ``df.write.insertInto(table, overwrite=True)`` + + ``spark.sql.sources.partitionOverwriteMode=dynamic``. + + .. dropdown:: Behavior in details + + * Table does not exist + Table is created using options provided by user (``format``, ``compression``, etc). + + * Table exists, but not partitioned, :obj:`~partition_by` is set + Data is **overwritten in all the table**. Table is still not partitioned (DDL is unchanged). + + * Table exists and partitioned, but has different partitioning schema than :obj:`~partition_by` + Partition is created based on table's ``PARTITIONED BY (...)`` options. + Explicit :obj:`~partition_by` value is ignored. + + * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in dataframe + Partition is created. + + * Table exists and partitioned according :obj:`~partition_by`, partition is present in both dataframe and table + Existing partition **replaced** with data from dataframe. + + * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in table, not dataframe + Existing partition is left intact. + + * ``replace_entire_table`` + **Recreates table** (via ``DROP + CREATE``), **deleting all existing data**. + **All existing partitions are dropped.** + + Same as Spark's ``df.write.saveAsTable(table, mode="overwrite")`` (NOT ``insertInto``)! + + .. warning:: + + Table is recreated using options provided by user (``format``, ``compression``, etc) + **instead of using original table options**. Be careful + + .. note:: + + ``error`` and ``ignore`` modes are not supported. + + .. note:: + + Unlike using pure Spark, config option ``spark.sql.sources.partitionOverwriteMode`` + does not affect behavior. + """ + + format: str = "orc" + """Format of files which should be used for storing table data. + + Examples: ``orc`` (default), ``parquet``, ``csv`` (NOT recommended) + + .. note:: + + It's better to use column-based formats like ``orc`` or ``parquet``, + not row-based (``csv``, ``json``) + + .. warning:: + + Used **only** while **creating new table**, or in case of ``if_exists=recreate_entire_table`` + """ + + partition_by: Optional[Union[List[str], str]] = Field(default=None, alias="partitionBy") + """ + List of columns should be used for data partitioning. ``None`` means partitioning is disabled. + + Each partition is a folder which contains only files with the specific column value, + like ``myschema.db/mytable/col1=value1``, ``myschema.db/mytable/col1=value2``, and so on. + + Multiple partitions columns means nested folder structure, like ``myschema.db/mytable/col1=val1/col2=val2``. + + If ``WHERE`` clause in the query contains expression like ``partition = value``, + Spark will scan only files in a specific partition. + + Examples: ``reg_id`` or ``["reg_id", "business_dt"]`` + + .. note:: + + Values should be scalars (integers, strings), + and either static (``countryId``) or incrementing (dates, years), with low + number of distinct values. + + Columns like ``userId`` or ``datetime``/``timestamp`` should **NOT** be used for partitioning. + + .. warning:: + + Used **only** while **creating new table**, or in case of ``if_exists=recreate_entire_table`` + """ + + bucket_by: Optional[Tuple[int, Union[List[str], str]]] = Field(default=None, alias="bucketBy") # noqa: WPS234 + """Number of buckets plus bucketing columns. ``None`` means bucketing is disabled. + + Each bucket is created as a set of files with name containing result of calculation ``hash(columns) mod num_buckets``. + + This allows to remove shuffle from queries containing ``GROUP BY`` or ``JOIN`` or using ``=`` / ``IN`` predicates + on specific columns. + + Examples: ``(10, "user_id")``, ``(10, ["user_id", "user_phone"])`` + + .. note:: + + Bucketing should be used on columns containing a lot of unique values, + like ``userId``. + + Columns like ``date`` should **NOT** be used for bucketing + because of too low number of unique values. + + .. warning:: + + It is recommended to use this option **ONLY** if you have a large table + (hundreds of Gb or more), which is used mostly for JOINs with other tables, + and you're inserting data using ``if_exists=overwrite_partitions`` or ``if_exists=recreate_entire_table``. + + Otherwise Spark will create a lot of small files + (one file for each bucket and each executor), drastically **decreasing** HDFS performance. + + .. warning:: + + Used **only** while **creating new table**, or in case of ``if_exists=recreate_entire_table`` + """ + + sort_by: Optional[Union[List[str], str]] = Field(default=None, alias="sortBy") + """Each file in a bucket will be sorted by these columns value. ``None`` means sorting is disabled. + + Examples: ``user_id`` or ``["user_id", "user_phone"]`` + + .. note:: + + Sorting columns should contain values which are used in ``ORDER BY`` clauses. + + .. warning:: + + Could be used only with :obj:`~bucket_by` option + + .. warning:: + + Used **only** while **creating new table**, or in case of ``if_exists=recreate_entire_table`` + """ + + compression: Optional[str] = None + """Compressing algorithm which should be used for compressing created files in HDFS. + ``None`` means compression is disabled. + + Examples: ``snappy``, ``zlib`` + + .. warning:: + + Used **only** while **creating new table**, or in case of ``if_exists=recreate_entire_table`` + """ + + @validator("sort_by") + def _sort_by_cannot_be_used_without_bucket_by(cls, sort_by, values): + options = values.copy() + bucket_by = options.pop("bucket_by", None) + if sort_by and not bucket_by: + raise ValueError("`sort_by` option can only be used with non-empty `bucket_by`") + + return sort_by + + @root_validator + def _partition_overwrite_mode_is_not_allowed(cls, values): + partition_overwrite_mode = values.get("partitionOverwriteMode") or values.get("partition_overwrite_mode") + if partition_overwrite_mode: + if partition_overwrite_mode == "static": + recommend_mode = "replace_entire_table" + else: + recommend_mode = "replace_overlapping_partitions" + raise ValueError( + f"`partitionOverwriteMode` option should be replaced with if_exists='{recommend_mode}'", + ) + + if values.get("insert_into") is not None or values.get("insertInto") is not None: + raise ValueError( + "`insertInto` option was removed in onETL 0.4.0, " + "now df.write.insertInto or df.write.saveAsTable is selected based on table existence", + ) + + return values + + @root_validator(pre=True) + def _mode_is_deprecated(cls, values): + if "mode" in values: + warnings.warn( + "Option `Hive.WriteOptions(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `Hive.WriteOptions(if_exists=...)` instead", + category=UserWarning, + stacklevel=3, + ) + return values + + +@deprecated( + version="0.5.0", + reason="Please use 'WriteOptions' class instead. Will be removed in v1.0.0", + action="always", + category=UserWarning, +) +class HiveLegacyOptions(HiveWriteOptions): + pass diff --git a/onetl/connection/db_connection/hive/slots.py b/onetl/connection/db_connection/hive/slots.py new file mode 100644 index 000000000..4105b5901 --- /dev/null +++ b/onetl/connection/db_connection/hive/slots.py @@ -0,0 +1,120 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from onetl.hooks import slot, support_hooks + + +@support_hooks +class HiveSlots: + """:ref:`Slots ` that could be implemented by third-party plugins.""" + + @slot + @staticmethod + def normalize_cluster_name(cluster: str) -> str | None: + """ + Normalize cluster name passed into Hive constructor. |support_hooks| + + If hooks didn't return anything, cluster name is left intact. + + Parameters + ---------- + cluster : :obj:`str` + Cluster name (raw) + + Returns + ------- + str | None + Normalized cluster name. + + If hook cannot be applied to a specific cluster, it should return ``None``. + + Examples + -------- + + .. code:: python + + from onetl.connection import Hive + from onetl.hooks import hook + + + @Hive.Slots.normalize_cluster_name.bind + @hook + def normalize_cluster_name(cluster: str) -> str: + return cluster.lower() + """ + + @slot + @staticmethod + def get_known_clusters() -> set[str] | None: + """ + Return collection of known clusters. |support_hooks| + + Cluster passed into Hive constructor should be present in this list. + If hooks didn't return anything, no validation will be performed. + + Returns + ------- + set[str] | None + Collection of cluster names (normalized). + + If hook cannot be applied, it should return ``None``. + + Examples + -------- + + .. code:: python + + from onetl.connection import Hive + from onetl.hooks import hook + + + @Hive.Slots.get_known_clusters.bind + @hook + def get_known_clusters() -> str[str]: + return {"rnd-dwh", "rnd-prod"} + """ + + @slot + @staticmethod + def get_current_cluster() -> str | None: + """ + Get current cluster name. |support_hooks| + + Used in :obj:`~check` method to verify that connection is created only from the same cluster. + If hooks didn't return anything, no validation will be performed. + + Returns + ------- + str | None + Current cluster name (normalized). + + If hook cannot be applied, it should return ``None``. + + Examples + -------- + + .. code:: python + + from onetl.connection import Hive + from onetl.hooks import hook + + + @Hive.Slots.get_current_cluster.bind + @hook + def get_current_cluster() -> str: + # some magic here + return "rnd-dwh" + """ diff --git a/onetl/connection/db_connection/jdbc_connection.py b/onetl/connection/db_connection/jdbc_connection.py deleted file mode 100644 index 4554e6660..000000000 --- a/onetl/connection/db_connection/jdbc_connection.py +++ /dev/null @@ -1,906 +0,0 @@ -# Copyright 2023 MTS (Mobile Telesystems) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import logging -import secrets -import warnings -from enum import Enum -from typing import TYPE_CHECKING, Any, Optional - -from deprecated import deprecated -from etl_entities.instance import Host -from pydantic import Field, PositiveInt, root_validator - -from onetl._internal import clear_statement, get_sql_query, to_camel -from onetl.connection.db_connection.db_connection import DBConnection -from onetl.connection.db_connection.dialect_mixins import ( - SupportColumnsList, - SupportDfSchemaNone, - SupportHintStr, - SupportHWMColumnStr, - SupportHWMExpressionStr, - SupportWhereStr, -) -from onetl.connection.db_connection.dialect_mixins.support_table_with_dbschema import ( - SupportTableWithDBSchema, -) -from onetl.connection.db_connection.jdbc_mixin import JDBCMixin -from onetl.hooks import slot, support_hooks -from onetl.hwm import Statement -from onetl.impl.generic_options import GenericOptions -from onetl.log import log_lines, log_with_indent - -if TYPE_CHECKING: - from pyspark.sql import DataFrame - from pyspark.sql.types import StructType - -log = logging.getLogger(__name__) - -# options from spark.read.jdbc which are populated by JDBCConnection methods -GENERIC_PROHIBITED_OPTIONS = frozenset( - ( - "table", - "dbtable", - "query", - "properties", - ), -) - -READ_WRITE_OPTIONS = frozenset( - ( - "keytab", - "principal", - "refreshKrb5Config", - "connectionProvider", - ), -) - -WRITE_OPTIONS = frozenset( - ( - "mode", - "column", # in some part of Spark source code option 'partitionColumn' is called just 'column' - "batchsize", - "isolationLevel", - "isolation_level", - "truncate", - "cascadeTruncate", - "createTableOptions", - "createTableColumnTypes", - "createTableColumnTypes", - ), -) - -READ_OPTIONS = frozenset( - ( - "column", # in some part of Spark source code option 'partitionColumn' is called just 'column' - "partitionColumn", - "partition_column", - "lowerBound", - "lower_bound", - "upperBound", - "upper_bound", - "numPartitions", - "num_partitions", - "fetchsize", - "sessionInitStatement", - "session_init_statement", - "customSchema", - "pushDownPredicate", - "pushDownAggregate", - "pushDownLimit", - "pushDownTableSample", - "predicates", - ), -) - - -# parameters accepted by spark.read.jdbc method: -# spark.read.jdbc( -# url, table, column, lowerBound, upperBound, numPartitions, predicates -# properties: { "user" : "SYSTEM", "password" : "mypassword", ... }) -READ_TOP_LEVEL_OPTIONS = frozenset(("url", "column", "lower_bound", "upper_bound", "num_partitions", "predicates")) - -# parameters accepted by spark.write.jdbc method: -# spark.write.jdbc( -# url, table, mode, -# properties: { "user" : "SYSTEM", "password" : "mypassword", ... }) -WRITE_TOP_LEVEL_OPTIONS = frozenset("url") - - -class JDBCTableExistBehavior(str, Enum): - APPEND = "append" - REPLACE_ENTIRE_TABLE = "replace_entire_table" - - def __str__(self) -> str: - return str(self.value) - - @classmethod # noqa: WPS120 - def _missing_(cls, value: object): # noqa: WPS120 - if str(value) == "overwrite": - warnings.warn( - "Mode `overwrite` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `replace_entire_table` instead", - category=UserWarning, - stacklevel=4, - ) - return cls.REPLACE_ENTIRE_TABLE - - -class PartitioningMode(str, Enum): - range = "range" - hash = "hash" - mod = "mod" - - def __str__(self): - return str(self.value) - - -@support_hooks -class JDBCConnection(SupportDfSchemaNone, JDBCMixin, DBConnection): - class Extra(GenericOptions): - class Config: - extra = "allow" - - class Dialect( # noqa: WPS215 - SupportTableWithDBSchema, - SupportColumnsList, - SupportDfSchemaNone, - SupportWhereStr, - SupportHintStr, - SupportHWMExpressionStr, - SupportHWMColumnStr, - DBConnection.Dialect, - ): - pass - - class ReadOptions(JDBCMixin.JDBCOptions): - """Spark JDBC options. - - .. note :: - - You can pass any value - `supported by Spark `_, - even if it is not mentioned in this documentation. **Option names should be in** ``camelCase``! - - The set of supported options depends on Spark version. See link above. - - Examples - -------- - - Read options initialization - - .. code:: python - - options = JDBC.ReadOptions( - partitionColumn="reg_id", - numPartitions=10, - lowerBound=0, - upperBound=1000, - someNewOption="value", - ) - """ - - class Config: - known_options = READ_OPTIONS | READ_WRITE_OPTIONS - prohibited_options = ( - JDBCMixin.JDBCOptions.Config.prohibited_options | GENERIC_PROHIBITED_OPTIONS | WRITE_OPTIONS - ) - alias_generator = to_camel - - # Options in DataFrameWriter.jdbc() method - partition_column: Optional[str] = None - """Column used to parallelize reading from a table. - - .. warning:: - It is highly recommended to use primary key, or at least a column with an index - to avoid performance issues. - - .. note:: - Column type depends on :obj:`~partitioning_mode`. - - * ``partitioning_mode="range"`` requires column to be an integer or date (can be NULL, but not recommended). - * ``partitioning_mode="hash"`` requires column to be an string (NOT NULL). - * ``partitioning_mode="mod"`` requires column to be an integer (NOT NULL). - - - See documentation for :obj:`~partitioning_mode` for more details""" - - num_partitions: PositiveInt = 1 - """Number of jobs created by Spark to read the table content in parallel. - See documentation for :obj:`~partitioning_mode` for more details""" - - lower_bound: Optional[int] = None - """See documentation for :obj:`~partitioning_mode` for more details""" # noqa: WPS322 - - upper_bound: Optional[int] = None - """See documentation for :obj:`~partitioning_mode` for more details""" # noqa: WPS322 - - session_init_statement: Optional[str] = None - '''After each database session is opened to the remote DB and before starting to read data, - this option executes a custom SQL statement (or a PL/SQL block). - - Use this to implement session initialization code. - - Example: - - .. code:: python - - sessionInitStatement = """ - BEGIN - execute immediate - 'alter session set "_serial_direct_read"=true'; - END; - """ - ''' - - fetchsize: int = 100_000 - """Fetch N rows from an opened cursor per one read round. - - Tuning this option can influence performance of reading. - - .. warning:: - - Default value is different from Spark. - - Spark uses driver's own value, and it may be different in different drivers, - and even versions of the same driver. For example, Oracle has - default ``fetchsize=10``, which is absolutely not usable. - - Thus we've overridden default value with ``100_000``, which should increase reading performance. - """ - - partitioning_mode: PartitioningMode = PartitioningMode.range - """Defines how Spark will parallelize reading from table. - - Possible values: - - * ``range`` (default) - Allocate each executor a range of values from column passed into :obj:`~partition_column`. - - Spark generates for each executor an SQL query like: - - Executor 1: - - .. code:: sql - - SELECT ... FROM table - WHERE (partition_column >= lowerBound - OR partition_column IS NULL) - AND partition_column < (lower_bound + stride) - - Executor 2: - - .. code:: sql - - SELECT ... FROM table - WHERE partition_column >= (lower_bound + stride) - AND partition_column < (lower_bound + 2 * stride) - - ... - - Executor N: - - .. code:: sql - - SELECT ... FROM table - WHERE partition_column >= (lower_bound + (N-1) * stride) - AND partition_column <= upper_bound - - Where ``stride=(upper_bound - lower_bound) / num_partitions``. - - .. note:: - - :obj:`~lower_bound`, :obj:`~upper_bound` and :obj:`~num_partitions` are used just to - calculate the partition stride, **NOT** for filtering the rows in table. - So all rows in the table will be returned (unlike *Incremental* :ref:`strategy`). - - .. note:: - - All queries are executed in parallel. To execute them sequentially, use *Batch* :ref:`strategy`. - - * ``hash`` - Allocate each executor a set of values based on hash of the :obj:`~partition_column` column. - - Spark generates for each executor an SQL query like: - - Executor 1: - - .. code:: sql - - SELECT ... FROM table - WHERE (some_hash(partition_column) mod num_partitions) = 0 -- lower_bound - - Executor 2: - - .. code:: sql - - SELECT ... FROM table - WHERE (some_hash(partition_column) mod num_partitions) = 1 -- lower_bound + 1 - - ... - - Executor N: - - .. code:: sql - - SELECT ... FROM table - WHERE (some_hash(partition_column) mod num_partitions) = num_partitions-1 -- upper_bound - - .. note:: - - The hash function implementation depends on RDBMS. It can be ``MD5`` or any other fast hash function, - or expression based on this function call. - - * ``mod`` - Allocate each executor a set of values based on modulus of the :obj:`~partition_column` column. - - Spark generates for each executor an SQL query like: - - Executor 1: - - .. code:: sql - - SELECT ... FROM table - WHERE (partition_column mod num_partitions) = 0 -- lower_bound - - Executor 2: - - .. code:: sql - - SELECT ... FROM table - WHERE (partition_column mod num_partitions) = 1 -- lower_bound + 1 - - Executor N: - - .. code:: sql - - SELECT ... FROM table - WHERE (partition_column mod num_partitions) = num_partitions-1 -- upper_bound - - Examples - -------- - - Read data in 10 parallel jobs by range of values in ``id_column`` column: - - .. code:: python - - Postgres.ReadOptions( - partitioning_mode="range", # default mode, can be omitted - partition_column="id_column", - num_partitions=10, - # if you're using DBReader, options below can be omitted - # because they are calculated by automatically as - # MIN and MAX values of `partition_column` - lower_bound=0, - upper_bound=100_000, - ) - - Read data in 10 parallel jobs by hash of values in ``some_column`` column: - - .. code:: python - - Postgres.ReadOptions( - partitioning_mode="hash", - partition_column="some_column", - num_partitions=10, - # lower_bound and upper_bound are automatically set to `0` and `9` - ) - - Read data in 10 parallel jobs by modulus of values in ``id_column`` column: - - .. code:: python - - Postgres.ReadOptions( - partitioning_mode="mod", - partition_column="id_column", - num_partitions=10, - # lower_bound and upper_bound are automatically set to `0` and `9` - ) - """ - - @root_validator - def partitioning_mode_actions(cls, values): - mode = values["partitioning_mode"] - num_partitions = values.get("num_partitions") - partition_column = values.get("partition_column") - lower_bound = values.get("lower_bound") - upper_bound = values.get("upper_bound") - - if not partition_column: - if num_partitions == 1: - return values - - raise ValueError("You should set partition_column to enable partitioning") - - elif num_partitions == 1: - raise ValueError("You should set num_partitions > 1 to enable partitioning") - - if mode == PartitioningMode.range: - return values - - if mode == PartitioningMode.hash: - values["partition_column"] = cls._get_partition_column_hash( - partition_column=partition_column, - num_partitions=num_partitions, - ) - - if mode == PartitioningMode.mod: - values["partition_column"] = cls._get_partition_column_mod( - partition_column=partition_column, - num_partitions=num_partitions, - ) - - values["lower_bound"] = lower_bound if lower_bound is not None else 0 - values["upper_bound"] = upper_bound if upper_bound is not None else num_partitions - - return values - - class WriteOptions(JDBCMixin.JDBCOptions): - """Spark JDBC writing options. - - .. note :: - - You can pass any value - `supported by Spark `_, - even if it is not mentioned in this documentation. **Option names should be in** ``camelCase``! - - The set of supported options depends on Spark version. See link above. - - Examples - -------- - - Write options initialization - - .. code:: python - - options = JDBC.WriteOptions(if_exists="append", batchsize=20_000, someNewOption="value") - """ - - class Config: - known_options = WRITE_OPTIONS | READ_WRITE_OPTIONS - prohibited_options = ( - JDBCMixin.JDBCOptions.Config.prohibited_options | GENERIC_PROHIBITED_OPTIONS | READ_OPTIONS - ) - alias_generator = to_camel - - if_exists: JDBCTableExistBehavior = Field(default=JDBCTableExistBehavior.APPEND, alias="mode") - """Behavior of writing data into existing table. - - Possible values: - * ``append`` (default) - Adds new rows into existing table. - - .. dropdown:: Behavior in details - - * Table does not exist - Table is created using options provided by user - (``createTableOptions``, ``createTableColumnTypes``, etc). - - * Table exists - Data is appended to a table. Table has the same DDL as before writing data - - .. warning:: - - This mode does not check whether table already contains - rows from dataframe, so duplicated rows can be created. - - Also Spark does not support passing custom options to - insert statement, like ``ON CONFLICT``, so don't try to - implement deduplication using unique indexes or constraints. - - Instead, write to staging table and perform deduplication - using :obj:`~execute` method. - - * ``replace_entire_table`` - **Table is dropped and then created, or truncated**. - - .. dropdown:: Behavior in details - - * Table does not exist - Table is created using options provided by user - (``createTableOptions``, ``createTableColumnTypes``, etc). - - * Table exists - Table content is replaced with dataframe content. - - After writing completed, target table could either have the same DDL as - before writing data (``truncate=True``), or can be recreated (``truncate=False`` - or source does not support truncation). - - .. note:: - - ``error`` and ``ignore`` modes are not supported. - """ - - batchsize: int = 20_000 - """How many rows can be inserted per round trip. - - Tuning this option can influence performance of writing. - - .. warning:: - - Default value is different from Spark. - - Spark uses quite small value ``1000``, which is absolutely not usable - in BigData world. - - Thus we've overridden default value with ``20_000``, - which should increase writing performance. - - You can increase it even more, up to ``50_000``, - but it depends on your database load and number of columns in the row. - Higher values does not increase performance. - """ - - isolation_level: str = "READ_UNCOMMITTED" - """The transaction isolation level, which applies to current connection. - - Possible values: - * ``NONE`` (as string, not Python's ``None``) - * ``READ_COMMITTED`` - * ``READ_UNCOMMITTED`` - * ``REPEATABLE_READ`` - * ``SERIALIZABLE`` - - Values correspond to transaction isolation levels defined by JDBC standard. - Please refer the documentation for - `java.sql.Connection `_. - """ - - @root_validator(pre=True) - def mode_is_deprecated(cls, values): - if "mode" in values: - warnings.warn( - "Option `WriteOptions(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `WriteOptions(if_exists=...)` instead", - category=UserWarning, - stacklevel=3, - ) - return values - - @deprecated( - version="0.5.0", - reason="Please use 'ReadOptions' or 'WriteOptions' class instead. Will be removed in v1.0.0", - action="always", - category=UserWarning, - ) - class Options(ReadOptions, WriteOptions): - class Config: - prohibited_options = JDBCMixin.JDBCOptions.Config.prohibited_options - - host: Host - port: int - extra: Extra = Extra() - - @property - def instance_url(self) -> str: - return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}" - - @slot - def sql( - self, - query: str, - options: ReadOptions | dict | None = None, - ) -> DataFrame: - """ - **Lazily** execute SELECT statement **on Spark executor** and return DataFrame. |support_hooks| - - Same as ``spark.read.jdbc(query)``. - - .. note:: - - This method does not support :ref:`strategy`, - use :obj:`DBReader ` instead - - .. note:: - - Statement is executed in read-write connection, - so if you're calling some functions/procedures with DDL/DML statements inside, - they can change data in your database. - - Unfortunately, Spark does no provide any option to change this behavior. - - Parameters - ---------- - query : str - - SQL query to be executed. - - Only ``SELECT ... FROM ...`` form is supported. - - Some databases also supports ``WITH ... AS (...) SELECT ... FROM ...`` form. - - Queries like ``SHOW ...`` are not supported. - - .. warning:: - - The exact syntax **depends on RDBMS** is being used. - - options : dict, :obj:`~ReadOptions`, default: ``None`` - - Spark options to be used while fetching data, like ``fetchsize`` or ``partitionColumn`` - - Returns - ------- - df : pyspark.sql.dataframe.DataFrame - - Spark dataframe - - Examples - -------- - - Read data from a table: - - .. code:: python - - df = connection.sql("SELECT * FROM mytable") - - Read data from a table with options: - - .. code:: python - - # reads data from table in batches, 10000 rows per batch - df = connection.sql("SELECT * FROM mytable", {"fetchsize": 10000}) - assert df.count() - - """ - - query = clear_statement(query) - - log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__) - log_lines(log, query) - - df = self._query_on_executor(query, self.ReadOptions.parse(options)) - - log.info("|Spark| DataFrame successfully created from SQL statement ") - return df - - @slot - def read_source_as_df( - self, - source: str, - columns: list[str] | None = None, - hint: str | None = None, - where: str | None = None, - df_schema: StructType | None = None, - start_from: Statement | None = None, - end_at: Statement | None = None, - options: ReadOptions | dict | None = None, - ) -> DataFrame: - read_options = self._set_lower_upper_bound( - table=source, - where=where, - hint=hint, - options=self.ReadOptions.parse(options).copy(exclude={"if_exists", "partitioning_mode"}), - ) - - # hack to avoid column name verification - # in the spark, the expression in the partitioning of the column must - # have the same name as the field in the table ( 2.4 version ) - # https://github.com/apache/spark/pull/21379 - - new_columns = columns or ["*"] - alias = "x" + secrets.token_hex(5) - - if read_options.partition_column: - aliased = self.Dialect._expression_with_alias(read_options.partition_column, alias) - read_options = read_options.copy(update={"partition_column": alias}) - new_columns.append(aliased) - - where = self.Dialect._condition_assembler(condition=where, start_from=start_from, end_at=end_at) - - query = get_sql_query( - table=source, - columns=new_columns, - where=where, - hint=hint, - ) - - result = self.sql(query, read_options) - - if read_options.partition_column: - result = result.drop(alias) - - return result - - @slot - def write_df_to_target( - self, - df: DataFrame, - target: str, - options: WriteOptions | dict | None = None, - ) -> None: - write_options = self.WriteOptions.parse(options) - jdbc_params = self.options_to_jdbc_params(write_options) - - mode = "append" if write_options.if_exists == JDBCTableExistBehavior.APPEND else "overwrite" - log.info("|%s| Saving data to a table %r", self.__class__.__name__, target) - df.write.jdbc(table=target, mode=mode, **jdbc_params) - log.info("|%s| Table %r successfully written", self.__class__.__name__, target) - - @slot - def get_df_schema( - self, - source: str, - columns: list[str] | None = None, - options: JDBCMixin.JDBCOptions | dict | None = None, - ) -> StructType: - log.info("|%s| Fetching schema of table %r", self.__class__.__name__, source) - - query = get_sql_query(source, columns=columns, where="1=0", compact=True) - read_options = self._exclude_partition_options(options, fetchsize=0) - - log.debug("|%s| Executing SQL query (on driver):", self.__class__.__name__) - log_lines(log, query, level=logging.DEBUG) - - df = self._query_on_driver(query, read_options) - log.info("|%s| Schema fetched", self.__class__.__name__) - - return df.schema - - def options_to_jdbc_params( - self, - options: ReadOptions | WriteOptions, - ) -> dict: - # Have to replace the parameter with - # since the method takes the named parameter - # link to source below - # https://github.com/apache/spark/blob/2ef8ced27a6b0170a691722a855d3886e079f037/python/pyspark/sql/readwriter.py#L465 - - partition_column = getattr(options, "partition_column", None) - if partition_column: - options = options.copy( - update={"column": partition_column}, - exclude={"partition_column"}, - ) - - result = self._get_jdbc_properties( - options, - include=READ_TOP_LEVEL_OPTIONS | WRITE_TOP_LEVEL_OPTIONS, - exclude={"if_exists"}, - exclude_none=True, - ) - - result["properties"] = self._get_jdbc_properties( - options, - exclude=READ_TOP_LEVEL_OPTIONS | WRITE_TOP_LEVEL_OPTIONS | {"if_exists"}, - exclude_none=True, - ) - - result["properties"].pop("partitioningMode", None) - - return result - - @slot - def get_min_max_bounds( - self, - source: str, - column: str, - expression: str | None = None, - hint: str | None = None, - where: str | None = None, - options: JDBCMixin.JDBCOptions | dict | None = None, - ) -> tuple[Any, Any]: - log.info("|Spark| Getting min and max values for column %r", column) - - read_options = self._exclude_partition_options(options, fetchsize=1) - - query = get_sql_query( - table=source, - columns=[ - self.Dialect._expression_with_alias( - self.Dialect._get_min_value_sql(expression or column), - "min", - ), - self.Dialect._expression_with_alias( - self.Dialect._get_max_value_sql(expression or column), - "max", - ), - ], - where=where, - hint=hint, - ) - - log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__) - log_lines(log, query) - - df = self._query_on_driver(query, read_options) - row = df.collect()[0] - min_value = row["min"] - max_value = row["max"] - - log.info("|Spark| Received values:") - log_with_indent(log, "MIN(%s) = %r", column, min_value) - log_with_indent(log, "MAX(%s) = %r", column, max_value) - - return min_value, max_value - - def _query_on_executor( - self, - query: str, - options: ReadOptions, - ) -> DataFrame: - jdbc_params = self.options_to_jdbc_params(options) - return self.spark.read.jdbc(table=f"({query}) T", **jdbc_params) - - def _exclude_partition_options( - self, - options: JDBCMixin.JDBCOptions | dict | None, - fetchsize: int, - ) -> JDBCMixin.JDBCOptions: - return self.JDBCOptions.parse(options).copy( - update={"fetchsize": fetchsize}, - exclude={"partition_column", "lower_bound", "upper_bound", "num_partitions", "partitioning_mode"}, - ) - - def _set_lower_upper_bound( - self, - table: str, - hint: str | None = None, - where: str | None = None, - options: ReadOptions | dict | None = None, - ) -> ReadOptions: - """ - Determine values of upperBound and lowerBound options - """ - - result_options = self.ReadOptions.parse(options) - - if not result_options.partition_column: - return result_options - - missing_values: list[str] = [] - - is_missed_lower_bound = result_options.lower_bound is None - is_missed_upper_bound = result_options.upper_bound is None - - if is_missed_lower_bound: - missing_values.append("lowerBound") - - if is_missed_upper_bound: - missing_values.append("upperBound") - - if not missing_values: - return result_options - - log.warning( - "|Spark| Passed numPartitions = %d, but values %r are not set. " - "They will be detected automatically based on values in partitionColumn %r", - result_options.num_partitions, - missing_values, - result_options.partition_column, - ) - - min_partition_value, max_partition_value = self.get_min_max_bounds( - source=table, - column=result_options.partition_column, - where=where, - hint=hint, - options=result_options, - ) - - # The sessionInitStatement parameter is removed because it only needs to be applied once. - return result_options.copy( - exclude={"session_init_statement"}, - update={ - "lower_bound": result_options.lower_bound if not is_missed_lower_bound else min_partition_value, - "upper_bound": result_options.upper_bound if not is_missed_upper_bound else max_partition_value, - }, - ) - - def _log_parameters(self): - super()._log_parameters() - log_with_indent(log, "jdbc_url = %r", self.jdbc_url) diff --git a/onetl/connection/db_connection/jdbc_connection/__init__.py b/onetl/connection/db_connection/jdbc_connection/__init__.py new file mode 100644 index 000000000..9c11532c1 --- /dev/null +++ b/onetl/connection/db_connection/jdbc_connection/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.jdbc_connection.connection import JDBCConnection +from onetl.connection.db_connection.jdbc_connection.dialect import JDBCDialect +from onetl.connection.db_connection.jdbc_connection.options import ( + JDBCPartitioningMode, + JDBCReadOptions, + JDBCTableExistBehavior, + JDBCWriteOptions, +) diff --git a/onetl/connection/db_connection/jdbc_connection/connection.py b/onetl/connection/db_connection/jdbc_connection/connection.py new file mode 100644 index 000000000..3eb83f538 --- /dev/null +++ b/onetl/connection/db_connection/jdbc_connection/connection.py @@ -0,0 +1,395 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import secrets +from typing import TYPE_CHECKING, Any + +from etl_entities.instance import Host + +from onetl._internal import clear_statement, get_sql_query +from onetl.connection.db_connection.db_connection import DBConnection +from onetl.connection.db_connection.jdbc_connection.dialect import JDBCDialect +from onetl.connection.db_connection.jdbc_connection.options import ( + JDBCLegacyOptions, + JDBCPartitioningMode, + JDBCReadOptions, + JDBCTableExistBehavior, + JDBCWriteOptions, +) +from onetl.connection.db_connection.jdbc_mixin import JDBCMixin +from onetl.connection.db_connection.jdbc_mixin.options import JDBCOptions +from onetl.hooks import slot, support_hooks +from onetl.hwm import Statement +from onetl.log import log_lines, log_with_indent + +if TYPE_CHECKING: + from pyspark.sql import DataFrame + from pyspark.sql.types import StructType + +log = logging.getLogger(__name__) + +# parameters accepted by spark.read.jdbc method: +# spark.read.jdbc( +# url, table, column, lowerBound, upperBound, numPartitions, predicates +# properties: { "user" : "SYSTEM", "password" : "mypassword", ... }) +READ_TOP_LEVEL_OPTIONS = frozenset(("url", "column", "lower_bound", "upper_bound", "num_partitions", "predicates")) + +# parameters accepted by spark.write.jdbc method: +# spark.write.jdbc( +# url, table, mode, +# properties: { "user" : "SYSTEM", "password" : "mypassword", ... }) +WRITE_TOP_LEVEL_OPTIONS = frozenset("url") + + +@support_hooks +class JDBCConnection(JDBCMixin, DBConnection): + host: Host + port: int + + Dialect = JDBCDialect + ReadOptions = JDBCReadOptions + WriteOptions = JDBCWriteOptions + Options = JDBCLegacyOptions + + @property + def instance_url(self) -> str: + return f"{self.__class__.__name__.lower()}://{self.host}:{self.port}" + + @slot + def sql( + self, + query: str, + options: JDBCReadOptions | dict | None = None, + ) -> DataFrame: + """ + **Lazily** execute SELECT statement **on Spark executor** and return DataFrame. |support_hooks| + + Same as ``spark.read.jdbc(query)``. + + .. note:: + + This method does not support :ref:`strategy`, + use :obj:`DBReader ` instead + + .. note:: + + Statement is executed in read-write connection, + so if you're calling some functions/procedures with DDL/DML statements inside, + they can change data in your database. + + Unfortunately, Spark does no provide any option to change this behavior. + + Parameters + ---------- + query : str + + SQL query to be executed. + + Only ``SELECT ... FROM ...`` form is supported. + + Some databases also supports ``WITH ... AS (...) SELECT ... FROM ...`` form. + + Queries like ``SHOW ...`` are not supported. + + .. warning:: + + The exact syntax **depends on RDBMS** is being used. + + options : dict, :obj:`~ReadOptions`, default: ``None`` + + Spark options to be used while fetching data, like ``fetchsize`` or ``partitionColumn`` + + Returns + ------- + df : pyspark.sql.dataframe.DataFrame + + Spark dataframe + + Examples + -------- + + Read data from a table: + + .. code:: python + + df = connection.sql("SELECT * FROM mytable") + + Read data from a table with options: + + .. code:: python + + # reads data from table in batches, 10000 rows per batch + df = connection.sql("SELECT * FROM mytable", {"fetchsize": 10000}) + assert df.count() + + """ + + query = clear_statement(query) + + log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__) + log_lines(log, query) + + df = self._query_on_executor(query, self.ReadOptions.parse(options)) + + log.info("|Spark| DataFrame successfully created from SQL statement ") + return df + + @slot + def read_source_as_df( + self, + source: str, + columns: list[str] | None = None, + hint: str | None = None, + where: str | None = None, + df_schema: StructType | None = None, + start_from: Statement | None = None, + end_at: Statement | None = None, + options: JDBCReadOptions | None = None, + ) -> DataFrame: + read_options = self._set_lower_upper_bound( + table=source, + where=where, + hint=hint, + options=self.ReadOptions.parse(options), + ) + + new_columns = columns or ["*"] + alias: str | None = None + + if read_options.partition_column: + if read_options.partitioning_mode == JDBCPartitioningMode.MOD: + partition_column = self.Dialect._get_partition_column_mod( + read_options.partition_column, + read_options.num_partitions, + ) + elif read_options.partitioning_mode == JDBCPartitioningMode.HASH: + partition_column = self.Dialect._get_partition_column_hash( + read_options.partition_column, + read_options.num_partitions, + ) + else: + partition_column = read_options.partition_column + + # hack to avoid column name verification + # in the spark, the expression in the partitioning of the column must + # have the same name as the field in the table ( 2.4 version ) + # https://github.com/apache/spark/pull/21379 + alias = "generated_" + secrets.token_hex(5) + alias_escaped = self.Dialect._escape_column(alias) + aliased_column = self.Dialect._expression_with_alias(partition_column, alias_escaped) + read_options = read_options.copy(update={"partition_column": alias_escaped}) + new_columns.append(aliased_column) + + where = self.Dialect._condition_assembler(condition=where, start_from=start_from, end_at=end_at) + query = get_sql_query( + table=source, + columns=new_columns, + where=where, + hint=hint, + ) + + result = self.sql(query, read_options) + if alias: + result = result.drop(alias) + + return result + + @slot + def write_df_to_target( + self, + df: DataFrame, + target: str, + options: JDBCWriteOptions | None = None, + ) -> None: + write_options = self.WriteOptions.parse(options) + jdbc_params = self.options_to_jdbc_params(write_options) + + mode = "append" if write_options.if_exists == JDBCTableExistBehavior.APPEND else "overwrite" + log.info("|%s| Saving data to a table %r", self.__class__.__name__, target) + df.write.jdbc(table=target, mode=mode, **jdbc_params) + log.info("|%s| Table %r successfully written", self.__class__.__name__, target) + + @slot + def get_df_schema( + self, + source: str, + columns: list[str] | None = None, + options: JDBCReadOptions | None = None, + ) -> StructType: + log.info("|%s| Fetching schema of table %r ...", self.__class__.__name__, source) + + query = get_sql_query(source, columns=columns, where="1=0", compact=True) + read_options = self._exclude_partition_options(self.ReadOptions.parse(options), fetchsize=0) + + log.debug("|%s| Executing SQL query (on driver):", self.__class__.__name__) + log_lines(log, query, level=logging.DEBUG) + + df = self._query_on_driver(query, read_options) + log.info("|%s| Schema fetched.", self.__class__.__name__) + + return df.schema + + def options_to_jdbc_params( + self, + options: JDBCReadOptions | JDBCWriteOptions, + ) -> dict: + # Have to replace the parameter with + # since the method takes the named parameter + # link to source below + # https://github.com/apache/spark/blob/2ef8ced27a6b0170a691722a855d3886e079f037/python/pyspark/sql/readwriter.py#L465 + + partition_column = getattr(options, "partition_column", None) + if partition_column: + options = options.copy( + update={"column": partition_column}, + exclude={"partition_column"}, + ) + + result = self._get_jdbc_properties( + options, + include=READ_TOP_LEVEL_OPTIONS | WRITE_TOP_LEVEL_OPTIONS, + exclude={"if_exists"}, + exclude_none=True, + ) + + result["properties"] = self._get_jdbc_properties( + options, + exclude=READ_TOP_LEVEL_OPTIONS | WRITE_TOP_LEVEL_OPTIONS | {"if_exists"}, + exclude_none=True, + ) + + result["properties"].pop("partitioningMode", None) + return result + + @slot + def get_min_max_bounds( + self, + source: str, + column: str, + expression: str | None = None, + hint: str | None = None, + where: str | None = None, + options: JDBCReadOptions | None = None, + ) -> tuple[Any, Any]: + log.info("|%s| Getting min and max values for column %r ...", self.__class__.__name__, column) + + read_options = self._exclude_partition_options(self.ReadOptions.parse(options), fetchsize=1) + + query = get_sql_query( + table=source, + columns=[ + self.Dialect._expression_with_alias( + self.Dialect._get_min_value_sql(expression or column), + self.Dialect._escape_column("min"), + ), + self.Dialect._expression_with_alias( + self.Dialect._get_max_value_sql(expression or column), + self.Dialect._escape_column("max"), + ), + ], + where=where, + hint=hint, + ) + + log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__) + log_lines(log, query) + + df = self._query_on_driver(query, read_options) + row = df.collect()[0] + min_value = row["min"] + max_value = row["max"] + + log.info("|%s| Received values:", self.__class__.__name__) + log_with_indent(log, "MIN(%s) = %r", column, min_value) + log_with_indent(log, "MAX(%s) = %r", column, max_value) + + return min_value, max_value + + def _query_on_executor( + self, + query: str, + options: JDBCReadOptions, + ) -> DataFrame: + jdbc_params = self.options_to_jdbc_params(options) + return self.spark.read.jdbc(table=f"({query}) T", **jdbc_params) + + def _exclude_partition_options( + self, + options: JDBCReadOptions, + fetchsize: int, + ) -> JDBCOptions: + return options.copy( + update={"fetchsize": fetchsize}, + exclude={"partition_column", "lower_bound", "upper_bound", "num_partitions", "partitioning_mode"}, + ) + + def _set_lower_upper_bound( + self, + table: str, + hint: str | None, + where: str | None, + options: JDBCReadOptions, + ) -> JDBCReadOptions: + """ + Determine values of upperBound and lowerBound options + """ + if not options.partition_column: + return options + + missing_values: list[str] = [] + + is_missed_lower_bound = options.lower_bound is None + is_missed_upper_bound = options.upper_bound is None + + if is_missed_lower_bound: + missing_values.append("lowerBound") + + if is_missed_upper_bound: + missing_values.append("upperBound") + + if not missing_values: + return options + + log.warning( + "|%s| Passed numPartitions = %d, but values %r are not set. " + "They will be detected automatically based on values in partitionColumn %r", + self.__class__.__name__, + options.num_partitions, + missing_values, + options.partition_column, + ) + + min_partition_value, max_partition_value = self.get_min_max_bounds( + source=table, + column=options.partition_column, + where=where, + hint=hint, + options=options, + ) + + # The sessionInitStatement parameter is removed because it only needs to be applied once. + return options.copy( + exclude={"session_init_statement"}, + update={ + "lower_bound": options.lower_bound if not is_missed_lower_bound else min_partition_value, + "upper_bound": options.upper_bound if not is_missed_upper_bound else max_partition_value, + }, + ) + + def _log_parameters(self): + super()._log_parameters() + log_with_indent(log, "jdbc_url = %r", self.jdbc_url) diff --git a/onetl/connection/db_connection/jdbc_connection/dialect.py b/onetl/connection/db_connection/jdbc_connection/dialect.py new file mode 100644 index 000000000..790a0c300 --- /dev/null +++ b/onetl/connection/db_connection/jdbc_connection/dialect.py @@ -0,0 +1,49 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod + +from onetl.connection.db_connection.db_connection import DBDialect +from onetl.connection.db_connection.dialect_mixins import ( + SupportColumnsList, + SupportDfSchemaNone, + SupportHintStr, + SupportHWMColumnStr, + SupportHWMExpressionStr, + SupportNameWithSchemaOnly, + SupportWhereStr, +) + + +class JDBCDialect( # noqa: WPS215 + SupportNameWithSchemaOnly, + SupportColumnsList, + SupportDfSchemaNone, + SupportWhereStr, + SupportHintStr, + SupportHWMExpressionStr, + SupportHWMColumnStr, + DBDialect, +): + @classmethod + @abstractmethod + def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: + ... + + @classmethod + @abstractmethod + def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: + ... diff --git a/onetl/connection/db_connection/jdbc_connection/options.py b/onetl/connection/db_connection/jdbc_connection/options.py new file mode 100644 index 000000000..c998055fe --- /dev/null +++ b/onetl/connection/db_connection/jdbc_connection/options.py @@ -0,0 +1,511 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from enum import Enum +from typing import Optional + +from deprecated import deprecated +from pydantic import Field, PositiveInt, root_validator + +from onetl._internal import to_camel +from onetl.connection.db_connection.jdbc_mixin.options import JDBCOptions + +# options from spark.read.jdbc which are populated by JDBCConnection methods +GENERIC_PROHIBITED_OPTIONS = frozenset( + ( + "table", + "dbtable", + "query", + "properties", + ), +) + +READ_WRITE_OPTIONS = frozenset( + ( + "keytab", + "principal", + "refreshKrb5Config", + "connectionProvider", + ), +) + +WRITE_OPTIONS = frozenset( + ( + "mode", + "column", # in some part of Spark source code option 'partitionColumn' is called just 'column' + "batchsize", + "isolationLevel", + "isolation_level", + "truncate", + "cascadeTruncate", + "createTableOptions", + "createTableColumnTypes", + "createTableColumnTypes", + ), +) + +READ_OPTIONS = frozenset( + ( + "column", # in some part of Spark source code option 'partitionColumn' is called just 'column' + "partitionColumn", + "partition_column", + "lowerBound", + "lower_bound", + "upperBound", + "upper_bound", + "numPartitions", + "num_partitions", + "fetchsize", + "sessionInitStatement", + "session_init_statement", + "customSchema", + "pushDownPredicate", + "pushDownAggregate", + "pushDownLimit", + "pushDownTableSample", + "predicates", + ), +) + + +class JDBCTableExistBehavior(str, Enum): + APPEND = "append" + REPLACE_ENTIRE_TABLE = "replace_entire_table" + + def __str__(self) -> str: + return str(self.value) + + @classmethod # noqa: WPS120 + def _missing_(cls, value: object): # noqa: WPS120 + if str(value) == "overwrite": + warnings.warn( + "Mode `overwrite` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `replace_entire_table` instead", + category=UserWarning, + stacklevel=4, + ) + return cls.REPLACE_ENTIRE_TABLE + + +class JDBCPartitioningMode(str, Enum): + RANGE = "range" + HASH = "hash" + MOD = "mod" + + def __str__(self): + return str(self.value) + + +class JDBCReadOptions(JDBCOptions): + """Spark JDBC reading options. + + .. note :: + + You can pass any value + `supported by Spark `_, + even if it is not mentioned in this documentation. **Option names should be in** ``camelCase``! + + The set of supported options depends on Spark version. See link above. + + Examples + -------- + + Read options initialization + + .. code:: python + + options = JDBC.ReadOptions( + partitionColumn="reg_id", + numPartitions=10, + lowerBound=0, + upperBound=1000, + someNewOption="value", + ) + """ + + class Config: + known_options = READ_OPTIONS | READ_WRITE_OPTIONS + prohibited_options = JDBCOptions.Config.prohibited_options | GENERIC_PROHIBITED_OPTIONS | WRITE_OPTIONS + alias_generator = to_camel + + # Options in DataFrameWriter.jdbc() method + partition_column: Optional[str] = None + """Column used to parallelize reading from a table. + + .. warning:: + It is highly recommended to use primary key, or at least a column with an index + to avoid performance issues. + + .. note:: + Column type depends on :obj:`~partitioning_mode`. + + * ``partitioning_mode="range"`` requires column to be an integer or date (can be NULL, but not recommended). + * ``partitioning_mode="hash"`` requires column to be an string (NOT NULL). + * ``partitioning_mode="mod"`` requires column to be an integer (NOT NULL). + + + See documentation for :obj:`~partitioning_mode` for more details""" + + num_partitions: PositiveInt = 1 + """Number of jobs created by Spark to read the table content in parallel. + See documentation for :obj:`~partitioning_mode` for more details""" + + lower_bound: Optional[int] = None + """See documentation for :obj:`~partitioning_mode` for more details""" # noqa: WPS322 + + upper_bound: Optional[int] = None + """See documentation for :obj:`~partitioning_mode` for more details""" # noqa: WPS322 + + session_init_statement: Optional[str] = None + '''After each database session is opened to the remote DB and before starting to read data, + this option executes a custom SQL statement (or a PL/SQL block). + + Use this to implement session initialization code. + + Example: + + .. code:: python + + sessionInitStatement = """ + BEGIN + execute immediate + 'alter session set "_serial_direct_read"=true'; + END; + """ + ''' + + fetchsize: int = 100_000 + """Fetch N rows from an opened cursor per one read round. + + Tuning this option can influence performance of reading. + + .. warning:: + + Default value is different from Spark. + + Spark uses driver's own value, and it may be different in different drivers, + and even versions of the same driver. For example, Oracle has + default ``fetchsize=10``, which is absolutely not usable. + + Thus we've overridden default value with ``100_000``, which should increase reading performance. + """ + + partitioning_mode: JDBCPartitioningMode = JDBCPartitioningMode.RANGE + """Defines how Spark will parallelize reading from table. + + Possible values: + + * ``range`` (default) + Allocate each executor a range of values from column passed into :obj:`~partition_column`. + + Spark generates for each executor an SQL query like: + + Executor 1: + + .. code:: sql + + SELECT ... FROM table + WHERE (partition_column >= lowerBound + OR partition_column IS NULL) + AND partition_column < (lower_bound + stride) + + Executor 2: + + .. code:: sql + + SELECT ... FROM table + WHERE partition_column >= (lower_bound + stride) + AND partition_column < (lower_bound + 2 * stride) + + ... + + Executor N: + + .. code:: sql + + SELECT ... FROM table + WHERE partition_column >= (lower_bound + (N-1) * stride) + AND partition_column <= upper_bound + + Where ``stride=(upper_bound - lower_bound) / num_partitions``. + + .. note:: + + :obj:`~lower_bound`, :obj:`~upper_bound` and :obj:`~num_partitions` are used just to + calculate the partition stride, **NOT** for filtering the rows in table. + So all rows in the table will be returned (unlike *Incremental* :ref:`strategy`). + + .. note:: + + All queries are executed in parallel. To execute them sequentially, use *Batch* :ref:`strategy`. + + * ``hash`` + Allocate each executor a set of values based on hash of the :obj:`~partition_column` column. + + Spark generates for each executor an SQL query like: + + Executor 1: + + .. code:: sql + + SELECT ... FROM table + WHERE (some_hash(partition_column) mod num_partitions) = 0 -- lower_bound + + Executor 2: + + .. code:: sql + + SELECT ... FROM table + WHERE (some_hash(partition_column) mod num_partitions) = 1 -- lower_bound + 1 + + ... + + Executor N: + + .. code:: sql + + SELECT ... FROM table + WHERE (some_hash(partition_column) mod num_partitions) = num_partitions-1 -- upper_bound + + .. note:: + + The hash function implementation depends on RDBMS. It can be ``MD5`` or any other fast hash function, + or expression based on this function call. + + * ``mod`` + Allocate each executor a set of values based on modulus of the :obj:`~partition_column` column. + + Spark generates for each executor an SQL query like: + + Executor 1: + + .. code:: sql + + SELECT ... FROM table + WHERE (partition_column mod num_partitions) = 0 -- lower_bound + + Executor 2: + + .. code:: sql + + SELECT ... FROM table + WHERE (partition_column mod num_partitions) = 1 -- lower_bound + 1 + + Executor N: + + .. code:: sql + + SELECT ... FROM table + WHERE (partition_column mod num_partitions) = num_partitions-1 -- upper_bound + + Examples + -------- + + Read data in 10 parallel jobs by range of values in ``id_column`` column: + + .. code:: python + + JDBC.ReadOptions( + partitioning_mode="range", # default mode, can be omitted + partition_column="id_column", + num_partitions=10, + # if you're using DBReader, options below can be omitted + # because they are calculated by automatically as + # MIN and MAX values of `partition_column` + lower_bound=0, + upper_bound=100_000, + ) + + Read data in 10 parallel jobs by hash of values in ``some_column`` column: + + .. code:: python + + JDBC.ReadOptions( + partitioning_mode="hash", + partition_column="some_column", + num_partitions=10, + # lower_bound and upper_bound are automatically set to `0` and `9` + ) + + Read data in 10 parallel jobs by modulus of values in ``id_column`` column: + + .. code:: python + + JDBC.ReadOptions( + partitioning_mode="mod", + partition_column="id_column", + num_partitions=10, + # lower_bound and upper_bound are automatically set to `0` and `9` + ) + """ + + @root_validator + def _partitioning_mode_actions(cls, values): + mode = values["partitioning_mode"] + num_partitions = values.get("num_partitions") + partition_column = values.get("partition_column") + lower_bound = values.get("lower_bound") + upper_bound = values.get("upper_bound") + + if not partition_column: + if num_partitions == 1: + return values + + raise ValueError("You should set partition_column to enable partitioning") + + elif num_partitions == 1: + raise ValueError("You should set num_partitions > 1 to enable partitioning") + + if mode == JDBCPartitioningMode.RANGE: + return values + + values["lower_bound"] = lower_bound if lower_bound is not None else 0 + values["upper_bound"] = upper_bound if upper_bound is not None else num_partitions + return values + + +class JDBCWriteOptions(JDBCOptions): + """Spark JDBC writing options. + + .. note :: + + You can pass any value + `supported by Spark `_, + even if it is not mentioned in this documentation. **Option names should be in** ``camelCase``! + + The set of supported options depends on Spark version. See link above. + + Examples + -------- + + Write options initialization + + .. code:: python + + options = JDBC.WriteOptions(if_exists="append", batchsize=20_000, someNewOption="value") + """ + + class Config: + known_options = WRITE_OPTIONS | READ_WRITE_OPTIONS + prohibited_options = JDBCOptions.Config.prohibited_options | GENERIC_PROHIBITED_OPTIONS | READ_OPTIONS + alias_generator = to_camel + + if_exists: JDBCTableExistBehavior = Field(default=JDBCTableExistBehavior.APPEND, alias="mode") + """Behavior of writing data into existing table. + + Possible values: + * ``append`` (default) + Adds new rows into existing table. + + .. dropdown:: Behavior in details + + * Table does not exist + Table is created using options provided by user + (``createTableOptions``, ``createTableColumnTypes``, etc). + + * Table exists + Data is appended to a table. Table has the same DDL as before writing data + + .. warning:: + + This mode does not check whether table already contains + rows from dataframe, so duplicated rows can be created. + + Also Spark does not support passing custom options to + insert statement, like ``ON CONFLICT``, so don't try to + implement deduplication using unique indexes or constraints. + + Instead, write to staging table and perform deduplication + using :obj:`~execute` method. + + * ``replace_entire_table`` + **Table is dropped and then created, or truncated**. + + .. dropdown:: Behavior in details + + * Table does not exist + Table is created using options provided by user + (``createTableOptions``, ``createTableColumnTypes``, etc). + + * Table exists + Table content is replaced with dataframe content. + + After writing completed, target table could either have the same DDL as + before writing data (``truncate=True``), or can be recreated (``truncate=False`` + or source does not support truncation). + + .. note:: + + ``error`` and ``ignore`` modes are not supported. + """ + + batchsize: int = 20_000 + """How many rows can be inserted per round trip. + + Tuning this option can influence performance of writing. + + .. warning:: + + Default value is different from Spark. + + Spark uses quite small value ``1000``, which is absolutely not usable + in BigData world. + + Thus we've overridden default value with ``20_000``, + which should increase writing performance. + + You can increase it even more, up to ``50_000``, + but it depends on your database load and number of columns in the row. + Higher values does not increase performance. + """ + + isolation_level: str = "READ_UNCOMMITTED" + """The transaction isolation level, which applies to current connection. + + Possible values: + * ``NONE`` (as string, not Python's ``None``) + * ``READ_COMMITTED`` + * ``READ_UNCOMMITTED`` + * ``REPEATABLE_READ`` + * ``SERIALIZABLE`` + + Values correspond to transaction isolation levels defined by JDBC standard. + Please refer the documentation for + `java.sql.Connection `_. + """ + + @root_validator(pre=True) + def _mode_is_deprecated(cls, values): + if "mode" in values: + warnings.warn( + "Option `WriteOptions(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `WriteOptions(if_exists=...)` instead", + category=UserWarning, + stacklevel=3, + ) + return values + + +@deprecated( + version="0.5.0", + reason="Please use 'ReadOptions' or 'WriteOptions' class instead. Will be removed in v1.0.0", + action="always", + category=UserWarning, +) +class JDBCLegacyOptions(JDBCReadOptions, JDBCWriteOptions): + class Config: + prohibited_options = JDBCOptions.Config.prohibited_options diff --git a/onetl/connection/db_connection/jdbc_mixin/__init__.py b/onetl/connection/db_connection/jdbc_mixin/__init__.py new file mode 100644 index 000000000..062fdb74d --- /dev/null +++ b/onetl/connection/db_connection/jdbc_mixin/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.jdbc_mixin.connection import ( + JDBCMixin, + JDBCStatementType, +) +from onetl.connection.db_connection.jdbc_mixin.options import JDBCOptions diff --git a/onetl/connection/db_connection/jdbc_mixin.py b/onetl/connection/db_connection/jdbc_mixin/connection.py similarity index 91% rename from onetl/connection/db_connection/jdbc_mixin.py rename to onetl/connection/db_connection/jdbc_mixin/connection.py index 7f4bd1a73..c02fb82f1 100644 --- a/onetl/connection/db_connection/jdbc_mixin.py +++ b/onetl/connection/db_connection/jdbc_mixin/connection.py @@ -25,9 +25,12 @@ from onetl._internal import clear_statement, stringify from onetl._util.java import get_java_gateway, try_import_java_class from onetl._util.spark import get_spark_version +from onetl.connection.db_connection.jdbc_mixin.options import ( + JDBCOptions as JDBCMixinOptions, +) from onetl.exception import MISSING_JVM_CLASS_MSG from onetl.hooks import slot, support_hooks -from onetl.impl import FrozenModel, GenericOptions +from onetl.impl import FrozenModel from onetl.log import log_lines if TYPE_CHECKING: @@ -48,7 +51,7 @@ ) -class StatementType(Enum): +class JDBCStatementType(Enum): GENERIC = auto() PREPARED = auto() CALL = auto() @@ -66,44 +69,14 @@ class JDBCMixin(FrozenModel): spark: SparkSession = Field(repr=False) user: str password: SecretStr - DRIVER: ClassVar[str] - _CHECK_QUERY: ClassVar[str] = "SELECT 1" - class JDBCOptions(GenericOptions): - """Generic options, related to specific JDBC driver. + JDBCOptions = JDBCMixinOptions - .. note :: - - You can pass any value - supported by underlying JDBC driver class, - even if it is not mentioned in this documentation. - """ - - class Config: - prohibited_options = PROHIBITED_OPTIONS - extra = "allow" - - query_timeout: Optional[int] = Field(default=None, alias="queryTimeout") - """The number of seconds the driver will wait for a statement to execute. - Zero means there is no limit. - - This option depends on driver implementation, - some drivers can check the timeout of each query instead of an entire JDBC batch. - """ - - fetchsize: Optional[int] = None - """How many rows to fetch per round trip. - - Tuning this option can influence performance of reading. - - .. warning:: - - Default value depends on driver. For example, Oracle has - default ``fetchsize=10``. - """ + DRIVER: ClassVar[str] + _CHECK_QUERY: ClassVar[str] = "SELECT 1" # cached JDBC connection (Java object), plus corresponding GenericOptions (Python object) - _last_connection_and_options: Optional[Tuple[Any, JDBCOptions]] = PrivateAttr(default=None) + _last_connection_and_options: Optional[Tuple[Any, JDBCMixinOptions]] = PrivateAttr(default=None) @property @abstractmethod @@ -176,7 +149,7 @@ def check(self): def fetch( self, query: str, - options: JDBCMixin.JDBCOptions | dict | None = None, + options: JDBCMixinOptions | dict | None = None, ) -> DataFrame: """ **Immediately** execute SELECT statement **on Spark driver** and return in-memory DataFrame. |support_hooks| @@ -274,7 +247,7 @@ def fetch( def execute( self, statement: str, - options: JDBCMixin.JDBCOptions | dict | None = None, + options: JDBCMixinOptions | dict | None = None, ) -> DataFrame | None: """ **Immediately** execute DDL, DML or procedure/function **on Spark driver**. |support_hooks| @@ -407,11 +380,11 @@ def _check_java_class_imported(cls, spark): def _query_on_driver( self, query: str, - options: JDBCMixin.JDBCOptions, + options: JDBCMixinOptions, ) -> DataFrame: return self._execute_on_driver( statement=query, - statement_type=StatementType.PREPARED, + statement_type=JDBCStatementType.PREPARED, callback=self._statement_to_dataframe, options=options, read_only=True, @@ -420,11 +393,11 @@ def _query_on_driver( def _query_optional_on_driver( self, query: str, - options: JDBCMixin.JDBCOptions, + options: JDBCMixinOptions, ) -> DataFrame | None: return self._execute_on_driver( statement=query, - statement_type=StatementType.PREPARED, + statement_type=JDBCStatementType.PREPARED, callback=self._statement_to_optional_dataframe, options=options, read_only=True, @@ -433,11 +406,11 @@ def _query_optional_on_driver( def _call_on_driver( self, query: str, - options: JDBCMixin.JDBCOptions, + options: JDBCMixinOptions, ) -> DataFrame | None: return self._execute_on_driver( statement=query, - statement_type=StatementType.CALL, + statement_type=JDBCStatementType.CALL, callback=self._statement_to_optional_dataframe, options=options, read_only=False, @@ -445,7 +418,7 @@ def _call_on_driver( def _get_jdbc_properties( self, - options: JDBCOptions, + options: JDBCMixinOptions, **kwargs, ) -> dict: """ @@ -463,7 +436,7 @@ def _get_jdbc_properties( return stringify(result) - def _options_to_connection_properties(self, options: JDBCOptions): + def _options_to_connection_properties(self, options: JDBCMixinOptions): """ Converts human-readable Options class to ``java.util.Properties``. @@ -485,7 +458,7 @@ def _options_to_connection_properties(self, options: JDBCOptions): ) return jdbc_options.asConnectionProperties() - def _get_jdbc_connection(self, options: JDBCOptions): + def _get_jdbc_connection(self, options: JDBCMixinOptions): with suppress(Exception): # nothing cached, or JVM failed last_connection, last_options = self._last_connection_and_options if options == last_options and not last_connection.isClosed(): @@ -516,9 +489,9 @@ def _get_statement_args(self) -> tuple[int, ...]: def _execute_on_driver( self, statement: str, - statement_type: StatementType, + statement_type: JDBCStatementType, callback: Callable[..., T], - options: JDBCOptions, + options: JDBCMixinOptions, read_only: bool, ) -> T: """ @@ -540,7 +513,7 @@ def _execute_statement( self, jdbc_statement, statement: str, - options: JDBCOptions, + options: JDBCMixinOptions, callback: Callable[..., T], read_only: bool, ) -> T: @@ -580,7 +553,7 @@ def _execute_statement( @staticmethod def _build_statement( statement: str, - statement_type: StatementType, + statement_type: JDBCStatementType, jdbc_connection, statement_args, ): @@ -596,10 +569,10 @@ def _build_statement( * https://github.com/apache/spark/blob/v2.3.0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala#L633 """ - if statement_type == StatementType.PREPARED: + if statement_type == JDBCStatementType.PREPARED: return jdbc_connection.prepareStatement(statement, *statement_args) - if statement_type == StatementType.CALL: + if statement_type == JDBCStatementType.CALL: return jdbc_connection.prepareCall(statement, *statement_args) return jdbc_connection.createStatement(*statement_args) diff --git a/onetl/connection/db_connection/jdbc_mixin/options.py b/onetl/connection/db_connection/jdbc_mixin/options.py new file mode 100644 index 000000000..dd889b8fc --- /dev/null +++ b/onetl/connection/db_connection/jdbc_mixin/options.py @@ -0,0 +1,65 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from onetl.impl import GenericOptions + +# options generated by JDBCMixin methods +PROHIBITED_OPTIONS = frozenset( + ( + "user", + "password", + "driver", + "url", + ), +) + + +class JDBCOptions(GenericOptions): + """Generic options, related to specific JDBC driver. + + .. note :: + + You can pass any value + supported by underlying JDBC driver class, + even if it is not mentioned in this documentation. + """ + + class Config: + prohibited_options = PROHIBITED_OPTIONS + extra = "allow" + + query_timeout: Optional[int] = Field(default=None, alias="queryTimeout") + """The number of seconds the driver will wait for a statement to execute. + Zero means there is no limit. + + This option depends on driver implementation, + some drivers can check the timeout of each query instead of an entire JDBC batch. + """ + + fetchsize: Optional[int] = None + """How many rows to fetch per round trip. + + Tuning this option can influence performance of reading. + + .. warning:: + + Default value depends on driver. For example, Oracle has + default ``fetchsize=10``. + """ diff --git a/onetl/connection/db_connection/kafka/connection.py b/onetl/connection/db_connection/kafka/connection.py index d49630139..3aa8f0fd2 100644 --- a/onetl/connection/db_connection/kafka/connection.py +++ b/onetl/connection/db_connection/kafka/connection.py @@ -192,14 +192,12 @@ class Kafka(DBConnection): kafka = Kafka( addresses=["mybroker:9092", "anotherbroker:9092"], cluster="my-cluster", - protocol=( - Kafka.SSLProtocol( - keystore_type="PEM", - keystore_certificate_chain=Path("path/to/user.crt").read_text(), - keystore_key=Path("path/to/user.key").read_text(), - truststore_type="PEM", - truststore_certificates=Path("/path/to/server.crt").read_text(), - ), + protocol=Kafka.SSLProtocol( + keystore_type="PEM", + keystore_certificate_chain=Path("path/to/user.crt").read_text(), + keystore_key=Path("path/to/user.key").read_text(), + truststore_type="PEM", + truststore_certificates=Path("/path/to/server.crt").read_text(), ), auth=Kafka.ScramAuth( user="me", @@ -271,11 +269,14 @@ def read_source_as_df( options: KafkaReadOptions = KafkaReadOptions(), # noqa: B008, WPS404 ) -> DataFrame: log.info("|%s| Reading data from topic %r", self.__class__.__name__, source) + if source not in self._get_topics(): + raise ValueError(f"Topic {source!r} doesn't exist") + result_options = {f"kafka.{key}": value for key, value in self._get_connection_properties().items()} result_options.update(options.dict(by_alias=True, exclude_none=True)) result_options["subscribe"] = source df = self.spark.read.format("kafka").options(**result_options).load() - log.info("|%s| Dataframe is successfully created", self.__class__.__name__) + log.info("|%s| Dataframe is successfully created.", self.__class__.__name__) return df @slot @@ -286,21 +287,27 @@ def write_df_to_target( options: KafkaWriteOptions = KafkaWriteOptions(), # noqa: B008, WPS404 ) -> None: # Check that the DataFrame doesn't contain any columns not in the schema - schema: StructType = self.get_df_schema(target) - required_columns = [field.name for field in schema.fields if not field.nullable] - optional_columns = [field.name for field in schema.fields if field.nullable] - schema_field_names = {field.name for field in schema.fields} - df_column_names = set(df.columns) - if not df_column_names.issubset(schema_field_names): - invalid_columns = df_column_names - schema_field_names + required_columns = {"value"} + optional_columns = {"key", "partition", "headers"} + allowed_columns = required_columns | optional_columns | {"topic"} + df_columns = set(df.columns) + if not df_columns.issubset(allowed_columns): + invalid_columns = df_columns - allowed_columns raise ValueError( - f"Invalid column names: {invalid_columns}. Expected columns: {required_columns} (required)," - f" {optional_columns} (optional)", + f"Invalid column names: {sorted(invalid_columns)}. " + f"Expected columns: {sorted(required_columns)} (required)," + f" {sorted(optional_columns)} (optional)", ) # Check that the DataFrame doesn't contain a 'headers' column with includeHeaders=False - if not getattr(options, "includeHeaders", True) and "headers" in df.columns: - raise ValueError("Cannot write 'headers' column with kafka.WriteOptions(includeHeaders=False)") + if not options.include_headers and "headers" in df.columns: + raise ValueError("Cannot write 'headers' column with kafka.WriteOptions(include_headers=False)") + + spark_version = get_spark_version(self.spark) + if options.include_headers and spark_version.major < 3: + raise ValueError( + f"kafka.WriteOptions(include_headers=True) requires Spark 3.x, got {spark_version}", + ) if "topic" in df.columns: log.warning("The 'topic' column in the DataFrame will be overridden with value %r", target) @@ -318,7 +325,7 @@ def write_df_to_target( log.info("|%s| Saving data to a topic %r", self.__class__.__name__, target) df.write.format("kafka").mode(mode).options(**write_options).save() - log.info("|%s| Data is successfully written to topic %r", self.__class__.__name__, target) + log.info("|%s| Data is successfully written to topic %r.", self.__class__.__name__, target) @slot def get_df_schema( @@ -357,6 +364,7 @@ def get_df_schema( ], ), ), + nullable=True, ), ], ) @@ -469,12 +477,12 @@ def _get_addresses_by_cluster(cls, values): @validator("cluster") def _validate_cluster_name(cls, cluster): - log.debug("|%s| Normalizing cluster %r name ...", cls.__name__, cluster) + log.debug("|%s| Normalizing cluster %r name...", cls.__name__, cluster) validated_cluster = cls.Slots.normalize_cluster_name(cluster) or cluster if validated_cluster != cluster: log.debug("|%s| Got %r", cls.__name__, validated_cluster) - log.debug("|%s| Checking if cluster %r is a known cluster ...", cls.__name__, validated_cluster) + log.debug("|%s| Checking if cluster %r is a known cluster...", cls.__name__, validated_cluster) known_clusters = cls.Slots.get_known_clusters() if known_clusters and validated_cluster not in known_clusters: raise ValueError( @@ -487,7 +495,7 @@ def _validate_cluster_name(cls, cluster): def _validate_addresses(cls, value, values): cluster = values.get("cluster") - log.debug("|%s| Normalizing addresses %r names ...", cls.__name__, value) + log.debug("|%s| Normalizing addresses %r names...", cls.__name__, value) validated_addresses = [cls.Slots.normalize_address(address, cluster) or address for address in value] if validated_addresses != value: @@ -562,8 +570,7 @@ def _get_topics(self, timeout: int = 10) -> set[str]: return set(topics) def _log_parameters(self): - log.info("|Spark| Using connection parameters:") - log_with_indent(log, "type = %s", self.__class__.__name__) + log.info("|%s| Using connection parameters:", self.__class__.__name__) log_with_indent(log, "cluster = %r", self.cluster) log_collection(log, "addresses", self.addresses, max_items=10) log_with_indent(log, "protocol = %r", self.protocol) diff --git a/onetl/connection/db_connection/kafka/dialect.py b/onetl/connection/db_connection/kafka/dialect.py index a8f4baff9..e8c35ccaa 100644 --- a/onetl/connection/db_connection/kafka/dialect.py +++ b/onetl/connection/db_connection/kafka/dialect.py @@ -18,13 +18,14 @@ import logging from onetl._util.spark import get_spark_version -from onetl.connection.db_connection.db_connection import BaseDBConnection, DBConnection +from onetl.base import BaseDBConnection +from onetl.connection.db_connection.db_connection.dialect import DBDialect from onetl.connection.db_connection.dialect_mixins import ( SupportColumnsNone, SupportDfSchemaNone, SupportHintNone, SupportHWMExpressionNone, - SupportTableWithoutDBSchema, + SupportNameAny, SupportWhereNone, ) @@ -36,9 +37,9 @@ class KafkaDialect( # noqa: WPS215 SupportDfSchemaNone, SupportHintNone, SupportWhereNone, - SupportTableWithoutDBSchema, + SupportNameAny, SupportHWMExpressionNone, - DBConnection.Dialect, + DBDialect, ): valid_hwm_columns = {"offset", "timestamp"} diff --git a/onetl/connection/db_connection/kafka/kafka_ssl_protocol.py b/onetl/connection/db_connection/kafka/kafka_ssl_protocol.py index b5a73c357..600d56876 100644 --- a/onetl/connection/db_connection/kafka/kafka_ssl_protocol.py +++ b/onetl/connection/db_connection/kafka/kafka_ssl_protocol.py @@ -48,28 +48,24 @@ class KafkaSSLProtocol(KafkaProtocol, GenericOptions): from pathlib import Path # Just read existing files located on host, and pass key and certificates as strings - protocol = ( - Kafka.SSLProtocol( - keystore_type="PEM", - keystore_certificate_chain=Path("path/to/user.crt").read_text(), - keystore_key=Path("path/to/user.key").read_text(), - truststore_type="PEM", - truststore_certificates=Path("/path/to/server.crt").read_text(), - ), + protocol = Kafka.SSLProtocol( + keystore_type="PEM", + keystore_certificate_chain=Path("path/to/user.crt").read_text(), + keystore_key=Path("path/to/user.key").read_text(), + truststore_type="PEM", + truststore_certificates=Path("/path/to/server.crt").read_text(), ) Pass PEM key and certificates as raw strings: .. code:: python - protocol = ( - Kafka.SSLProtocol( - keystore_type="PEM", - keystore_certificate_chain="-----BEGIN CERTIFICATE-----\\nMIIDZjC...\\n-----END CERTIFICATE-----", - keystore_key="-----BEGIN PRIVATE KEY-----\\nMIIEvg..\\n-----END PRIVATE KEY-----", - truststore_type="PEM", - truststore_certificates="-----BEGIN CERTIFICATE-----\\nMICC...\\n-----END CERTIFICATE-----", - ), + protocol = Kafka.SSLProtocol( + keystore_type="PEM", + keystore_certificate_chain="-----BEGIN CERTIFICATE-----\\nMIIDZjC...\\n-----END CERTIFICATE-----", + keystore_key="-----BEGIN PRIVATE KEY-----\\nMIIEvg..\\n-----END PRIVATE KEY-----", + truststore_type="PEM", + truststore_certificates="-----BEGIN CERTIFICATE-----\\nMICC...\\n-----END CERTIFICATE-----", ) Pass custom options: diff --git a/onetl/connection/db_connection/kafka/options.py b/onetl/connection/db_connection/kafka/options.py index 88f5a15b6..6d7e4ef4c 100644 --- a/onetl/connection/db_connection/kafka/options.py +++ b/onetl/connection/db_connection/kafka/options.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from enum import Enum from pydantic import Field, root_validator @@ -34,14 +36,6 @@ ), ) -KNOWN_READ_WRITE_OPTIONS = frozenset( - ( - # not adding this to class itself because headers support was added to Spark only in 3.0 - # https://issues.apache.org/jira/browse/SPARK-23539 - "includeHeaders", - ), -) - KNOWN_READ_OPTIONS = frozenset( ( "failOnDataLoss", @@ -88,9 +82,8 @@ class KafkaReadOptions(GenericOptions): * ``startingTimestamp`` * ``subscribe`` * ``subscribePattern`` - * ``topic`` - populated from connection attributes, and cannot be set in ``KafkaReadOptions`` class and be overridden + are populated from connection attributes, and cannot be set in ``KafkaReadOptions`` class and be overridden by the user to avoid issues. Examples @@ -101,14 +94,21 @@ class KafkaReadOptions(GenericOptions): .. code:: python options = Kafka.ReadOptions( + include_headers=False, minPartitions=50, - includeHeaders=True, ) """ + include_headers: bool = Field(default=False, alias="includeHeaders") + """ + If ``True``, add ``headers`` column to output DataFrame. + + If ``False``, column will not be added. + """ + class Config: prohibited_options = PROHIBITED_OPTIONS - known_options = KNOWN_READ_OPTIONS | KNOWN_READ_WRITE_OPTIONS + known_options = KNOWN_READ_OPTIONS extra = "allow" @@ -126,18 +126,10 @@ class KafkaWriteOptions(GenericOptions): .. warning:: Options: - * ``assign`` - * ``endingOffsets`` - * ``endingOffsetsByTimestamp`` * ``kafka.*`` - * ``startingOffsets`` - * ``startingOffsetsByTimestamp`` - * ``startingTimestamp`` - * ``subscribe`` - * ``subscribePattern`` * ``topic`` - populated from connection attributes, and cannot be set in ``KafkaWriteOptions`` class and be overridden + are populated from connection attributes, and cannot be set in ``KafkaWriteOptions`` class and be overridden by the user to avoid issues. Examples @@ -149,7 +141,7 @@ class KafkaWriteOptions(GenericOptions): options = Kafka.WriteOptions( if_exists="append", - includeHeaders=False, + include_headers=True, ) """ @@ -163,9 +155,16 @@ class KafkaWriteOptions(GenericOptions): * ``error`` - Raises an error if topic already exists. """ + include_headers: bool = Field(default=False, alias="includeHeaders") + """ + If ``True``, ``headers`` column from dataframe can be written to Kafka (requires Kafka 2.0+). + + If ``False`` and dataframe contains ``headers`` column, an exception will be raised. + """ + class Config: prohibited_options = PROHIBITED_OPTIONS | KNOWN_READ_OPTIONS - known_options = KNOWN_READ_WRITE_OPTIONS + known_options: frozenset[str] = frozenset() extra = "allow" @root_validator(pre=True) diff --git a/onetl/connection/db_connection/mongodb/__init__.py b/onetl/connection/db_connection/mongodb/__init__.py new file mode 100644 index 000000000..0ec92b2d7 --- /dev/null +++ b/onetl/connection/db_connection/mongodb/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.mongodb.connection import MongoDB, MongoDBExtra +from onetl.connection.db_connection.mongodb.dialect import MongoDBDialect +from onetl.connection.db_connection.mongodb.options import ( + MongoDBCollectionExistBehavior, + MongoDBPipelineOptions, + MongoDBReadOptions, + MongoDBWriteOptions, +) diff --git a/onetl/connection/db_connection/mongodb.py b/onetl/connection/db_connection/mongodb/connection.py similarity index 56% rename from onetl/connection/db_connection/mongodb.py rename to onetl/connection/db_connection/mongodb/connection.py index 1bb95bbd1..8e6110f14 100644 --- a/onetl/connection/db_connection/mongodb.py +++ b/onetl/connection/db_connection/mongodb/connection.py @@ -14,31 +14,26 @@ from __future__ import annotations -import json import logging -import operator import warnings -from datetime import datetime -from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Iterable, Mapping +from typing import TYPE_CHECKING, Any from urllib import parse as parser from etl_entities.instance import Host -from pydantic import Field, SecretStr, root_validator, validator +from pydantic import SecretStr, validator from onetl._util.classproperty import classproperty from onetl._util.java import try_import_java_class from onetl._util.scala import get_default_scala_version from onetl._util.spark import get_spark_version from onetl._util.version import Version -from onetl.base.base_db_connection import BaseDBConnection from onetl.connection.db_connection.db_connection import DBConnection -from onetl.connection.db_connection.dialect_mixins import ( - SupportColumnsNone, - SupportDfSchemaStruct, - SupportHWMColumnStr, - SupportHWMExpressionNone, - SupportTableWithoutDBSchema, +from onetl.connection.db_connection.mongodb.dialect import MongoDBDialect +from onetl.connection.db_connection.mongodb.options import ( + MongoDBCollectionExistBehavior, + MongoDBPipelineOptions, + MongoDBReadOptions, + MongoDBWriteOptions, ) from onetl.exception import MISSING_JVM_CLASS_MSG from onetl.hooks import slot, support_hooks @@ -53,125 +48,9 @@ log = logging.getLogger(__name__) -_upper_level_operators = frozenset( # noqa: WPS527 - [ - "$addFields", - "$bucket", - "$bucketAuto", - "$changeStream", - "$collStats", - "$count", - "$currentOp", - "$densify", - "$documents", - "$facet", - "$fill", - "$geoNear", - "$graphLookup", - "$group", - "$indexStats", - "$limit", - "$listLocalSessions", - "$listSessions", - "$lookup", - "$merge", - "$out", - "$planCacheStats", - "$project", - "$redact", - "$replaceRoot", - "$replaceWith", - "$sample", - "$search", - "$searchMeta", - "$set", - "$setWindowFields", - "$shardedDataDistribution", - "$skip", - "$sort", - "$sortByCount", - "$unionWith", - "$unset", - "$unwind", - ], -) - - -class MongoDBCollectionExistBehavior(str, Enum): - APPEND = "append" - REPLACE_ENTIRE_COLLECTION = "replace_entire_collection" - - def __str__(self) -> str: - return str(self.value) - - @classmethod # noqa: WPS120 - def _missing_(cls, value: object): # noqa: WPS120 - if str(value) == "overwrite": - warnings.warn( - "Mode `overwrite` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `replace_entire_collection` instead", - category=UserWarning, - stacklevel=4, - ) - return cls.REPLACE_ENTIRE_COLLECTION - - -PIPELINE_PROHIBITED_OPTIONS = frozenset( - ( - "uri", - "database", - "collection", - "pipeline", - ), -) - -PROHIBITED_OPTIONS = frozenset( - ( - "uri", - "database", - "collection", - "pipeline", - "hint", - ), -) - -KNOWN_READ_OPTIONS = frozenset( - ( - "localThreshold", - "readPreference.name", - "readPreference.tagSets", - "readConcern.level", - "sampleSize", - "samplePoolSize", - "partitioner", - "partitionerOptions", - "registerSQLHelperFunctions", - "sql.inferschema.mapTypes.enabled", - "sql.inferschema.mapTypes.minimumKeys", - "sql.pipeline.includeNullFilters", - "sql.pipeline.includeFiltersAndProjections", - "pipeline", - "hint", - "collation", - "allowDiskUse", - "batchSize", - ), -) - -KNOWN_WRITE_OPTIONS = frozenset( - ( - "extendedBsonTypes", - "localThreshold", - "replaceDocument", - "maxBatchSize", - "writeConcern.w", - "writeConcern.journal", - "writeConcern.wTimeoutMS", - "shardKey", - "forceInsert", - "ordered", - ), -) +class MongoDBExtra(GenericOptions): + class Config: + extra = "allow" @support_hooks @@ -262,16 +141,18 @@ class MongoDB(DBConnection): ) """ - class Extra(GenericOptions): - class Config: - extra = "allow" - database: str host: Host user: str password: SecretStr port: int = 27017 - extra: Extra = Extra() + extra: MongoDBExtra = MongoDBExtra() + + Dialect = MongoDBDialect + ReadOptions = MongoDBReadOptions + WriteOptions = MongoDBWriteOptions + PipelineOptions = MongoDBPipelineOptions + Extra = MongoDBExtra @slot @classmethod @@ -357,295 +238,13 @@ def package_spark_3_4(cls) -> str: warnings.warn(msg, UserWarning, stacklevel=3) return "org.mongodb.spark:mongo-spark-connector_2.12:10.1.1" - class PipelineOptions(GenericOptions): - """Aggregation pipeline options for MongoDB connector. - - The only difference from :obj:`~ReadOptions` that it is allowed to pass the 'hint' parameter. - - .. note :: - - You can pass any value - `supported by connector `_, - even if it is not mentioned in this documentation. - - The set of supported options depends on connector version. See link above. - - .. warning:: - - Options ``uri``, ``database``, ``collection``, ``pipeline`` are populated from connection attributes, - and cannot be set in ``PipelineOptions`` class. - - Examples - -------- - - Pipeline options initialization - - .. code:: python - - MongoDB.PipelineOptions( - hint="{'_id': 1}", - ) - """ - - class Config: - prohibited_options = PIPELINE_PROHIBITED_OPTIONS - known_options = KNOWN_READ_OPTIONS - extra = "allow" - - class ReadOptions(GenericOptions): - """Reading options for MongoDB connector. - - .. note :: - - You can pass any value - `supported by connector `_, - even if it is not mentioned in this documentation. - - The set of supported options depends on connector version. See link above. - - .. warning:: - - Options ``uri``, ``database``, ``collection``, ``pipeline``, ``hint`` are populated from connection - attributes, and cannot be set in ``ReadOptions`` class. - - Examples - -------- - - Read options initialization - - .. code:: python - - MongoDB.ReadOptions( - batchSize=10000, - ) - """ - - class Config: - prohibited_options = PROHIBITED_OPTIONS - known_options = KNOWN_READ_OPTIONS - extra = "allow" - - class WriteOptions(GenericOptions): - """Writing options for MongoDB connector. - - .. note :: - - You can pass any value - `supported by connector `_, - even if it is not mentioned in this documentation. - - The set of supported options depends on connector version. See link above. - - .. warning:: - - Options ``uri``, ``database``, ``collection`` are populated from connection attributes, - and cannot be set in ``WriteOptions`` class. - - Examples - -------- - - Write options initialization - - .. code:: python - - options = MongoDB.WriteOptions( - if_exists="append", - sampleSize=500, - localThreshold=20, - ) - """ - - if_exists: MongoDBCollectionExistBehavior = Field(default=MongoDBCollectionExistBehavior.APPEND, alias="mode") - """Behavior of writing data into existing collection. - - Possible values: - * ``append`` (default) - Adds new objects into existing collection. - - .. dropdown:: Behavior in details - - * Collection does not exist - Collection is created using options provided by user - (``shardkey`` and others). - - * Collection exists - Data is appended to a collection. - - .. warning:: - - This mode does not check whether collection already contains - objects from dataframe, so duplicated objects can be created. - - * ``replace_entire_collection`` - **Collection is deleted and then created**. - - .. dropdown:: Behavior in details - - * Collection does not exist - Collection is created using options provided by user - (``shardkey`` and others). - - * Collection exists - Collection content is replaced with dataframe content. - - .. note:: - - ``error`` and ``ignore`` modes are not supported. - """ - - class Config: - prohibited_options = PROHIBITED_OPTIONS - known_options = KNOWN_WRITE_OPTIONS - extra = "allow" - - @root_validator(pre=True) - def mode_is_deprecated(cls, values): - if "mode" in values: - warnings.warn( - "Option `MongoDB.WriteOptions(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `MongoDB.WriteOptions(if_exists=...)` instead", - category=UserWarning, - stacklevel=3, - ) - return values - - class Dialect( # noqa: WPS215 - SupportTableWithoutDBSchema, - SupportHWMExpressionNone, - SupportColumnsNone, - SupportDfSchemaStruct, - SupportHWMColumnStr, - DBConnection.Dialect, - ): - _compare_statements: ClassVar[Dict[Callable, str]] = { - operator.ge: "$gte", - operator.gt: "$gt", - operator.le: "$lte", - operator.lt: "$lt", - operator.eq: "$eq", - operator.ne: "$ne", - } - - @classmethod - def validate_where( - cls, - connection: BaseDBConnection, - where: Any, - ) -> dict | None: - if where is None: - return None - - if not isinstance(where, dict): - raise ValueError( - f"{connection.__class__.__name__} requires 'where' parameter type to be 'dict', " - f"got {where.__class__.__name__!r}", - ) - - for key in where: - cls._validate_top_level_keys_in_where_parameter(key) - return where - - @classmethod - def validate_hint( - cls, - connection: BaseDBConnection, - hint: Any, - ) -> dict | None: - if hint is None: - return None - - if not isinstance(hint, dict): - raise ValueError( - f"{connection.__class__.__name__} requires 'hint' parameter type to be 'dict', " - f"got {hint.__class__.__name__!r}", - ) - return hint - - @classmethod - def prepare_pipeline( - cls, - pipeline: Any, - ) -> Any: - """ - Prepares pipeline (list or dict) to MongoDB syntax, but without converting it to string. - """ - - if isinstance(pipeline, datetime): - return {"$date": pipeline.astimezone().isoformat()} - - if isinstance(pipeline, Mapping): - return {cls.prepare_pipeline(key): cls.prepare_pipeline(value) for key, value in pipeline.items()} - - if isinstance(pipeline, Iterable) and not isinstance(pipeline, str): - return [cls.prepare_pipeline(item) for item in pipeline] - - return pipeline - - @classmethod - def convert_to_str( - cls, - value: Any, - ) -> str: - """ - Converts the given dictionary, list or primitive to a string. - """ - - return json.dumps(cls.prepare_pipeline(value)) - - @classmethod - def _merge_conditions(cls, conditions: list[Any]) -> Any: - if len(conditions) == 1: - return conditions[0] - - return {"$and": conditions} - - @classmethod - def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> dict: - """ - Returns the comparison statement in MongoDB syntax: - - .. code:: - - { - "field": { - "$gt": "some_value", - } - } - """ - return { - arg1: { - cls._compare_statements[comparator]: arg2, - }, - } - - @classmethod - def _validate_top_level_keys_in_where_parameter(cls, key: str): - """ - Checks the 'where' parameter for illegal operators, such as ``$match``, ``$merge`` or ``$changeStream``. - - 'where' clause can contain only filtering operators, like ``{"col1" {"$eq": 1}}`` or ``{"$and": [...]}``. - """ - if key.startswith("$"): - if key == "$match": - raise ValueError( - "'$match' operator not allowed at the top level of the 'where' parameter dictionary. " - "This error most likely occurred due to the fact that you used the MongoDB format for the " - "pipeline {'$match': {'column': ...}}. In the onETL paradigm, you do not need to specify the " - "'$match' keyword, but write the filtering condition right away, like {'column': ...}", - ) - if key in _upper_level_operators: # noqa: WPS220 - raise ValueError( # noqa: WPS220 - f"An invalid parameter {key!r} was specified in the 'where' " - "field. You cannot use aggregations or 'groupBy' clauses in 'where'", - ) - @slot def pipeline( self, collection: str, pipeline: dict | list[dict], df_schema: StructType | None = None, - options: PipelineOptions | dict | None = None, + options: MongoDBPipelineOptions | dict | None = None, ): """ Execute a pipeline for a specific collection, and return DataFrame. |support_hooks| @@ -806,9 +405,9 @@ def get_min_max_bounds( expression: str | None = None, # noqa: U100 hint: dict | None = None, # noqa: U100 where: dict | None = None, - options: ReadOptions | dict | None = None, + options: MongoDBReadOptions | dict | None = None, ) -> tuple[Any, Any]: - log.info("|Spark| Getting min and max values for column %r", column) + log.info("|%s| Getting min and max values for column %r ...", self.__class__.__name__, column) read_options = self.ReadOptions.parse(options).dict(by_alias=True, exclude_none=True) @@ -836,7 +435,7 @@ def get_min_max_bounds( min_value = row["min"] max_value = row["max"] - log.info("|Spark| Received values:") + log.info("|%s| Received values:", self.__class__.__name__) log_with_indent(log, "MIN(%s) = %r", column, min_value) log_with_indent(log, "MAX(%s) = %r", column, max_value) @@ -852,7 +451,7 @@ def read_source_as_df( df_schema: StructType | None = None, start_from: Statement | None = None, end_at: Statement | None = None, - options: ReadOptions | dict | None = None, + options: MongoDBReadOptions | dict | None = None, ) -> DataFrame: read_options = self.ReadOptions.parse(options).dict(by_alias=True, exclude_none=True) final_where = self.Dialect._condition_assembler( @@ -893,7 +492,7 @@ def write_df_to_target( self, df: DataFrame, target: str, - options: WriteOptions | dict | None = None, + options: MongoDBWriteOptions | dict | None = None, ) -> None: write_options = self.WriteOptions.parse(options) write_options_dict = write_options.dict(by_alias=True, exclude_none=True, exclude={"if_exists"}) diff --git a/onetl/connection/db_connection/mongodb/dialect.py b/onetl/connection/db_connection/mongodb/dialect.py new file mode 100644 index 000000000..d3d388f72 --- /dev/null +++ b/onetl/connection/db_connection/mongodb/dialect.py @@ -0,0 +1,204 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import operator +from datetime import datetime +from typing import Any, Callable, ClassVar, Dict, Iterable, Mapping + +from onetl.base.base_db_connection import BaseDBConnection +from onetl.connection.db_connection.db_connection.dialect import DBDialect +from onetl.connection.db_connection.dialect_mixins import ( + SupportColumnsNone, + SupportDfSchemaStruct, + SupportHWMColumnStr, + SupportHWMExpressionNone, + SupportNameAny, +) + +_upper_level_operators = frozenset( # noqa: WPS527 + [ + "$addFields", + "$bucket", + "$bucketAuto", + "$changeStream", + "$collStats", + "$count", + "$currentOp", + "$densify", + "$documents", + "$facet", + "$fill", + "$geoNear", + "$graphLookup", + "$group", + "$indexStats", + "$limit", + "$listLocalSessions", + "$listSessions", + "$lookup", + "$merge", + "$out", + "$planCacheStats", + "$project", + "$redact", + "$replaceRoot", + "$replaceWith", + "$sample", + "$search", + "$searchMeta", + "$set", + "$setWindowFields", + "$shardedDataDistribution", + "$skip", + "$sort", + "$sortByCount", + "$unionWith", + "$unset", + "$unwind", + ], +) + + +class MongoDBDialect( # noqa: WPS215 + SupportNameAny, + SupportHWMExpressionNone, + SupportColumnsNone, + SupportDfSchemaStruct, + SupportHWMColumnStr, + DBDialect, +): + _compare_statements: ClassVar[Dict[Callable, str]] = { + operator.ge: "$gte", + operator.gt: "$gt", + operator.le: "$lte", + operator.lt: "$lt", + operator.eq: "$eq", + operator.ne: "$ne", + } + + @classmethod + def validate_where( + cls, + connection: BaseDBConnection, + where: Any, + ) -> dict | None: + if where is None: + return None + + if not isinstance(where, dict): + raise ValueError( + f"{connection.__class__.__name__} requires 'where' parameter type to be 'dict', " + f"got {where.__class__.__name__!r}", + ) + + for key in where: + cls._validate_top_level_keys_in_where_parameter(key) + return where + + @classmethod + def validate_hint( + cls, + connection: BaseDBConnection, + hint: Any, + ) -> dict | None: + if hint is None: + return None + + if not isinstance(hint, dict): + raise ValueError( + f"{connection.__class__.__name__} requires 'hint' parameter type to be 'dict', " + f"got {hint.__class__.__name__!r}", + ) + return hint + + @classmethod + def prepare_pipeline( + cls, + pipeline: Any, + ) -> Any: + """ + Prepares pipeline (list or dict) to MongoDB syntax, but without converting it to string. + """ + + if isinstance(pipeline, datetime): + return {"$date": pipeline.astimezone().isoformat()} + + if isinstance(pipeline, Mapping): + return {cls.prepare_pipeline(key): cls.prepare_pipeline(value) for key, value in pipeline.items()} + + if isinstance(pipeline, Iterable) and not isinstance(pipeline, str): + return [cls.prepare_pipeline(item) for item in pipeline] + + return pipeline + + @classmethod + def convert_to_str( + cls, + value: Any, + ) -> str: + """ + Converts the given dictionary, list or primitive to a string. + """ + + return json.dumps(cls.prepare_pipeline(value)) + + @classmethod + def _merge_conditions(cls, conditions: list[Any]) -> Any: + if len(conditions) == 1: + return conditions[0] + + return {"$and": conditions} + + @classmethod + def _get_compare_statement(cls, comparator: Callable, arg1: Any, arg2: Any) -> dict: + """ + Returns the comparison statement in MongoDB syntax: + + .. code:: + + { + "field": { + "$gt": "some_value", + } + } + """ + return { + arg1: { + cls._compare_statements[comparator]: arg2, + }, + } + + @classmethod + def _validate_top_level_keys_in_where_parameter(cls, key: str): + """ + Checks the 'where' parameter for illegal operators, such as ``$match``, ``$merge`` or ``$changeStream``. + + 'where' clause can contain only filtering operators, like ``{"col1" {"$eq": 1}}`` or ``{"$and": [...]}``. + """ + if key.startswith("$"): + if key == "$match": + raise ValueError( + "'$match' operator not allowed at the top level of the 'where' parameter dictionary. " + "This error most likely occurred due to the fact that you used the MongoDB format for the " + "pipeline {'$match': {'column': ...}}. In the onETL paradigm, you do not need to specify the " + "'$match' keyword, but write the filtering condition right away, like {'column': ...}", + ) + if key in _upper_level_operators: # noqa: WPS220 + raise ValueError( # noqa: WPS220 + f"An invalid parameter {key!r} was specified in the 'where' " + "field. You cannot use aggregations or 'groupBy' clauses in 'where'", + ) diff --git a/onetl/connection/db_connection/mongodb/options.py b/onetl/connection/db_connection/mongodb/options.py new file mode 100644 index 000000000..85f1935a3 --- /dev/null +++ b/onetl/connection/db_connection/mongodb/options.py @@ -0,0 +1,253 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from enum import Enum + +from pydantic import Field, root_validator + +from onetl.impl import GenericOptions + +PIPELINE_PROHIBITED_OPTIONS = frozenset( + ( + "uri", + "database", + "collection", + "pipeline", + ), +) + +PROHIBITED_OPTIONS = frozenset( + ( + "uri", + "database", + "collection", + "pipeline", + "hint", + ), +) + +KNOWN_READ_OPTIONS = frozenset( + ( + "localThreshold", + "readPreference.name", + "readPreference.tagSets", + "readConcern.level", + "sampleSize", + "samplePoolSize", + "partitioner", + "partitionerOptions", + "registerSQLHelperFunctions", + "sql.inferschema.mapTypes.enabled", + "sql.inferschema.mapTypes.minimumKeys", + "sql.pipeline.includeNullFilters", + "sql.pipeline.includeFiltersAndProjections", + "pipeline", + "hint", + "collation", + "allowDiskUse", + "batchSize", + ), +) + +KNOWN_WRITE_OPTIONS = frozenset( + ( + "extendedBsonTypes", + "localThreshold", + "replaceDocument", + "maxBatchSize", + "writeConcern.w", + "writeConcern.journal", + "writeConcern.wTimeoutMS", + "shardKey", + "forceInsert", + "ordered", + ), +) + + +class MongoDBCollectionExistBehavior(str, Enum): + APPEND = "append" + REPLACE_ENTIRE_COLLECTION = "replace_entire_collection" + + def __str__(self) -> str: + return str(self.value) + + @classmethod # noqa: WPS120 + def _missing_(cls, value: object): # noqa: WPS120 + if str(value) == "overwrite": + warnings.warn( + "Mode `overwrite` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `replace_entire_collection` instead", + category=UserWarning, + stacklevel=4, + ) + return cls.REPLACE_ENTIRE_COLLECTION + + +class MongoDBPipelineOptions(GenericOptions): + """Aggregation pipeline options for MongoDB connector. + + The only difference from :obj:`MongoDBReadOptions` that it is allowed to pass the ``hint`` parameter. + + .. note :: + + You can pass any value + `supported by connector `_, + even if it is not mentioned in this documentation. + + The set of supported options depends on connector version. See link above. + + .. warning:: + + Options ``uri``, ``database``, ``collection``, ``pipeline`` are populated from connection attributes, + and cannot be set in ``PipelineOptions`` class. + + Examples + -------- + + Pipeline options initialization + + .. code:: python + + MongoDB.PipelineOptions( + hint="{'_id': 1}", + ) + """ + + class Config: + prohibited_options = PIPELINE_PROHIBITED_OPTIONS + known_options = KNOWN_READ_OPTIONS + extra = "allow" + + +class MongoDBReadOptions(GenericOptions): + """Reading options for MongoDB connector. + + .. note :: + + You can pass any value + `supported by connector `_, + even if it is not mentioned in this documentation. + + The set of supported options depends on connector version. See link above. + + .. warning:: + + Options ``uri``, ``database``, ``collection``, ``pipeline``, ``hint`` are populated from connection + attributes, and cannot be set in ``ReadOptions`` class. + + Examples + -------- + + Read options initialization + + .. code:: python + + MongoDB.ReadOptions( + batchSize=10000, + ) + """ + + class Config: + prohibited_options = PROHIBITED_OPTIONS + known_options = KNOWN_READ_OPTIONS + extra = "allow" + + +class MongoDBWriteOptions(GenericOptions): + """Writing options for MongoDB connector. + + .. note :: + + You can pass any value + `supported by connector `_, + even if it is not mentioned in this documentation. + + The set of supported options depends on connector version. See link above. + + .. warning:: + + Options ``uri``, ``database``, ``collection`` are populated from connection attributes, + and cannot be set in ``WriteOptions`` class. + + Examples + -------- + + Write options initialization + + .. code:: python + + options = MongoDB.WriteOptions( + if_exists="append", + sampleSize=500, + localThreshold=20, + ) + """ + + if_exists: MongoDBCollectionExistBehavior = Field(default=MongoDBCollectionExistBehavior.APPEND, alias="mode") + """Behavior of writing data into existing collection. + + Possible values: + * ``append`` (default) + Adds new objects into existing collection. + + .. dropdown:: Behavior in details + + * Collection does not exist + Collection is created using options provided by user + (``shardkey`` and others). + + * Collection exists + Data is appended to a collection. + + .. warning:: + + This mode does not check whether collection already contains + objects from dataframe, so duplicated objects can be created. + + * ``replace_entire_collection`` + **Collection is deleted and then created**. + + .. dropdown:: Behavior in details + + * Collection does not exist + Collection is created using options provided by user + (``shardkey`` and others). + + * Collection exists + Collection content is replaced with dataframe content. + + .. note:: + + ``error`` and ``ignore`` modes are not supported. + """ + + class Config: + prohibited_options = PROHIBITED_OPTIONS + known_options = KNOWN_WRITE_OPTIONS + extra = "allow" + + @root_validator(pre=True) + def _mode_is_deprecated(cls, values): + if "mode" in values: + warnings.warn( + "Option `MongoDB.WriteOptions(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `MongoDB.WriteOptions(if_exists=...)` instead", + category=UserWarning, + stacklevel=3, + ) + return values diff --git a/onetl/connection/db_connection/mssql/__init__.py b/onetl/connection/db_connection/mssql/__init__.py new file mode 100644 index 000000000..efd6e7072 --- /dev/null +++ b/onetl/connection/db_connection/mssql/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.mssql.connection import MSSQL, MSSQLExtra +from onetl.connection.db_connection.mssql.dialect import MSSQLDialect diff --git a/onetl/connection/db_connection/mssql.py b/onetl/connection/db_connection/mssql/connection.py similarity index 85% rename from onetl/connection/db_connection/mssql.py rename to onetl/connection/db_connection/mssql/connection.py index a187ea7dd..49fc825d9 100644 --- a/onetl/connection/db_connection/mssql.py +++ b/onetl/connection/db_connection/mssql/connection.py @@ -15,17 +15,24 @@ from __future__ import annotations import warnings -from datetime import date, datetime from typing import ClassVar from onetl._util.classproperty import classproperty from onetl._util.version import Version from onetl.connection.db_connection.jdbc_connection import JDBCConnection +from onetl.connection.db_connection.mssql.dialect import MSSQLDialect from onetl.hooks import slot, support_hooks +from onetl.impl import GenericOptions # do not import PySpark here, as we allow user to use `MSSQL.get_packages()` for creating Spark session +class MSSQLExtra(GenericOptions): + class Config: + extra = "allow" + prohibited_options = frozenset(("databaseName",)) + + @support_hooks class MSSQL(JDBCConnection): """MSSQL JDBC connection. |support_hooks| @@ -161,13 +168,12 @@ class MSSQL(JDBCConnection): """ - class Extra(JDBCConnection.Extra): - class Config: - prohibited_options = frozenset(("databaseName",)) - database: str port: int = 1433 - extra: Extra = Extra() + extra: MSSQLExtra = MSSQLExtra() + + Extra = MSSQLExtra + Dialect = MSSQLDialect DRIVER: ClassVar[str] = "com.microsoft.sqlserver.jdbc.SQLServerDriver" _CHECK_QUERY: ClassVar[str] = "SELECT 1 AS field" @@ -211,29 +217,6 @@ def package(cls) -> str: warnings.warn(msg, UserWarning, stacklevel=3) return "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre8" - class Dialect(JDBCConnection.Dialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: - result = value.isoformat() - return f"CAST('{result}' AS datetime2)" - - @classmethod - def _get_date_value_sql(cls, value: date) -> str: - result = value.isoformat() - return f"CAST('{result}' AS date)" - - class ReadOptions(JDBCConnection.ReadOptions): - # https://docs.microsoft.com/ru-ru/sql/t-sql/functions/hashbytes-transact-sql?view=sql-server-ver16 - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"CONVERT(BIGINT, HASHBYTES ( 'SHA' , {partition_column} )) % {num_partitions}" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" - - ReadOptions.__doc__ = JDBCConnection.ReadOptions.__doc__ - @property def jdbc_url(self) -> str: prop = self.extra.dict(by_alias=True) diff --git a/onetl/connection/db_connection/mssql/dialect.py b/onetl/connection/db_connection/mssql/dialect.py new file mode 100644 index 000000000..95e4ff022 --- /dev/null +++ b/onetl/connection/db_connection/mssql/dialect.py @@ -0,0 +1,40 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import date, datetime + +from onetl.connection.db_connection.jdbc_connection import JDBCDialect + + +class MSSQLDialect(JDBCDialect): + @classmethod + def _get_datetime_value_sql(cls, value: datetime) -> str: + result = value.isoformat() + return f"CAST('{result}' AS datetime2)" + + @classmethod + def _get_date_value_sql(cls, value: date) -> str: + result = value.isoformat() + return f"CAST('{result}' AS date)" + + # https://docs.microsoft.com/ru-ru/sql/t-sql/functions/hashbytes-transact-sql?view=sql-server-ver16 + @classmethod + def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: + return f"CONVERT(BIGINT, HASHBYTES ( 'SHA' , {partition_column} )) % {num_partitions}" + + @classmethod + def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} % {num_partitions}" diff --git a/onetl/connection/db_connection/mysql/__init__.py b/onetl/connection/db_connection/mysql/__init__.py new file mode 100644 index 000000000..ba7337b23 --- /dev/null +++ b/onetl/connection/db_connection/mysql/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.mysql.connection import MySQL, MySQLExtra +from onetl.connection.db_connection.mysql.dialect import MySQLDialect diff --git a/onetl/connection/db_connection/mysql.py b/onetl/connection/db_connection/mysql/connection.py similarity index 80% rename from onetl/connection/db_connection/mysql.py rename to onetl/connection/db_connection/mysql/connection.py index fcde3d82c..868731eaf 100644 --- a/onetl/connection/db_connection/mysql.py +++ b/onetl/connection/db_connection/mysql/connection.py @@ -15,16 +15,25 @@ from __future__ import annotations import warnings -from datetime import date, datetime from typing import ClassVar, Optional from onetl._util.classproperty import classproperty from onetl.connection.db_connection.jdbc_connection import JDBCConnection +from onetl.connection.db_connection.mysql.dialect import MySQLDialect from onetl.hooks import slot, support_hooks +from onetl.impl.generic_options import GenericOptions # do not import PySpark here, as we allow user to use `MySQL.get_packages()` for creating Spark session +class MySQLExtra(GenericOptions): + useUnicode: str = "yes" # noqa: N815 + characterEncoding: str = "UTF-8" # noqa: N815 + + class Config: + extra = "allow" + + @support_hooks class MySQL(JDBCConnection): """MySQL JDBC connection. |support_hooks| @@ -116,13 +125,12 @@ class MySQL(JDBCConnection): """ - class Extra(JDBCConnection.Extra): - useUnicode: str = "yes" # noqa: N815 - characterEncoding: str = "UTF-8" # noqa: N815 - port: int = 3306 database: Optional[str] = None - extra: Extra = Extra() + extra: MySQLExtra = MySQLExtra() + + Extra = MySQLExtra + Dialect = MySQLDialect DRIVER: ClassVar[str] = "com.mysql.cj.jdbc.Driver" @@ -160,25 +168,3 @@ def jdbc_url(self): return f"jdbc:mysql://{self.host}:{self.port}/{self.database}?{parameters}" return f"jdbc:mysql://{self.host}:{self.port}?{parameters}" - - class Dialect(JDBCConnection.Dialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: - result = value.strftime("%Y-%m-%d %H:%M:%S.%f") - return f"STR_TO_DATE('{result}', '%Y-%m-%d %H:%i:%s.%f')" - - @classmethod - def _get_date_value_sql(cls, value: date) -> str: - result = value.strftime("%Y-%m-%d") - return f"STR_TO_DATE('{result}', '%Y-%m-%d')" - - class ReadOptions(JDBCConnection.ReadOptions): - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"MOD(CONV(CONV(RIGHT(MD5({partition_column}), 16),16, 2), 2, 10), {num_partitions})" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"MOD({partition_column}, {num_partitions})" - - ReadOptions.__doc__ = JDBCConnection.ReadOptions.__doc__ diff --git a/onetl/connection/db_connection/mysql/dialect.py b/onetl/connection/db_connection/mysql/dialect.py new file mode 100644 index 000000000..b3cd70a55 --- /dev/null +++ b/onetl/connection/db_connection/mysql/dialect.py @@ -0,0 +1,43 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import date, datetime + +from onetl.connection.db_connection.jdbc_connection import JDBCDialect + + +class MySQLDialect(JDBCDialect): + @classmethod + def _escape_column(cls, value: str) -> str: + return f"`{value}`" + + @classmethod + def _get_datetime_value_sql(cls, value: datetime) -> str: + result = value.strftime("%Y-%m-%d %H:%M:%S.%f") + return f"STR_TO_DATE('{result}', '%Y-%m-%d %H:%i:%s.%f')" + + @classmethod + def _get_date_value_sql(cls, value: date) -> str: + result = value.strftime("%Y-%m-%d") + return f"STR_TO_DATE('{result}', '%Y-%m-%d')" + + @classmethod + def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: + return f"MOD(CONV(CONV(RIGHT(MD5({partition_column}), 16),16, 2), 2, 10), {num_partitions})" + + @classmethod + def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: + return f"MOD({partition_column}, {num_partitions})" diff --git a/onetl/connection/db_connection/oracle/__init__.py b/onetl/connection/db_connection/oracle/__init__.py new file mode 100644 index 000000000..79b1b9278 --- /dev/null +++ b/onetl/connection/db_connection/oracle/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.oracle.connection import Oracle, OracleExtra +from onetl.connection.db_connection.oracle.dialect import OracleDialect diff --git a/onetl/connection/db_connection/oracle.py b/onetl/connection/db_connection/oracle/connection.py similarity index 88% rename from onetl/connection/db_connection/oracle.py rename to onetl/connection/db_connection/oracle/connection.py index 538a5b65a..69d7e2c5b 100644 --- a/onetl/connection/db_connection/oracle.py +++ b/onetl/connection/db_connection/oracle/connection.py @@ -20,9 +20,9 @@ import warnings from collections import OrderedDict from dataclasses import dataclass -from datetime import date, datetime +from decimal import Decimal from textwrap import indent -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Optional from pydantic import root_validator @@ -30,7 +30,11 @@ from onetl._util.classproperty import classproperty from onetl._util.version import Version from onetl.connection.db_connection.jdbc_connection import JDBCConnection +from onetl.connection.db_connection.jdbc_connection.options import JDBCReadOptions +from onetl.connection.db_connection.jdbc_mixin.options import JDBCOptions +from onetl.connection.db_connection.oracle.dialect import OracleDialect from onetl.hooks import slot, support_hooks +from onetl.impl import GenericOptions from onetl.log import BASE_LOG_INDENT, log_lines # do not import PySpark here, as we allow user to use `Oracle.get_packages()` for creating Spark session @@ -65,6 +69,11 @@ def sort_key(self) -> tuple[int, int, int]: return 100 - self.level, self.line, self.position +class OracleExtra(GenericOptions): + class Config: + extra = "allow" + + @support_hooks class Oracle(JDBCConnection): """Oracle JDBC connection. |support_hooks| @@ -173,6 +182,10 @@ class Oracle(JDBCConnection): port: int = 1521 sid: Optional[str] = None service_name: Optional[str] = None + extra: OracleExtra = OracleExtra() + + Extra = OracleExtra + Dialect = OracleDialect DRIVER: ClassVar[str] = "oracle.jdbc.driver.OracleDriver" _CHECK_QUERY: ClassVar[str] = "SELECT 1 FROM dual" @@ -216,41 +229,6 @@ def package(cls) -> str: warnings.warn(msg, UserWarning, stacklevel=3) return "com.oracle.database.jdbc:ojdbc8:23.2.0.0" - @root_validator - def only_one_of_sid_or_service_name(cls, values): - sid = values.get("sid") - service_name = values.get("service_name") - - if sid and service_name: - raise ValueError("Only one of parameters ``sid``, ``service_name`` can be set, got both") - - if not sid and not service_name: - raise ValueError("One of parameters ``sid``, ``service_name`` should be set, got none") - - return values - - class Dialect(JDBCConnection.Dialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: - result = value.strftime("%Y-%m-%d %H:%M:%S") - return f"TO_DATE('{result}', 'YYYY-MM-DD HH24:MI:SS')" - - @classmethod - def _get_date_value_sql(cls, value: date) -> str: - result = value.strftime("%Y-%m-%d") - return f"TO_DATE('{result}', 'YYYY-MM-DD')" - - class ReadOptions(JDBCConnection.ReadOptions): - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"ora_hash({partition_column}, {num_partitions})" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"MOD({partition_column}, {num_partitions})" - - ReadOptions.__doc__ = JDBCConnection.ReadOptions.__doc__ - @property def jdbc_url(self) -> str: extra = self.extra.dict(by_alias=True) @@ -268,11 +246,37 @@ def instance_url(self) -> str: return f"{super().instance_url}/{self.service_name}" + @slot + def get_min_max_bounds( + self, + source: str, + column: str, + expression: str | None = None, + hint: str | None = None, + where: str | None = None, + options: JDBCReadOptions | None = None, + ) -> tuple[Any, Any]: + min_value, max_value = super().get_min_max_bounds( + source=source, + column=column, + expression=expression, + hint=hint, + where=where, + options=options, + ) + # Oracle does not have Integer type, only Numeric, which is represented as Decimal in Python + # If number does not have decimal part, convert it to integer to use as lowerBound/upperBound + if isinstance(min_value, Decimal) and min_value == round(min_value): + min_value = int(min_value) + if isinstance(max_value, Decimal) and max_value == round(max_value): + max_value = int(max_value) + return min_value, max_value + @slot def execute( self, statement: str, - options: Oracle.JDBCOptions | dict | None = None, # noqa: WPS437 + options: JDBCOptions | dict | None = None, # noqa: WPS437 ) -> DataFrame | None: statement = clear_statement(statement) @@ -294,6 +298,19 @@ def execute( log.info("|%s| Execution succeeded, nothing returned", self.__class__.__name__) return df + @root_validator + def _only_one_of_sid_or_service_name(cls, values): + sid = values.get("sid") + service_name = values.get("service_name") + + if sid and service_name: + raise ValueError("Only one of parameters ``sid``, ``service_name`` can be set, got both") + + if not sid and not service_name: + raise ValueError("One of parameters ``sid``, ``service_name`` should be set, got none") + + return values + def _parse_create_statement(self, statement: str) -> tuple[str, str, str] | None: """ Parses ``CREATE ... type_name [schema.]object_name ...`` statement @@ -323,7 +340,7 @@ def _get_compile_errors( type_name: str, schema: str, object_name: str, - options: Oracle.JDBCOptions, + options: JDBCOptions, ) -> list[tuple[ErrorPosition, str]]: """ Get compile errors for the object. @@ -393,7 +410,7 @@ def _build_error_message(self, aggregated_errors: OrderedDict[ErrorPosition, str def _handle_compile_errors( self, statement: str, - options: Oracle.JDBCOptions, + options: JDBCOptions, ) -> None: """ Oracle does not return compilation errors immediately. diff --git a/onetl/connection/db_connection/oracle/dialect.py b/onetl/connection/db_connection/oracle/dialect.py new file mode 100644 index 000000000..fb3fa715d --- /dev/null +++ b/onetl/connection/db_connection/oracle/dialect.py @@ -0,0 +1,39 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import date, datetime + +from onetl.connection.db_connection.jdbc_connection import JDBCDialect + + +class OracleDialect(JDBCDialect): + @classmethod + def _get_datetime_value_sql(cls, value: datetime) -> str: + result = value.strftime("%Y-%m-%d %H:%M:%S") + return f"TO_DATE('{result}', 'YYYY-MM-DD HH24:MI:SS')" + + @classmethod + def _get_date_value_sql(cls, value: date) -> str: + result = value.strftime("%Y-%m-%d") + return f"TO_DATE('{result}', 'YYYY-MM-DD')" + + @classmethod + def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: + return f"ora_hash({partition_column}, {num_partitions})" + + @classmethod + def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: + return f"MOD({partition_column}, {num_partitions})" diff --git a/onetl/connection/db_connection/postgres/__init__.py b/onetl/connection/db_connection/postgres/__init__.py new file mode 100644 index 000000000..42bad7d54 --- /dev/null +++ b/onetl/connection/db_connection/postgres/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.postgres.connection import Postgres, PostgresExtra +from onetl.connection.db_connection.postgres.dialect import PostgresDialect diff --git a/onetl/connection/db_connection/postgres.py b/onetl/connection/db_connection/postgres/connection.py similarity index 75% rename from onetl/connection/db_connection/postgres.py rename to onetl/connection/db_connection/postgres/connection.py index aee317974..eb07a68f6 100644 --- a/onetl/connection/db_connection/postgres.py +++ b/onetl/connection/db_connection/postgres/connection.py @@ -15,28 +15,23 @@ from __future__ import annotations import warnings -from datetime import date, datetime from typing import ClassVar from onetl._util.classproperty import classproperty -from onetl.connection.db_connection.db_connection import DBConnection -from onetl.connection.db_connection.dialect_mixins import ( - SupportColumnsList, - SupportDfSchemaNone, - SupportHintNone, - SupportHWMColumnStr, - SupportHWMExpressionStr, - SupportWhereStr, -) -from onetl.connection.db_connection.dialect_mixins.support_table_with_dbschema import ( - SupportTableWithDBSchema, -) from onetl.connection.db_connection.jdbc_connection import JDBCConnection +from onetl.connection.db_connection.jdbc_mixin.options import JDBCOptions +from onetl.connection.db_connection.postgres.dialect import PostgresDialect from onetl.hooks import slot, support_hooks +from onetl.impl import GenericOptions # do not import PySpark here, as we allow user to use `Postgres.get_packages()` for creating Spark session +class PostgresExtra(GenericOptions): + class Config: + extra = "allow" + + @support_hooks class Postgres(JDBCConnection): """PostgreSQL JDBC connection. |support_hooks| @@ -130,6 +125,10 @@ class Postgres(JDBCConnection): database: str port: int = 5432 + extra: PostgresExtra = PostgresExtra() + + Extra = PostgresExtra + Dialect = PostgresDialect DRIVER: ClassVar[str] = "org.postgresql.Driver" @@ -158,38 +157,6 @@ def package(cls) -> str: warnings.warn(msg, UserWarning, stacklevel=3) return "org.postgresql:postgresql:42.6.0" - class Dialect( # noqa: WPS215 - SupportTableWithDBSchema, - SupportColumnsList, - SupportDfSchemaNone, - SupportWhereStr, - SupportHWMExpressionStr, - SupportHWMColumnStr, - SupportHintNone, - DBConnection.Dialect, - ): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: - result = value.isoformat() - return f"'{result}'::timestamp" - - @classmethod - def _get_date_value_sql(cls, value: date) -> str: - result = value.isoformat() - return f"'{result}'::date" - - class ReadOptions(JDBCConnection.ReadOptions): - # https://stackoverflow.com/a/9812029 - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"('x'||right(md5('{partition_column}'), 16))::bit(32)::bigint % {num_partitions}" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} % {num_partitions}" - - ReadOptions.__doc__ = JDBCConnection.ReadOptions.__doc__ - @property def jdbc_url(self) -> str: extra = self.extra.dict(by_alias=True) @@ -202,7 +169,7 @@ def jdbc_url(self) -> str: def instance_url(self) -> str: return f"{super().instance_url}/{self.database}" - def _options_to_connection_properties(self, options: JDBCConnection.JDBCOptions): # noqa: WPS437 + def _options_to_connection_properties(self, options: JDBCOptions): # noqa: WPS437 # See https://github.com/pgjdbc/pgjdbc/pull/1252 # Since 42.2.9 Postgres JDBC Driver added new option readOnlyMode=transaction # Which is not a desired behavior, because `.fetch()` method should always be read-only diff --git a/onetl/connection/db_connection/postgres/dialect.py b/onetl/connection/db_connection/postgres/dialect.py new file mode 100644 index 000000000..05a44471e --- /dev/null +++ b/onetl/connection/db_connection/postgres/dialect.py @@ -0,0 +1,41 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import date, datetime + +from onetl.connection.db_connection.dialect_mixins import SupportHintNone +from onetl.connection.db_connection.jdbc_connection import JDBCDialect + + +class PostgresDialect(SupportHintNone, JDBCDialect): + @classmethod + def _get_datetime_value_sql(cls, value: datetime) -> str: + result = value.isoformat() + return f"'{result}'::timestamp" + + @classmethod + def _get_date_value_sql(cls, value: date) -> str: + result = value.isoformat() + return f"'{result}'::date" + + # https://stackoverflow.com/a/9812029 + @classmethod + def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: + return f"('x'||right(md5('{partition_column}'), 16))::bit(32)::bigint % {num_partitions}" + + @classmethod + def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} % {num_partitions}" diff --git a/onetl/connection/db_connection/teradata/__init__.py b/onetl/connection/db_connection/teradata/__init__.py new file mode 100644 index 000000000..9a70a22f8 --- /dev/null +++ b/onetl/connection/db_connection/teradata/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.db_connection.teradata.connection import Teradata, TeradataExtra +from onetl.connection.db_connection.teradata.dialect import TeradataDialect diff --git a/onetl/connection/db_connection/teradata.py b/onetl/connection/db_connection/teradata/connection.py similarity index 80% rename from onetl/connection/db_connection/teradata.py rename to onetl/connection/db_connection/teradata/connection.py index 4c04e6f6e..7e730f9eb 100644 --- a/onetl/connection/db_connection/teradata.py +++ b/onetl/connection/db_connection/teradata/connection.py @@ -15,16 +15,29 @@ from __future__ import annotations import warnings -from datetime import date, datetime from typing import ClassVar, Optional from onetl._util.classproperty import classproperty from onetl.connection.db_connection.jdbc_connection import JDBCConnection +from onetl.connection.db_connection.teradata.dialect import TeradataDialect from onetl.hooks import slot +from onetl.impl import GenericOptions # do not import PySpark here, as we allow user to use `Teradata.get_packages()` for creating Spark session +class TeradataExtra(GenericOptions): + CHARSET: str = "UTF8" + COLUMN_NAME: str = "ON" + FLATTEN: str = "ON" + MAYBENULL: str = "ON" + STRICT_NAMES: str = "OFF" + + class Config: + extra = "allow" + prohibited_options = frozenset(("DATABASE", "DBS_PORT")) + + class Teradata(JDBCConnection): """Teradata JDBC connection. |support_hooks| @@ -131,19 +144,12 @@ class Teradata(JDBCConnection): """ - class Extra(JDBCConnection.Extra): - CHARSET: str = "UTF8" - COLUMN_NAME: str = "ON" - FLATTEN: str = "ON" - MAYBENULL: str = "ON" - STRICT_NAMES: str = "OFF" - - class Config: - prohibited_options = frozenset(("DATABASE", "DBS_PORT")) - port: int = 1025 database: Optional[str] = None - extra: Extra = Extra() + extra: TeradataExtra = TeradataExtra() + + Extra = TeradataExtra + Dialect = TeradataDialect DRIVER: ClassVar[str] = "com.teradata.jdbc.TeraDriver" _CHECK_QUERY: ClassVar[str] = "SELECT 1 AS check_result" @@ -184,26 +190,3 @@ def jdbc_url(self) -> str: conn = ",".join(f"{k}={v}" for k, v in sorted(prop.items())) return f"jdbc:teradata://{self.host}/{conn}" - - class Dialect(JDBCConnection.Dialect): - @classmethod - def _get_datetime_value_sql(cls, value: datetime) -> str: - result = value.isoformat() - return f"CAST('{result}' AS TIMESTAMP)" - - @classmethod - def _get_date_value_sql(cls, value: date) -> str: - result = value.isoformat() - return f"CAST('{result}' AS DATE)" - - class ReadOptions(JDBCConnection.ReadOptions): - # https://docs.teradata.com/r/w4DJnG9u9GdDlXzsTXyItA/lkaegQT4wAakj~K_ZmW1Dg - @classmethod - def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: - return f"HASHAMP(HASHBUCKET(HASHROW({partition_column}))) mod {num_partitions}" - - @classmethod - def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: - return f"{partition_column} mod {num_partitions}" - - ReadOptions.__doc__ = JDBCConnection.ReadOptions.__doc__ diff --git a/onetl/connection/db_connection/teradata/dialect.py b/onetl/connection/db_connection/teradata/dialect.py new file mode 100644 index 000000000..c449debc6 --- /dev/null +++ b/onetl/connection/db_connection/teradata/dialect.py @@ -0,0 +1,40 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import date, datetime + +from onetl.connection.db_connection.jdbc_connection import JDBCDialect + + +class TeradataDialect(JDBCDialect): + @classmethod + def _get_datetime_value_sql(cls, value: datetime) -> str: + result = value.isoformat() + return f"CAST('{result}' AS TIMESTAMP)" + + @classmethod + def _get_date_value_sql(cls, value: date) -> str: + result = value.isoformat() + return f"CAST('{result}' AS DATE)" + + # https://docs.teradata.com/r/w4DJnG9u9GdDlXzsTXyItA/lkaegQT4wAakj~K_ZmW1Dg + @classmethod + def _get_partition_column_hash(cls, partition_column: str, num_partitions: int) -> str: + return f"HASHAMP(HASHBUCKET(HASHROW({partition_column}))) mod {num_partitions}" + + @classmethod + def _get_partition_column_mod(cls, partition_column: str, num_partitions: int) -> str: + return f"{partition_column} mod {num_partitions}" diff --git a/onetl/connection/file_connection/file_connection.py b/onetl/connection/file_connection/file_connection.py index 914e989e2..39e27f2c6 100644 --- a/onetl/connection/file_connection/file_connection.py +++ b/onetl/connection/file_connection/file_connection.py @@ -129,7 +129,7 @@ def check(self): try: self.list_dir("/") - log.info("|%s| Connection is available", self.__class__.__name__) + log.info("|%s| Connection is available.", self.__class__.__name__) except (RuntimeError, ValueError): # left validation errors intact log.exception("|%s| Connection is unavailable", self.__class__.__name__) @@ -719,10 +719,9 @@ def _extract_stat_from_entry(self, top: RemotePath, entry) -> PathStatProtocol: """ def _log_parameters(self): - log.info("|onETL| Using connection parameters:") - log_with_indent(log, "type = %s", self.__class__.__name__) + log.info("|%s| Using connection parameters:", self.__class__.__name__) parameters = self.dict(exclude_none=True) - for attr, value in sorted(parameters.items()): + for attr, value in parameters.items(): if isinstance(value, os.PathLike): log_with_indent(log, "%s = %s", attr, path_repr(value)) else: diff --git a/onetl/connection/file_connection/hdfs/__init__.py b/onetl/connection/file_connection/hdfs/__init__.py new file mode 100644 index 000000000..56be3211c --- /dev/null +++ b/onetl/connection/file_connection/hdfs/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from onetl.connection.file_connection.hdfs.connection import HDFS +from onetl.connection.file_connection.hdfs.slots import HDFSSlots diff --git a/onetl/connection/file_connection/hdfs.py b/onetl/connection/file_connection/hdfs/connection.py similarity index 67% rename from onetl/connection/file_connection/hdfs.py rename to onetl/connection/file_connection/hdfs/connection.py index 31a4c3012..2419aae2f 100644 --- a/onetl/connection/file_connection/hdfs.py +++ b/onetl/connection/file_connection/hdfs/connection.py @@ -25,6 +25,7 @@ from onetl.base import PathStatProtocol from onetl.connection.file_connection.file_connection import FileConnection +from onetl.connection.file_connection.hdfs.slots import HDFSSlots from onetl.connection.file_connection.mixins.rename_dir_mixin import RenameDirMixin from onetl.connection.kerberos_helpers import kinit from onetl.hooks import slot, support_hooks @@ -88,7 +89,8 @@ class HDFS(FileConnection, RenameDirMixin): Used for: * HWM and lineage (as instance name for file paths), if set. * Validation of ``host`` value, - if latter is passed and if some hooks are bound to :obj:`~slots.get_cluster_namenodes`. + if latter is passed and if some hooks are bound to + :obj:`Slots.get_cluster_namenodes ` .. warning: @@ -100,7 +102,8 @@ class HDFS(FileConnection, RenameDirMixin): Should be an active namenode (NOT standby). If value is not set, but there are some hooks bound to - :obj:`~slots.get_cluster_namenodes` and :obj:`~slots.is_namenode_active`, + :obj:`Slots.get_cluster_namenodes ` + and :obj:`Slots.is_namenode_active `, onETL will iterate over cluster namenodes to detect which one is active. .. warning: @@ -110,7 +113,8 @@ class HDFS(FileConnection, RenameDirMixin): webhdfs_port : int, default: ``50070`` Port of Hadoop namenode (WebHDFS protocol). - If omitted, but there are some hooks bound to :obj:`~slots.get_webhdfs_port` slot, + If omitted, but there are some hooks bound to + :obj:`Slots.get_webhdfs_port ` slot, onETL will try to detect port number for a specific ``cluster``. user : str, optional @@ -202,287 +206,74 @@ class HDFS(FileConnection, RenameDirMixin): ).check() """ - @support_hooks - class Slots: - """Slots that could be implemented by third-party plugins""" - - @slot - @staticmethod - def normalize_cluster_name(cluster: str) -> str | None: - """ - Normalize cluster name passed into HDFS constructor. - - If hooks didn't return anything, cluster name is left intact. - - Parameters - ---------- - cluster : :obj:`str` - Cluster name - - Returns - ------- - str | None - Normalized cluster name. - - If hook cannot be applied to a specific cluster, it should return ``None``. - - Examples - -------- - - .. code:: python - - from onetl.connection import HDFS - from onetl.hooks import hook - - - @HDFS.Slots.normalize_cluster_name.bind - @hook - def normalize_cluster_name(cluster: str) -> str: - return cluster.lower() - """ - - @slot - @staticmethod - def normalize_namenode_host(host: str, cluster: str | None) -> str | None: - """ - Normalize namenode host passed into HDFS constructor. - - If hooks didn't return anything, host is left intact. - - Parameters - ---------- - host : :obj:`str` - Namenode host (raw) - - cluster : :obj:`str` or :obj:`None` - Cluster name (normalized), if set - - Returns - ------- - str | None - Normalized namenode host name. - - If hook cannot be applied to a specific host name, it should return ``None``. - - Examples - -------- - - .. code:: python - - from onetl.connection import HDFS - from onetl.hooks import hook - - - @HDFS.Slots.normalize_namenode_host.bind - @hook - def normalize_namenode_host(host: str, cluster: str) -> str | None: - if cluster == "rnd-dwh": - if not host.endswith(".domain.com"): - # fix missing domain name - host += ".domain.com" - return host - - return None - """ - - @slot - @staticmethod - def get_known_clusters() -> set[str] | None: - """ - Return collection of known clusters. - - Cluster passed into HDFS constructor should be present in this list. - If hooks didn't return anything, no validation will be performed. - - Returns - ------- - set[str] | None - Collection of cluster names (in normalized form). - - If hook cannot be applied, it should return ``None``. - - Examples - -------- - - .. code:: python - - from onetl.connection import HDFS - from onetl.hooks import hook - - - @HDFS.Slots.get_known_clusters.bind - @hook - def get_known_clusters() -> str[str]: - return {"rnd-dwh", "rnd-prod"} - """ - - @slot - @staticmethod - def get_cluster_namenodes(cluster: str) -> set[str] | None: - """ - Return collection of known namenodes for the cluster. - - Namenode host passed into HDFS constructor should be present in this list. - If hooks didn't return anything, no validation will be performed. - - Parameters - ---------- - cluster : :obj:`str` - Cluster name (normalized) - - Returns - ------- - set[str] | None - Collection of host names (in normalized form). - - If hook cannot be applied, it should return ``None``. - - Examples - -------- - - .. code:: python - - from onetl.connection import HDFS - from onetl.hooks import hook - - - @HDFS.Slots.get_cluster_namenodes.bind - @hook - def get_cluster_namenodes(cluster: str) -> str[str] | None: - if cluster == "rnd-dwh": - return {"namenode1.domain.com", "namenode2.domain.com"} - return None - """ - - @slot - @staticmethod - def get_current_cluster() -> str | None: - """ - Get current cluster name. - - Used in :obj:`~get_current_cluster` to automatically fill up ``cluster`` attribute of a connection. - If hooks didn't return anything, calling the method above will raise an exception. - - Returns - ------- - str | None - Current cluster name (in normalized form). - - If hook cannot be applied, it should return ``None``. - - Examples - -------- - - .. code:: python - - from onetl.connection import HDFS - from onetl.hooks import hook - - - @HDFS.Slots.get_current_cluster.bind - @hook - def get_current_cluster() -> str: - # some magic here - return "rnd-dwh" - """ - - @slot - @staticmethod - def get_webhdfs_port(cluster: str) -> int | None: - """ - Get WebHDFS port number for a specific cluster. - - Used by constructor to automatically set port number if omitted. - - Parameters - ---------- - cluster : :obj:`str` - Cluster name (normalized) - - Returns - ------- - int | None - WebHDFS port number. - - If hook cannot be applied, it should return ``None``. - - Examples - -------- - - .. code:: python - - from onetl.connection import HDFS - from onetl.hooks import hook - - - @HDFS.Slots.get_webhdfs_port.bind - @hook - def get_webhdfs_port(cluster: str) -> int | None: - if cluster == "rnd-dwh": - return 50007 # Cloudera - return None - """ - - @slot - @staticmethod - def is_namenode_active(host: str, cluster: str | None) -> bool | None: - """ - Check whether a namenode of a specified cluster is active (=not standby) or not. + cluster: Optional[Cluster] = None + host: Optional[Host] = None + webhdfs_port: int = Field(alias="port", default=50070) + user: Optional[str] = None + password: Optional[SecretStr] = None + keytab: Optional[FilePath] = None + timeout: int = 10 - Used for: - * If HDFS connection is created without ``host`` + Slots = HDFSSlots + # TODO: remove in v1.0.0 + slots = Slots - Connector will iterate over :obj:`~get_cluster_namenodes` of a cluster to get active namenode, - and then use it instead of ``host`` attribute. + @slot + @classmethod + def get_current(cls, **kwargs): + """ + Create connection for current cluster. |support_hooks| - * If HDFS connection is created with ``host`` + Automatically sets up current cluster name as ``cluster``. - :obj:`~check` will determine whether this host is active. + .. note:: - Parameters - ---------- - host : :obj:`str` - Namenode host (normalized) + Can be used only if there are a some hooks bound to slot + :obj:`Slots.get_current_cluster ` - cluster : :obj:`str` or :obj:`None` - Cluster name (normalized), if set + Parameters + ---------- + user : str + password : str | None + keytab : str | None + timeout : int - Returns - ------- - bool | None - ``True`` if namenode is active, ``False`` if not. + See :obj:`~HDFS` constructor documentation. - If hook cannot be applied, it should return ``None``. + Examples + -------- - Examples - -------- + .. code:: python - .. code:: python + from onetl.connection import HDFS - from onetl.connection import HDFS - from onetl.hooks import hook + # injecting current cluster name via hooks mechanism + hdfs = HDFS.get_current(user="me", password="pass") + """ + log.info("|%s| Detecting current cluster...", cls.__name__) + current_cluster = cls.Slots.get_current_cluster() + if not current_cluster: + raise RuntimeError( + f"{cls.__name__}.get_current() can be used only if there are " + f"some hooks bound to {cls.__name__}.Slots.get_current_cluster", + ) - @HDFS.Slots.is_namenode_active.bind - @hook - def is_namenode_active(host: str, cluster: str | None) -> bool: - # some magic here - return True - """ + log.info("|%s| Got %r", cls.__name__, current_cluster) + return cls(cluster=current_cluster, **kwargs) - # TODO: remove in v1.0.0 - slots = Slots + @property + def instance_url(self) -> str: + if self.cluster: + return self.cluster + return f"hdfs://{self.host}:{self.webhdfs_port}" - cluster: Optional[Cluster] = None - host: Optional[Host] = None - webhdfs_port: int = Field(alias="port", default=50070) - user: Optional[str] = None - password: Optional[SecretStr] = None - keytab: Optional[FilePath] = None - timeout: int = 10 + @slot + def path_exists(self, path: os.PathLike | str) -> bool: + return self.client.status(os.fspath(path), strict=False) @validator("user", pre=True) - def validate_packages(cls, user): + def _validate_packages(cls, user): if user: try: from hdfs.ext.kerberos import KerberosClient as CheckForKerberosSupport @@ -507,7 +298,7 @@ def validate_packages(cls, user): return user @root_validator - def validate_cluster_or_hostname_set(cls, values): + def _validate_cluster_or_hostname_set(cls, values): host = values.get("host") cluster = values.get("cluster") @@ -517,13 +308,13 @@ def validate_cluster_or_hostname_set(cls, values): return values @validator("cluster") - def validate_cluster_name(cls, cluster): - log.debug("|%s| Normalizing cluster %r name ...", cls.__name__, cluster) + def _validate_cluster_name(cls, cluster): + log.debug("|%s| Normalizing cluster %r name...", cls.__name__, cluster) validated_cluster = cls.Slots.normalize_cluster_name(cluster) or cluster if validated_cluster != cluster: log.debug("|%s| Got %r", cls.__name__, validated_cluster) - log.debug("|%s| Checking if cluster %r is a known cluster ...", cls.__name__, validated_cluster) + log.debug("|%s| Checking if cluster %r is a known cluster...", cls.__name__, validated_cluster) known_clusters = cls.Slots.get_known_clusters() if known_clusters and validated_cluster not in known_clusters: raise ValueError( @@ -533,10 +324,10 @@ def validate_cluster_name(cls, cluster): return validated_cluster @validator("host") - def validate_host_name(cls, host, values): + def _validate_host_name(cls, host, values): cluster = values.get("cluster") - log.debug("|%s| Normalizing namenode %r ...", cls.__name__, host) + log.debug("|%s| Normalizing namenode %r host...", cls.__name__, host) namenode = cls.Slots.normalize_namenode_host(host, cluster) or host if namenode != host: log.debug("|%s| Got %r", cls.__name__, namenode) @@ -553,7 +344,7 @@ def validate_host_name(cls, host, values): return namenode @validator("webhdfs_port", always=True) - def validate_port_number(cls, port, values): + def _validate_port_number(cls, port, values): cluster = values.get("cluster") if cluster: log.debug("|%s| Getting WebHDFS port of cluster %r ...", cls.__name__, cluster) @@ -565,7 +356,7 @@ def validate_port_number(cls, port, values): return port @root_validator - def validate_credentials(cls, values): + def _validate_credentials(cls, values): user = values.get("user") password = values.get("password") keytab = values.get("keytab") @@ -577,59 +368,6 @@ def validate_credentials(cls, values): return values - @slot - @classmethod - def get_current(cls, **kwargs): - """ - Create connection for current cluster. |support_hooks| - - Automatically sets up current cluster name as ``cluster``. - - .. note:: - - Can be used only if there are a some hooks bound to slot :obj:`~slots.get_current_cluster`. - - Parameters - ---------- - user : str - password : str | None - keytab : str | None - timeout : int - - See :obj:`~HDFS` constructor documentation. - - Examples - -------- - - .. code:: python - - from onetl.connection import HDFS - - # injecting current cluster name via hooks mechanism - hdfs = HDFS.get_current(user="me", password="pass") - """ - - log.info("|%s| Detecting current cluster...", cls.__name__) - current_cluster = cls.Slots.get_current_cluster() - if not current_cluster: - raise RuntimeError( - f"{cls.__name__}.get_current() can be used only if there are " - f"some hooks bound to {cls.__name__}.Slots.get_current_cluster", - ) - - log.info("|%s| Got %r", cls.__name__, current_cluster) - return cls(cluster=current_cluster, **kwargs) - - @property - def instance_url(self) -> str: - if self.cluster: - return self.cluster - return f"hdfs://{self.host}:{self.webhdfs_port}" - - @slot - def path_exists(self, path: os.PathLike | str) -> bool: - return self.client.status(os.fspath(path), strict=False) - def _get_active_namenode(self) -> str: class_name = self.__class__.__name__ log.info("|%s| Detecting active namenode of cluster %r ...", class_name, self.cluster) diff --git a/onetl/connection/file_connection/hdfs/slots.py b/onetl/connection/file_connection/hdfs/slots.py new file mode 100644 index 000000000..c57e69af4 --- /dev/null +++ b/onetl/connection/file_connection/hdfs/slots.py @@ -0,0 +1,286 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from onetl.hooks import slot, support_hooks + + +@support_hooks +class HDFSSlots: + """Slots that could be implemented by third-party plugins""" + + @slot + @staticmethod + def normalize_cluster_name(cluster: str) -> str | None: + """ + Normalize cluster name passed into HDFS constructor. + + If hooks didn't return anything, cluster name is left intact. + + Parameters + ---------- + cluster : :obj:`str` + Cluster name + + Returns + ------- + str | None + Normalized cluster name. + + If hook cannot be applied to a specific cluster, it should return ``None``. + + Examples + -------- + + .. code:: python + + from onetl.connection import HDFS + from onetl.hooks import hook + + + @HDFS.Slots.normalize_cluster_name.bind + @hook + def normalize_cluster_name(cluster: str) -> str: + return cluster.lower() + """ + + @slot + @staticmethod + def normalize_namenode_host(host: str, cluster: str | None) -> str | None: + """ + Normalize namenode host passed into HDFS constructor. + + If hooks didn't return anything, host is left intact. + + Parameters + ---------- + host : :obj:`str` + Namenode host (raw) + + cluster : :obj:`str` or :obj:`None` + Cluster name (normalized), if set + + Returns + ------- + str | None + Normalized namenode host name. + + If hook cannot be applied to a specific host name, it should return ``None``. + + Examples + -------- + + .. code:: python + + from onetl.connection import HDFS + from onetl.hooks import hook + + + @HDFS.Slots.normalize_namenode_host.bind + @hook + def normalize_namenode_host(host: str, cluster: str) -> str | None: + if cluster == "rnd-dwh": + if not host.endswith(".domain.com"): + # fix missing domain name + host += ".domain.com" + return host + + return None + """ + + @slot + @staticmethod + def get_known_clusters() -> set[str] | None: + """ + Return collection of known clusters. + + Cluster passed into HDFS constructor should be present in this list. + If hooks didn't return anything, no validation will be performed. + + Returns + ------- + set[str] | None + Collection of cluster names (in normalized form). + + If hook cannot be applied, it should return ``None``. + + Examples + -------- + + .. code:: python + + from onetl.connection import HDFS + from onetl.hooks import hook + + + @HDFS.Slots.get_known_clusters.bind + @hook + def get_known_clusters() -> str[str]: + return {"rnd-dwh", "rnd-prod"} + """ + + @slot + @staticmethod + def get_cluster_namenodes(cluster: str) -> set[str] | None: + """ + Return collection of known namenodes for the cluster. + + Namenode host passed into HDFS constructor should be present in this list. + If hooks didn't return anything, no validation will be performed. + + Parameters + ---------- + cluster : :obj:`str` + Cluster name (normalized) + + Returns + ------- + set[str] | None + Collection of host names (in normalized form). + + If hook cannot be applied, it should return ``None``. + + Examples + -------- + + .. code:: python + + from onetl.connection import HDFS + from onetl.hooks import hook + + + @HDFS.Slots.get_cluster_namenodes.bind + @hook + def get_cluster_namenodes(cluster: str) -> str[str] | None: + if cluster == "rnd-dwh": + return {"namenode1.domain.com", "namenode2.domain.com"} + return None + """ + + @slot + @staticmethod + def get_current_cluster() -> str | None: + """ + Get current cluster name. + + Used in :obj:`~get_current_cluster` to automatically fill up ``cluster`` attribute of a connection. + If hooks didn't return anything, calling the method above will raise an exception. + + Returns + ------- + str | None + Current cluster name (in normalized form). + + If hook cannot be applied, it should return ``None``. + + Examples + -------- + + .. code:: python + + from onetl.connection import HDFS + from onetl.hooks import hook + + + @HDFS.Slots.get_current_cluster.bind + @hook + def get_current_cluster() -> str: + # some magic here + return "rnd-dwh" + """ + + @slot + @staticmethod + def get_webhdfs_port(cluster: str) -> int | None: + """ + Get WebHDFS port number for a specific cluster. + + Used by constructor to automatically set port number if omitted. + + Parameters + ---------- + cluster : :obj:`str` + Cluster name (normalized) + + Returns + ------- + int | None + WebHDFS port number. + + If hook cannot be applied, it should return ``None``. + + Examples + -------- + + .. code:: python + + from onetl.connection import HDFS + from onetl.hooks import hook + + + @HDFS.Slots.get_webhdfs_port.bind + @hook + def get_webhdfs_port(cluster: str) -> int | None: + if cluster == "rnd-dwh": + return 50007 # Cloudera + return None + """ + + @slot + @staticmethod + def is_namenode_active(host: str, cluster: str | None) -> bool | None: + """ + Check whether a namenode of a specified cluster is active (=not standby) or not. + + Used for: + * If HDFS connection is created without ``host`` + + Connector will iterate over :obj:`~get_cluster_namenodes` of a cluster to get active namenode, + and then use it instead of ``host`` attribute. + + * If HDFS connection is created with ``host`` + + :obj:`~check` will determine whether this host is active. + + Parameters + ---------- + host : :obj:`str` + Namenode host (normalized) + + cluster : :obj:`str` or :obj:`None` + Cluster name (normalized), if set + + Returns + ------- + bool | None + ``True`` if namenode is active, ``False`` if not. + + If hook cannot be applied, it should return ``None``. + + Examples + -------- + + .. code:: python + + from onetl.connection import HDFS + from onetl.hooks import hook + + + @HDFS.Slots.is_namenode_active.bind + @hook + def is_namenode_active(host: str, cluster: str | None) -> bool: + # some magic here + return True + """ diff --git a/onetl/connection/file_df_connection/spark_file_df_connection.py b/onetl/connection/file_df_connection/spark_file_df_connection.py index 206cf9e27..7c1994182 100644 --- a/onetl/connection/file_df_connection/spark_file_df_connection.py +++ b/onetl/connection/file_df_connection/spark_file_df_connection.py @@ -81,7 +81,7 @@ def read_files_as_df( if root: log.info("|%s| Reading data from '%s' ...", self.__class__.__name__, root) else: - log.info("|%s| Reading data ...", self.__class__.__name__) + log.info("|%s| Reading data...", self.__class__.__name__) reader: DataFrameReader = self.spark.read with ExitStack() as stack: @@ -138,7 +138,7 @@ def write_df_as_files( url = self._convert_to_url(path) writer.save(url) - log.info("|%s| Data is successfully saved to '%s'", self.__class__.__name__, path) + log.info("|%s| Data is successfully saved to '%s'.", self.__class__.__name__, path) @abstractmethod def _convert_to_url(self, path: PurePathProtocol) -> str: @@ -183,8 +183,7 @@ def _forward_refs(cls) -> dict[str, type]: return refs def _log_parameters(self): - log.info("|Spark| Using connection parameters:") - log_with_indent(log, "type = %s", self.__class__.__name__) + log.info("|%s| Using connection parameters:", self.__class__.__name__) parameters = self.dict(exclude_none=True, exclude={"spark"}) - for attr, value in sorted(parameters.items()): + for attr, value in parameters.items(): log_with_indent(log, "%s = %r", attr, value) diff --git a/onetl/connection/file_df_connection/spark_hdfs/connection.py b/onetl/connection/file_df_connection/spark_hdfs/connection.py index 021230779..04bdfae48 100644 --- a/onetl/connection/file_df_connection/spark_hdfs/connection.py +++ b/onetl/connection/file_df_connection/spark_hdfs/connection.py @@ -75,7 +75,7 @@ class SparkHDFS(SparkFileDFConnection): Supports only reading files as Spark DataFrame and writing DataFrame to files. Does NOT support file operations, like create, delete, rename, etc. For these operations, - use :obj:`HDFS ` connection. + use :obj:`HDFS ` connection. Parameters ---------- @@ -85,7 +85,8 @@ class SparkHDFS(SparkFileDFConnection): Used for: * HWM and lineage (as instance name for file paths) * Validation of ``host`` value, - if latter is passed and if some hooks are bound to :obj:`~slots.get_cluster_namenodes`. + if latter is passed and if some hooks are bound to + :obj:`Slots.get_cluster_namenodes `. host : str, optional Hadoop namenode host. For example: ``namenode1.domain.com``. @@ -270,12 +271,12 @@ def get_current(cls, spark: SparkSession): @validator("cluster") def _validate_cluster_name(cls, cluster): - log.debug("|%s| Normalizing cluster %r name ...", cls.__name__, cluster) + log.debug("|%s| Normalizing cluster %r name...", cls.__name__, cluster) validated_cluster = cls.Slots.normalize_cluster_name(cluster) or cluster if validated_cluster != cluster: log.debug("|%s| Got %r", cls.__name__, validated_cluster) - log.debug("|%s| Checking if cluster %r is a known cluster ...", cls.__name__, validated_cluster) + log.debug("|%s| Checking if cluster %r is a known cluster...", cls.__name__, validated_cluster) known_clusters = cls.Slots.get_known_clusters() if known_clusters and validated_cluster not in known_clusters: raise ValueError( @@ -288,7 +289,7 @@ def _validate_cluster_name(cls, cluster): def _validate_host_name(cls, host, values): cluster = values.get("cluster") - log.debug("|%s| Normalizing namenode %r ...", cls.__name__, host) + log.debug("|%s| Normalizing namenode %r host...", cls.__name__, host) namenode = cls.Slots.normalize_namenode_host(host, cluster) or host if namenode != host: log.debug("|%s| Got %r", cls.__name__, namenode) diff --git a/onetl/connection/file_df_connection/spark_s3/connection.py b/onetl/connection/file_df_connection/spark_s3/connection.py index 44a2955ef..0fd72a0ca 100644 --- a/onetl/connection/file_df_connection/spark_s3/connection.py +++ b/onetl/connection/file_df_connection/spark_s3/connection.py @@ -66,7 +66,7 @@ class SparkS3(SparkFileDFConnection): .. warning:: - See :spark-s3-troubleshooting` guide. + See :ref:`spark-s3-troubleshooting` guide. .. warning:: diff --git a/onetl/file/file_downloader/__init__.py b/onetl/file/file_downloader/__init__.py index f0e741e55..232bbccb7 100644 --- a/onetl/file/file_downloader/__init__.py +++ b/onetl/file/file_downloader/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from onetl.file.file_downloader.download_result import DownloadResult from onetl.file.file_downloader.file_downloader import FileDownloader +from onetl.file.file_downloader.options import FileDownloaderOptions +from onetl.file.file_downloader.result import DownloadResult diff --git a/onetl/file/file_downloader/file_downloader.py b/onetl/file/file_downloader/file_downloader.py index ba7c9b18a..4a51cc74b 100644 --- a/onetl/file/file_downloader/file_downloader.py +++ b/onetl/file/file_downloader/file_downloader.py @@ -24,13 +24,14 @@ from etl_entities import HWM, FileHWM, RemoteFolder from ordered_set import OrderedSet -from pydantic import Field, root_validator, validator +from pydantic import Field, validator from onetl._internal import generate_temp_path from onetl.base import BaseFileConnection, BaseFileFilter, BaseFileLimit from onetl.base.path_protocol import PathProtocol, PathWithStatsProtocol from onetl.base.pure_path_protocol import PurePathProtocol -from onetl.file.file_downloader.download_result import DownloadResult +from onetl.file.file_downloader.options import FileDownloaderOptions +from onetl.file.file_downloader.result import DownloadResult from onetl.file.file_set import FileSet from onetl.file.filter.file_hwm import FileHWMFilter from onetl.hooks import slot, support_hooks @@ -39,7 +40,6 @@ FailedRemoteFile, FileExistBehavior, FrozenModel, - GenericOptions, LocalPath, RemoteFile, RemotePath, @@ -209,48 +209,6 @@ class FileDownloader(FrozenModel): """ - class Options(GenericOptions): - """File downloading options""" - - if_exists: FileExistBehavior = Field(default=FileExistBehavior.ERROR, alias="mode") - """ - How to handle existing files in the local directory. - - Possible values: - * ``error`` (default) - do nothing, mark file as failed - * ``ignore`` - do nothing, mark file as ignored - * ``overwrite`` - replace existing file with a new one - * ``delete_all`` - delete local directory content before downloading files - """ - - delete_source: bool = False - """ - If ``True``, remove source file after successful download. - - If download failed, file will left intact. - """ - - workers: int = Field(default=1, ge=1) - """ - Number of workers to create for parallel file download. - - 1 (default) means files will me downloaded sequentially. - 2 or more means files will be downloaded in parallel workers. - - Recommended value is ``min(32, os.cpu_count() + 4)``, e.g. ``5``. - """ - - @root_validator(pre=True) - def mode_is_deprecated(cls, values): - if "mode" in values: - warnings.warn( - "Option `FileDownloader.Options(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `FileDownloader.Options(if_exists=...)` instead", - category=UserWarning, - stacklevel=3, - ) - return values - connection: BaseFileConnection local_path: LocalPath @@ -262,7 +220,9 @@ def mode_is_deprecated(cls, values): hwm_type: Optional[Type[FileHWM]] = None - options: Options = Options() + options: FileDownloaderOptions = FileDownloaderOptions() + + Options = FileDownloaderOptions @slot def run(self, files: Iterable[str | os.PathLike] | None = None) -> DownloadResult: # noqa: WPS231 @@ -676,7 +636,7 @@ def _download_files( log.info("|%s| Files to be downloaded:", self.__class__.__name__) log_lines(log, str(files)) log_with_indent(log, "") - log.info("|%s| Starting the download process ...", self.__class__.__name__) + log.info("|%s| Starting the download process...", self.__class__.__name__) self._create_dirs(to_download) @@ -715,11 +675,23 @@ def _bulk_download( to_download: DOWNLOAD_ITEMS_TYPE, ) -> list[tuple[FileDownloadStatus, PurePathProtocol | PathWithStatsProtocol]]: workers = self.options.workers + files_count = len(to_download) result = [] - if workers > 1: + real_workers = workers + if files_count < workers: + log.debug( + "|%s| Asked for %d workers, but there are only %d files", + self.__class__.__name__, + workers, + files_count, + ) + real_workers = files_count + + if real_workers > 1: + log.debug("|%s| Using ThreadPoolExecutor with %d workers", self.__class__.__name__, real_workers) with ThreadPoolExecutor( - max_workers=min(workers, len(to_download)), + max_workers=real_workers, thread_name_prefix=self.__class__.__name__, ) as executor: futures = [ @@ -729,6 +701,7 @@ def _bulk_download( for future in as_completed(futures): result.append(future.result()) else: + log.debug("|%s| Using plain old for-loop", self.__class__.__name__) for source_file, target_file, tmp_file in to_download: result.append( self._download_file( diff --git a/onetl/file/file_downloader/options.py b/onetl/file/file_downloader/options.py new file mode 100644 index 000000000..9ec44ce52 --- /dev/null +++ b/onetl/file/file_downloader/options.py @@ -0,0 +1,64 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings + +from pydantic import Field, root_validator + +from onetl.impl import FileExistBehavior, GenericOptions + + +class FileDownloaderOptions(GenericOptions): + """File downloading options""" + + if_exists: FileExistBehavior = Field(default=FileExistBehavior.ERROR, alias="mode") + """ + How to handle existing files in the local directory. + + Possible values: + * ``error`` (default) - do nothing, mark file as failed + * ``ignore`` - do nothing, mark file as ignored + * ``overwrite`` - replace existing file with a new one + * ``delete_all`` - delete local directory content before downloading files + """ + + delete_source: bool = False + """ + If ``True``, remove source file after successful download. + + If download failed, file will left intact. + """ + + workers: int = Field(default=1, ge=1) + """ + Number of workers to create for parallel file download. + + 1 (default) means files will me downloaded sequentially. + 2 or more means files will be downloaded in parallel workers. + + Recommended value is ``min(32, os.cpu_count() + 4)``, e.g. ``5``. + """ + + @root_validator(pre=True) + def _mode_is_deprecated(cls, values): + if "mode" in values: + warnings.warn( + "Option `FileDownloader.Options(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `FileDownloader.Options(if_exists=...)` instead", + category=UserWarning, + stacklevel=3, + ) + return values diff --git a/onetl/file/file_downloader/download_result.py b/onetl/file/file_downloader/result.py similarity index 100% rename from onetl/file/file_downloader/download_result.py rename to onetl/file/file_downloader/result.py diff --git a/onetl/file/file_mover/__init__.py b/onetl/file/file_mover/__init__.py index acfa16899..fa416ec6f 100644 --- a/onetl/file/file_mover/__init__.py +++ b/onetl/file/file_mover/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from onetl.file.file_mover.file_mover import FileMover -from onetl.file.file_mover.move_result import MoveResult +from onetl.file.file_mover.options import FileMoverOptions +from onetl.file.file_mover.result import MoveResult diff --git a/onetl/file/file_mover/file_mover.py b/onetl/file/file_mover/file_mover.py index 9c28632a3..7d27eb8a8 100644 --- a/onetl/file/file_mover/file_mover.py +++ b/onetl/file/file_mover/file_mover.py @@ -16,25 +16,24 @@ import logging import os -import warnings from concurrent.futures import ThreadPoolExecutor, as_completed from enum import Enum from typing import Iterable, List, Optional, Tuple from ordered_set import OrderedSet -from pydantic import Field, root_validator, validator +from pydantic import Field, validator from onetl.base import BaseFileConnection, BaseFileFilter, BaseFileLimit from onetl.base.path_protocol import PathProtocol, PathWithStatsProtocol from onetl.base.pure_path_protocol import PurePathProtocol -from onetl.file.file_mover.move_result import MoveResult +from onetl.file.file_mover.options import FileMoverOptions +from onetl.file.file_mover.result import MoveResult from onetl.file.file_set import FileSet from onetl.hooks import slot, support_hooks from onetl.impl import ( FailedRemoteFile, FileExistBehavior, FrozenModel, - GenericOptions, RemoteFile, RemotePath, path_repr, @@ -153,41 +152,6 @@ class FileMover(FrozenModel): """ - class Options(GenericOptions): - """File moving options""" - - if_exists: FileExistBehavior = Field(default=FileExistBehavior.ERROR, alias="mode") - """ - How to handle existing files in the local directory. - - Possible values: - * ``error`` (default) - do nothing, mark file as failed - * ``ignore`` - do nothing, mark file as ignored - * ``overwrite`` - replace existing file with a new one - * ``delete_all`` - delete directory content before moving files - """ - - workers: int = Field(default=1, ge=1) - """ - Number of workers to create for parallel file moving. - - 1 (default) means files will me moved sequentially. - 2 or more means files will be moved in parallel workers. - - Recommended value is ``min(32, os.cpu_count() + 4)``, e.g. ``5``. - """ - - @root_validator(pre=True) - def mode_is_deprecated(cls, values): - if "mode" in values: - warnings.warn( - "Option `FileMover.Options(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `FileMover.Options(if_exists=...)` instead", - category=UserWarning, - stacklevel=3, - ) - return values - connection: BaseFileConnection target_path: RemotePath @@ -196,7 +160,9 @@ def mode_is_deprecated(cls, values): filters: List[BaseFileFilter] = Field(default_factory=list) limits: List[BaseFileLimit] = Field(default_factory=list) - options: Options = Options() + options: FileMoverOptions = FileMoverOptions() + + Options = FileMoverOptions @slot def run(self, files: Iterable[str | os.PathLike] | None = None) -> MoveResult: # noqa: WPS231 @@ -478,7 +444,7 @@ def _move_files( log.info("|%s| Files to be moved:", self.__class__.__name__) log_lines(log, str(files)) log_with_indent(log, "") - log.info("|%s| Starting the move process ...", self.__class__.__name__) + log.info("|%s| Starting the move process...", self.__class__.__name__) self._create_dirs(to_move) @@ -512,11 +478,23 @@ def _bulk_move( to_move: MOVE_ITEMS_TYPE, ) -> list[tuple[FileMoveStatus, PurePathProtocol | PathWithStatsProtocol]]: workers = self.options.workers + files_count = len(to_move) result = [] - if workers > 1: + real_workers = workers + if files_count < workers: + log.debug( + "|%s| Asked for %d workers, but there are only %d files", + self.__class__.__name__, + workers, + files_count, + ) + real_workers = files_count + + if real_workers > 1: + log.debug("|%s| Using ThreadPoolExecutor with %d workers", self.__class__.__name__, real_workers) with ThreadPoolExecutor( - max_workers=min(workers, len(to_move)), + max_workers=workers, thread_name_prefix=self.__class__.__name__, ) as executor: futures = [ @@ -525,6 +503,7 @@ def _bulk_move( for future in as_completed(futures): result.append(future.result()) else: + log.debug("|%s| Using plain old for-loop", self.__class__.__name__) for source_file, target_file in to_move: result.append( self._move_file( diff --git a/onetl/file/file_mover/options.py b/onetl/file/file_mover/options.py new file mode 100644 index 000000000..912c0ae1b --- /dev/null +++ b/onetl/file/file_mover/options.py @@ -0,0 +1,57 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings + +from pydantic import Field, root_validator + +from onetl.impl import FileExistBehavior, GenericOptions + + +class FileMoverOptions(GenericOptions): + """File moving options""" + + if_exists: FileExistBehavior = Field(default=FileExistBehavior.ERROR, alias="mode") + """ + How to handle existing files in the local directory. + + Possible values: + * ``error`` (default) - do nothing, mark file as failed + * ``ignore`` - do nothing, mark file as ignored + * ``overwrite`` - replace existing file with a new one + * ``delete_all`` - delete directory content before moving files + """ + + workers: int = Field(default=1, ge=1) + """ + Number of workers to create for parallel file moving. + + 1 (default) means files will me moved sequentially. + 2 or more means files will be moved in parallel workers. + + Recommended value is ``min(32, os.cpu_count() + 4)``, e.g. ``5``. + """ + + @root_validator(pre=True) + def _mode_is_deprecated(cls, values): + if "mode" in values: + warnings.warn( + "Option `FileMover.Options(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `FileMover.Options(if_exists=...)` instead", + category=UserWarning, + stacklevel=3, + ) + return values diff --git a/onetl/file/file_mover/move_result.py b/onetl/file/file_mover/result.py similarity index 100% rename from onetl/file/file_mover/move_result.py rename to onetl/file/file_mover/result.py diff --git a/onetl/file/file_uploader/__init__.py b/onetl/file/file_uploader/__init__.py index cfcf797c3..7f24451f7 100644 --- a/onetl/file/file_uploader/__init__.py +++ b/onetl/file/file_uploader/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from onetl.file.file_uploader.file_uploader import FileUploader -from onetl.file.file_uploader.upload_result import UploadResult +from onetl.file.file_uploader.options import FileUploaderOptions +from onetl.file.file_uploader.result import UploadResult diff --git a/onetl/file/file_uploader/file_uploader.py b/onetl/file/file_uploader/file_uploader.py index 811c27c5f..e9e8c550a 100644 --- a/onetl/file/file_uploader/file_uploader.py +++ b/onetl/file/file_uploader/file_uploader.py @@ -16,13 +16,12 @@ import logging import os -import warnings from concurrent.futures import ThreadPoolExecutor, as_completed from enum import Enum from typing import Iterable, Optional, Tuple from ordered_set import OrderedSet -from pydantic import Field, root_validator, validator +from pydantic import validator from onetl._internal import generate_temp_path from onetl.base import BaseFileConnection @@ -30,13 +29,13 @@ from onetl.base.pure_path_protocol import PurePathProtocol from onetl.exception import DirectoryNotFoundError, NotAFileError from onetl.file.file_set import FileSet -from onetl.file.file_uploader.upload_result import UploadResult +from onetl.file.file_uploader.options import FileUploaderOptions +from onetl.file.file_uploader.result import UploadResult from onetl.hooks import slot, support_hooks from onetl.impl import ( FailedLocalFile, FileExistBehavior, FrozenModel, - GenericOptions, LocalPath, RemotePath, path_repr, @@ -142,48 +141,6 @@ class FileUploader(FrozenModel): """ - class Options(GenericOptions): - """File uploading options""" - - if_exists: FileExistBehavior = Field(default=FileExistBehavior.ERROR, alias="mode") - """ - How to handle existing files in the target directory. - - Possible values: - * ``error`` (default) - do nothing, mark file as failed - * ``ignore`` - do nothing, mark file as ignored - * ``overwrite`` - replace existing file with a new one - * ``delete_all`` - delete local directory content before downloading files - """ - - delete_local: bool = False - """ - If ``True``, remove local file after successful download. - - If download failed, file will left intact. - """ - - workers: int = Field(default=1, ge=1) - """ - Number of workers to create for parallel file upload. - - 1 (default) means files will me uploaded sequentially. - 2 or more means files will be uploaded in parallel workers. - - Recommended value is ``min(32, os.cpu_count() + 4)``, e.g. ``5``. - """ - - @root_validator(pre=True) - def mode_is_deprecated(cls, values): - if "mode" in values: - warnings.warn( - "Option `FileUploader.Options(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " - "Use `FileUploader.Options(if_exists=...)` instead", - category=UserWarning, - stacklevel=3, - ) - return values - connection: BaseFileConnection target_path: RemotePath @@ -191,7 +148,9 @@ def mode_is_deprecated(cls, values): local_path: Optional[LocalPath] = None temp_path: Optional[RemotePath] = None - options: Options = Options() + options: FileUploaderOptions = FileUploaderOptions() + + Options = FileUploaderOptions @slot def run(self, files: Iterable[str | os.PathLike] | None = None) -> UploadResult: @@ -501,7 +460,7 @@ def _upload_files(self, to_upload: UPLOAD_ITEMS_TYPE) -> UploadResult: log.info("|%s| Files to be uploaded:", self.__class__.__name__) log_lines(log, str(files)) log_with_indent(log, "") - log.info("|%s| Starting the upload process ...", self.__class__.__name__) + log.info("|%s| Starting the upload process...", self.__class__.__name__) self._create_dirs(to_upload) @@ -540,11 +499,23 @@ def _bulk_upload( to_upload: UPLOAD_ITEMS_TYPE, ) -> list[tuple[FileUploadStatus, PurePathProtocol | PathWithStatsProtocol]]: workers = self.options.workers + files_count = len(to_upload) result = [] - if workers > 1: + real_workers = workers + if files_count < workers: + log.debug( + "|%s| Asked for %d workers, but there are only %d files", + self.__class__.__name__, + workers, + files_count, + ) + real_workers = files_count + + if real_workers > 1: + log.debug("|%s| Using ThreadPoolExecutor with %d workers", self.__class__.__name__, real_workers) with ThreadPoolExecutor( - max_workers=min(workers, len(to_upload)), + max_workers=workers, thread_name_prefix=self.__class__.__name__, ) as executor: futures = [ @@ -554,6 +525,7 @@ def _bulk_upload( for future in as_completed(futures): result.append(future.result()) else: + log.debug("|%s| Using plain old for-loop", self.__class__.__name__) for local_file, target_file, tmp_file in to_upload: result.append( self._upload_file( diff --git a/onetl/file/file_uploader/options.py b/onetl/file/file_uploader/options.py new file mode 100644 index 000000000..e3dd78bb3 --- /dev/null +++ b/onetl/file/file_uploader/options.py @@ -0,0 +1,64 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings + +from pydantic import Field, root_validator + +from onetl.impl import FileExistBehavior, GenericOptions + + +class FileUploaderOptions(GenericOptions): + """File uploading options""" + + if_exists: FileExistBehavior = Field(default=FileExistBehavior.ERROR, alias="mode") + """ + How to handle existing files in the target directory. + + Possible values: + * ``error`` (default) - do nothing, mark file as failed + * ``ignore`` - do nothing, mark file as ignored + * ``overwrite`` - replace existing file with a new one + * ``delete_all`` - delete local directory content before downloading files + """ + + delete_local: bool = False + """ + If ``True``, remove local file after successful download. + + If download failed, file will left intact. + """ + + workers: int = Field(default=1, ge=1) + """ + Number of workers to create for parallel file upload. + + 1 (default) means files will me uploaded sequentially. + 2 or more means files will be uploaded in parallel workers. + + Recommended value is ``min(32, os.cpu_count() + 4)``, e.g. ``5``. + """ + + @root_validator(pre=True) + def _mode_is_deprecated(cls, values): + if "mode" in values: + warnings.warn( + "Option `FileUploader.Options(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `FileUploader.Options(if_exists=...)` instead", + category=UserWarning, + stacklevel=3, + ) + return values diff --git a/onetl/file/file_uploader/upload_result.py b/onetl/file/file_uploader/result.py similarity index 100% rename from onetl/file/file_uploader/upload_result.py rename to onetl/file/file_uploader/result.py diff --git a/requirements/core.txt b/requirements/core.txt index edd3095a6..96c8959e5 100644 --- a/requirements/core.txt +++ b/requirements/core.txt @@ -1,5 +1,5 @@ deprecated -etl-entities>=1.3,<1.4 +etl-entities>=1.4,<1.5 evacuator>=1.0,<1.1 frozendict humanize diff --git a/requirements/docs.txt b/requirements/docs.txt index 463d9c398..3e8e1e0e8 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -3,8 +3,7 @@ furo importlib-resources<6 numpydoc pygments-csv-lexer -# https://github.com/pradyunsg/furo/discussions/693 -sphinx<7.2.0 +sphinx sphinx-copybutton sphinx-design sphinx-tabs diff --git a/setup.cfg b/setup.cfg index b2b8aab52..6a799b4c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -307,22 +307,13 @@ per-file-ignores = *connection.py: # WPS437 Found protected attribute usage: spark._sc._gateway WPS437, - onetl/connection/db_connection/mongodb.py: -# WPS437 Found protected attribute usage: self.Dialect._ - WPS437, - onetl/connection/db_connection/jdbc_mixin.py: + onetl/connection/db_connection/jdbc_mixin/connection.py: # too few type annotations TAE001, # WPS219 :Found too deep access level WPS219, # WPS437: Found protected attribute usage: spark._jvm WPS437, - onetl/connection/db_connection/hive.py: -# WPS437 Found protected attribute usage: self.Dialect._ - WPS437, - onetl/connection/db_connection/greenplum.py: -# WPS437 Found protected attribute usage: self.Dialect._ - WPS437, onetl/connection/db_connection/kafka/connection.py: # WPS342: Found implicit raw string \\n WPS342, @@ -338,7 +329,7 @@ per-file-ignores = onetl/connection/file_connection/file_connection.py: # WPS220: Found too deep nesting WPS220, - onetl/connection/file_connection/hdfs.py: + onetl/connection/file_connection/hdfs/connection.py: # E800 Found commented out code E800, # F401 'hdfs.ext.kerberos.KerberosClient as CheckForKerberosSupport' imported but unused diff --git a/tests/tests_integration/test_file_df_connection_integration/test_spark_hdfs_integration.py b/tests/tests_integration/test_file_df_connection_integration/test_spark_hdfs_integration.py index e145fcc7f..3b47df0d0 100644 --- a/tests/tests_integration/test_file_df_connection_integration/test_spark_hdfs_integration.py +++ b/tests/tests_integration/test_file_df_connection_integration/test_spark_hdfs_integration.py @@ -15,12 +15,12 @@ def test_spark_hdfs_check(hdfs_file_df_connection, caplog): with caplog.at_level(logging.INFO): assert hdfs.check() == hdfs - assert "type = SparkHDFS" in caplog.text + assert "|SparkHDFS|" in caplog.text assert f"cluster = '{hdfs.cluster}'" in caplog.text assert f"host = '{hdfs.host}'" in caplog.text assert f"ipc_port = {hdfs.ipc_port}" in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_spark_hdfs_file_connection_check_failed(spark): diff --git a/tests/tests_integration/test_file_df_connection_integration/test_spark_local_fs_integration.py b/tests/tests_integration/test_file_df_connection_integration/test_spark_local_fs_integration.py index c5d433b75..1f9751e45 100644 --- a/tests/tests_integration/test_file_df_connection_integration/test_spark_local_fs_integration.py +++ b/tests/tests_integration/test_file_df_connection_integration/test_spark_local_fs_integration.py @@ -13,6 +13,6 @@ def test_spark_local_fs_check(spark, caplog): with caplog.at_level(logging.INFO): assert local_fs.check() == local_fs - assert "type = SparkLocalFS" in caplog.text + assert "|SparkLocalFS|" in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text diff --git a/tests/tests_integration/test_file_df_connection_integration/test_spark_s3_integration.py b/tests/tests_integration/test_file_df_connection_integration/test_spark_s3_integration.py index 15fa6d5a4..c1c9d9b97 100644 --- a/tests/tests_integration/test_file_df_connection_integration/test_spark_s3_integration.py +++ b/tests/tests_integration/test_file_df_connection_integration/test_spark_s3_integration.py @@ -15,7 +15,7 @@ def test_spark_s3_check(s3_file_df_connection, caplog): with caplog.at_level(logging.INFO): assert s3.check() == s3 - assert "type = SparkS3" in caplog.text + assert "|SparkS3|" in caplog.text assert f"host = '{s3.host}'" in caplog.text assert f"port = {s3.port}" in caplog.text assert f"protocol = '{s3.protocol}'" in caplog.text @@ -27,7 +27,7 @@ def test_spark_s3_check(s3_file_df_connection, caplog): assert "extra = {" in caplog.text assert "'path.style.access': True" in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_spark_s3_check_failed(spark, s3_server): diff --git a/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py b/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py index 86f7dc95d..ed290ab43 100644 --- a/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py +++ b/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py @@ -53,13 +53,14 @@ def test_file_downloader_view_file(file_connection_with_path_and_files): [str, Path], ids=["run_path_type str", "run_path_type Path"], ) -@pytest.mark.parametrize("workers", [1, 3]) +@pytest.mark.parametrize("workers", [1, 3, 20]) def test_file_downloader_run( file_connection_with_path_and_files, path_type, run_path_type, tmp_path_factory, workers, + caplog, ): file_connection, remote_path, uploaded_files = file_connection_with_path_and_files local_path = tmp_path_factory.mktemp("local_path") @@ -73,7 +74,18 @@ def test_file_downloader_run( ), ) - download_result = downloader.run() + with caplog.at_level(logging.DEBUG): + download_result = downloader.run() + + files_count = len(uploaded_files) + if 1 <= files_count < workers: + assert f"Asked for {workers} workers, but there are only {files_count} files" in caplog.text + + if workers > 1 and files_count > 1: + real_workers = min(workers, files_count) + assert f"Using ThreadPoolExecutor with {real_workers} workers" in caplog.text + else: + assert "Using plain old for-loop" in caplog.text assert not download_result.failed assert not download_result.skipped @@ -372,6 +384,7 @@ def test_file_downloader_run_with_empty_files_input( file_connection_with_path_and_files, pass_source_path, tmp_path_factory, + caplog, ): file_connection, remote_path, _ = file_connection_with_path_and_files local_path = tmp_path_factory.mktemp("local_path") @@ -382,7 +395,11 @@ def test_file_downloader_run_with_empty_files_input( source_path=remote_path if pass_source_path else None, ) - download_result = downloader.run([]) # this argument takes precedence + with caplog.at_level(logging.INFO): + download_result = downloader.run([]) # argument takes precedence over source_path content + + assert "No files to download!" in caplog.text + assert "Starting the download process" not in caplog.text assert not download_result.failed assert not download_result.skipped @@ -390,7 +407,7 @@ def test_file_downloader_run_with_empty_files_input( assert not download_result.successful -def test_file_downloader_run_with_empty_source_path(request, file_connection_with_path, tmp_path_factory): +def test_file_downloader_run_with_empty_source_path(request, file_connection_with_path, tmp_path_factory, caplog): file_connection, remote_path = file_connection_with_path remote_path = PurePosixPath(f"/tmp/test_download_{secrets.token_hex(5)}") @@ -411,7 +428,11 @@ def finalizer(): source_path=remote_path, ) - download_result = downloader.run() + with caplog.at_level(logging.INFO): + download_result = downloader.run() + + assert "No files to download!" in caplog.text + assert "Starting the download process" not in caplog.text assert not download_result.failed assert not download_result.skipped diff --git a/tests/tests_integration/tests_core_integration/test_file_mover_integration.py b/tests/tests_integration/tests_core_integration/test_file_mover_integration.py index a3aeb32fb..762e9b9aa 100644 --- a/tests/tests_integration/tests_core_integration/test_file_mover_integration.py +++ b/tests/tests_integration/tests_core_integration/test_file_mover_integration.py @@ -35,12 +35,13 @@ def test_file_mover_view_file(file_connection_with_path_and_files): @pytest.mark.parametrize("path_type", [str, PurePosixPath], ids=["path_type str", "path_type PurePosixPath"]) -@pytest.mark.parametrize("workers", [1, 3]) +@pytest.mark.parametrize("workers", [1, 3, 20]) def test_file_mover_run( request, file_connection_with_path_and_files, path_type, workers, + caplog, ): file_connection, source_path, uploaded_files = file_connection_with_path_and_files target_path = f"/tmp/test_move_{secrets.token_hex(5)}" @@ -68,7 +69,18 @@ def finalizer(): files_content[file_path] = file_connection.read_bytes(file_path) files_size[file_path] = file_connection.get_stat(file_path).st_size - move_result = mover.run() + with caplog.at_level(logging.DEBUG): + move_result = mover.run() + + files_count = len(uploaded_files) + if 1 <= files_count < workers: + assert f"Asked for {workers} workers, but there are only {files_count} files" in caplog.text + + if workers > 1 and files_count > 1: + real_workers = min(workers, files_count) + assert f"Using ThreadPoolExecutor with {real_workers} workers" in caplog.text + else: + assert "Using plain old for-loop" in caplog.text assert not move_result.failed assert not move_result.skipped @@ -344,6 +356,7 @@ def test_file_mover_run_with_empty_files_input( request, file_connection_with_path_and_files, pass_source_path, + caplog, ): file_connection, source_path, _ = file_connection_with_path_and_files target_path = f"/tmp/test_move_{secrets.token_hex(5)}" @@ -359,7 +372,11 @@ def finalizer(): source_path=source_path if pass_source_path else None, ) - move_result = mover.run([]) # this argument takes precedence + with caplog.at_level(logging.INFO): + move_result = mover.run([]) # argument takes precedence over source_path content + + assert "No files to move!" in caplog.text + assert "Starting the moving process" not in caplog.text assert not move_result.failed assert not move_result.skipped @@ -367,7 +384,7 @@ def finalizer(): assert not move_result.successful -def test_file_mover_run_with_empty_source_path(request, file_connection): +def test_file_mover_run_with_empty_source_path(request, file_connection, caplog): source_path = PurePosixPath(f"/tmp/test_move_{secrets.token_hex(5)}") file_connection.create_dir(source_path) @@ -393,7 +410,11 @@ def finalizer2(): source_path=source_path, ) - move_result = mover.run() + with caplog.at_level(logging.INFO): + move_result = mover.run() + + assert "No files to move!" in caplog.text + assert "Starting the moving process" not in caplog.text assert not move_result.failed assert not move_result.skipped diff --git a/tests/tests_integration/tests_core_integration/test_file_uploader_integration.py b/tests/tests_integration/tests_core_integration/test_file_uploader_integration.py index cd52202fd..522cf2dd4 100644 --- a/tests/tests_integration/tests_core_integration/test_file_uploader_integration.py +++ b/tests/tests_integration/tests_core_integration/test_file_uploader_integration.py @@ -41,7 +41,7 @@ def test_file_uploader_view_files(file_connection, file_connection_resource_path [str, Path], ids=["run_path_type str", "run_path_type Path"], ) -@pytest.mark.parametrize("workers", [1, 3]) +@pytest.mark.parametrize("workers", [1, 3, 20]) def test_file_uploader_run_with_files( request, file_connection, @@ -49,6 +49,7 @@ def test_file_uploader_run_with_files( run_path_type, path_type, workers, + caplog, ): target_path = path_type(f"/tmp/test_upload_{secrets.token_hex(5)}") test_files = file_connection_test_files @@ -67,7 +68,18 @@ def finalizer(): ), ) - upload_result = uploader.run(run_path_type(file) for file in test_files) + with caplog.at_level(logging.DEBUG): + upload_result = uploader.run(run_path_type(file) for file in test_files) + + files_count = len(test_files) + if 1 <= files_count < workers: + assert f"Asked for {workers} workers, but there are only {files_count} files" in caplog.text + + if workers > 1 and files_count > 1: + real_workers = min(workers, files_count) + assert f"Using ThreadPoolExecutor with {real_workers} workers" in caplog.text + else: + assert "Using plain old for-loop" in caplog.text assert not upload_result.failed assert not upload_result.missing @@ -517,25 +529,29 @@ def test_file_uploader_run_input_is_not_file(file_connection): [False, True], ids=["Without local_path", "With local_path"], ) -def test_file_uploader_run_with_empty_files(file_connection, pass_local_path, tmp_path_factory): +def test_file_uploader_run_with_empty_files(file_connection, pass_local_path, tmp_path_factory, caplog): target_path = PurePosixPath(f"/tmp/test_upload_{secrets.token_hex(5)}") local_path = tmp_path_factory.mktemp("local_path") - downloader = FileUploader( + uploader = FileUploader( connection=file_connection, target_path=target_path, local_path=local_path if pass_local_path else None, ) - download_result = downloader.run([]) + with caplog.at_level(logging.INFO): + upload_result = uploader.run([]) # argument takes precedence over source_path content - assert not download_result.failed - assert not download_result.skipped - assert not download_result.missing - assert not download_result.successful + assert "No files to upload!" in caplog.text + assert "Starting the upload process" not in caplog.text + assert not upload_result.failed + assert not upload_result.skipped + assert not upload_result.missing + assert not upload_result.successful -def test_file_uploader_run_with_empty_local_path(request, file_connection, tmp_path_factory): + +def test_file_uploader_run_with_empty_local_path(request, file_connection, tmp_path_factory, caplog): target_path = PurePosixPath(f"/tmp/test_upload_{secrets.token_hex(5)}") local_path = tmp_path_factory.mktemp("local_path") @@ -544,18 +560,22 @@ def finalizer(): request.addfinalizer(finalizer) - downloader = FileUploader( + uploader = FileUploader( connection=file_connection, target_path=target_path, local_path=local_path, ) - download_result = downloader.run() + with caplog.at_level(logging.INFO): + upload_result = uploader.run() - assert not download_result.failed - assert not download_result.skipped - assert not download_result.missing - assert not download_result.successful + assert "No files to upload!" in caplog.text + assert "Starting the upload process" not in caplog.text + + assert not upload_result.failed + assert not upload_result.skipped + assert not upload_result.missing + assert not upload_result.successful def test_file_uploader_without_files_and_without_local_path(file_connection): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py index bc693e89f..643187e22 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_clickhouse_reader_integration.py @@ -1,7 +1,6 @@ import pytest from onetl.connection import Clickhouse -from onetl.connection.db_connection.jdbc_connection import PartitioningMode from onetl.db import DBReader pytestmark = pytest.mark.clickhouse @@ -30,7 +29,15 @@ def test_clickhouse_reader_snapshot(spark, processing, load_table_data): ) -def test_clickhouse_reader_snapshot_partitioning_mode_mod(spark, processing, load_table_data): +@pytest.mark.parametrize( + "mode, column", + [ + ("range", "id_int"), + ("hash", "text_string"), + ("mod", "id_int"), + ], +) +def test_clickhouse_reader_snapshot_partitioning_mode(mode, column, spark, processing, load_table_data): clickhouse = Clickhouse( host=processing.host, port=processing.port, @@ -43,9 +50,9 @@ def test_clickhouse_reader_snapshot_partitioning_mode_mod(spark, processing, loa reader = DBReader( connection=clickhouse, source=load_table_data.full_name, - options=clickhouse.ReadOptions( - partitioning_mode=PartitioningMode.mod, - partition_column="id_int", + options=Clickhouse.ReadOptions( + partitioning_mode=mode, + partition_column=column, num_partitions=5, ), ) @@ -59,35 +66,7 @@ def test_clickhouse_reader_snapshot_partitioning_mode_mod(spark, processing, loa order_by="id_int", ) - -def test_clickhouse_reader_snapshot_partitioning_mode_hash(spark, processing, load_table_data): - clickhouse = Clickhouse( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - database=processing.database, - spark=spark, - ) - - reader = DBReader( - connection=clickhouse, - source=load_table_data.full_name, - options=clickhouse.ReadOptions( - partitioning_mode=PartitioningMode.hash, - partition_column="text_string", - num_partitions=5, - ), - ) - - table_df = reader.run() - - processing.assert_equal_df( - schema=load_table_data.schema, - table=load_table_data.table, - df=table_df, - order_by="id_int", - ) + assert table_df.rdd.getNumPartitions() == 5 def test_clickhouse_reader_snapshot_without_set_database(spark, processing, load_table_data): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py index ce23a9a23..6f5d3e545 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py @@ -22,10 +22,10 @@ def dataframe_schema(): return StructType( [ - StructField("id_int", LongType(), True), - StructField("text_string", StringType(), True), - StructField("hwm_int", LongType(), True), - StructField("float_value", FloatType(), True), + StructField("id_int", LongType(), nullable=True), + StructField("text_string", StringType(), nullable=True), + StructField("hwm_int", LongType(), nullable=True), + StructField("float_value", FloatType(), nullable=True), ], ) @@ -88,6 +88,7 @@ def kafka_schema_with_headers(): ], ), ), + nullable=True, ), ], ) @@ -112,10 +113,8 @@ def create_kafka_data(spark): def test_kafka_reader(spark, kafka_processing, schema): - # Arrange topic, processing, expected_df = kafka_processing - # Act kafka = Kafka( spark=spark, addresses=[f"{processing.host}:{processing.port}"], @@ -128,7 +127,6 @@ def test_kafka_reader(spark, kafka_processing, schema): ) df = reader.run() - # Assert processing.assert_equal_df(processing.json_deserialize(df, df_schema=schema), other_frame=expected_df) @@ -167,9 +165,27 @@ def test_kafka_reader_columns_and_types_with_headers(spark, kafka_processing, ka reader = DBReader( connection=kafka, source=topic, - options=kafka.ReadOptions(includeHeaders=True), + options=Kafka.ReadOptions(includeHeaders=True), ) df = reader.run() assert df.schema == kafka_schema_with_headers + + +def test_kafka_reader_topic_does_not_exist(spark, kafka_processing): + _, processing, _ = kafka_processing + + kafka = Kafka( + spark=spark, + addresses=[f"{processing.host}:{processing.port}"], + cluster="cluster", + ) + + reader = DBReader( + connection=kafka, + source="missing", + ) + + with pytest.raises(ValueError, match="Topic 'missing' doesn't exist"): + reader.run() diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py index a3098656b..12ace609d 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mssql_reader_integration.py @@ -1,7 +1,6 @@ import pytest from onetl.connection import MSSQL -from onetl.connection.db_connection.jdbc_connection import PartitioningMode from onetl.db import DBReader pytestmark = pytest.mark.mssql @@ -31,7 +30,15 @@ def test_mssql_reader_snapshot(spark, processing, load_table_data): ) -def test_mssql_reader_snapshot_partitioning_mode_mod(spark, processing, load_table_data): +@pytest.mark.parametrize( + "mode, column", + [ + ("range", "id_int"), + ("hash", "text_string"), + ("mod", "id_int"), + ], +) +def test_mssql_reader_snapshot_partitioning_mode(mode, column, spark, processing, load_table_data): mssql = MSSQL( host=processing.host, port=processing.port, @@ -45,9 +52,9 @@ def test_mssql_reader_snapshot_partitioning_mode_mod(spark, processing, load_tab reader = DBReader( connection=mssql, source=load_table_data.full_name, - options=mssql.ReadOptions( - partitioning_mode=PartitioningMode.mod, - partition_column="id_int", + options=MSSQL.ReadOptions( + partitioning_mode=mode, + partition_column=column, num_partitions=5, ), ) @@ -61,36 +68,7 @@ def test_mssql_reader_snapshot_partitioning_mode_mod(spark, processing, load_tab order_by="id_int", ) - -def test_mssql_reader_snapshot_partitioning_mode_hash(spark, processing, load_table_data): - mssql = MSSQL( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - database=processing.database, - spark=spark, - extra={"trustServerCertificate": "true"}, - ) - - reader = DBReader( - connection=mssql, - source=load_table_data.full_name, - options=mssql.ReadOptions( - partitioning_mode=PartitioningMode.hash, - partition_column="text_string", - num_partitions=5, - ), - ) - - table_df = reader.run() - - processing.assert_equal_df( - schema=load_table_data.schema, - table=load_table_data.table, - df=table_df, - order_by="id_int", - ) + assert table_df.rdd.getNumPartitions() == 5 def test_mssql_reader_snapshot_with_columns(spark, processing, load_table_data): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py index e05ce3826..e0865866a 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_mysql_reader_integration.py @@ -1,7 +1,6 @@ import pytest from onetl.connection import MySQL -from onetl.connection.db_connection.jdbc_connection import PartitioningMode from onetl.db import DBReader pytestmark = pytest.mark.mysql @@ -31,7 +30,15 @@ def test_mysql_reader_snapshot(spark, processing, load_table_data): ) -def test_mysql_reader_snapshot_partitioning_mode_mod(spark, processing, load_table_data): +@pytest.mark.parametrize( + "mode, column", + [ + ("range", "id_int"), + ("hash", "text_string"), + ("mod", "id_int"), + ], +) +def test_mysql_reader_snapshot_partitioning_mode(mode, column, spark, processing, load_table_data): mysql = MySQL( host=processing.host, port=processing.port, @@ -44,9 +51,9 @@ def test_mysql_reader_snapshot_partitioning_mode_mod(spark, processing, load_tab reader = DBReader( connection=mysql, source=load_table_data.full_name, - options=mysql.ReadOptions( - partitioning_mode=PartitioningMode.mod, - partition_column="id_int", + options=MySQL.ReadOptions( + partitioning_mode=mode, + partition_column=column, num_partitions=5, ), ) @@ -60,35 +67,7 @@ def test_mysql_reader_snapshot_partitioning_mode_mod(spark, processing, load_tab order_by="id_int", ) - -def test_mysql_reader_snapshot_partitioning_mode_hash(spark, processing, load_table_data): - mysql = MySQL( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - database=processing.database, - spark=spark, - ) - - reader = DBReader( - connection=mysql, - source=load_table_data.full_name, - options=mysql.ReadOptions( - partitioning_mode=PartitioningMode.hash, - partition_column="text_string", - num_partitions=5, - ), - ) - - table_df = reader.run() - - processing.assert_equal_df( - schema=load_table_data.schema, - table=load_table_data.table, - df=table_df, - order_by="id_int", - ) + assert table_df.rdd.getNumPartitions() == 5 def test_mysql_reader_snapshot_with_not_set_database(spark, processing, load_table_data): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py index 02f4d5607..b379923ef 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py @@ -1,7 +1,6 @@ import pytest from onetl.connection import Oracle -from onetl.connection.db_connection.jdbc_connection import PartitioningMode from onetl.db import DBReader pytestmark = pytest.mark.oracle @@ -31,7 +30,15 @@ def test_oracle_reader_snapshot(spark, processing, load_table_data): ) -def test_oracle_reader_snapshot_partitioning_mode_mod(spark, processing, load_table_data): +@pytest.mark.parametrize( + "mode, column", + [ + ("range", "id_int"), + ("hash", "text_string"), + ("mod", "id_int"), + ], +) +def test_oracle_reader_snapshot_partitioning_mode(mode, column, spark, processing, load_table_data): oracle = Oracle( host=processing.host, port=processing.port, @@ -45,9 +52,9 @@ def test_oracle_reader_snapshot_partitioning_mode_mod(spark, processing, load_ta reader = DBReader( connection=oracle, source=load_table_data.full_name, - options=oracle.ReadOptions( - partitioning_mode=PartitioningMode.mod, - partition_column="id_int", + options=Oracle.ReadOptions( + partitioning_mode=mode, + partition_column=column, num_partitions=5, ), ) @@ -61,36 +68,7 @@ def test_oracle_reader_snapshot_partitioning_mode_mod(spark, processing, load_ta order_by="id_int", ) - -def test_oracle_reader_snapshot_partitioning_mode_hash(spark, processing, load_table_data): - oracle = Oracle( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - spark=spark, - sid=processing.sid, - service_name=processing.service_name, - ) - - reader = DBReader( - connection=oracle, - source=load_table_data.full_name, - options=oracle.ReadOptions( - partitioning_mode=PartitioningMode.hash, - partition_column="text_string", - num_partitions=5, - ), - ) - - table_df = reader.run() - - processing.assert_equal_df( - schema=load_table_data.schema, - table=load_table_data.table, - df=table_df, - order_by="id_int", - ) + assert table_df.rdd.getNumPartitions() == 5 def test_oracle_reader_snapshot_with_columns(spark, processing, load_table_data): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py index 738c9b6e3..617eba903 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py @@ -1,7 +1,6 @@ import pytest from onetl.connection import Postgres -from onetl.connection.db_connection.jdbc_connection import PartitioningMode from onetl.db import DBReader pytestmark = pytest.mark.postgres @@ -30,7 +29,15 @@ def test_postgres_reader_snapshot(spark, processing, load_table_data): ) -def test_postgres_reader_snapshot_partitioning_mode_mod(spark, processing, load_table_data): +@pytest.mark.parametrize( + "mode, column", + [ + ("range", "id_int"), + ("hash", "text_string"), + ("mod", "id_int"), + ], +) +def test_postgres_reader_snapshot_partitioning_mode(mode, column, spark, processing, load_table_data): postgres = Postgres( host=processing.host, port=processing.port, @@ -43,9 +50,9 @@ def test_postgres_reader_snapshot_partitioning_mode_mod(spark, processing, load_ reader = DBReader( connection=postgres, source=load_table_data.full_name, - options=postgres.ReadOptions( - partitioning_mode=PartitioningMode.mod, - partition_column="id_int", + options=Postgres.ReadOptions( + partitioning_mode=mode, + partition_column=column, num_partitions=5, ), ) @@ -59,35 +66,7 @@ def test_postgres_reader_snapshot_partitioning_mode_mod(spark, processing, load_ order_by="id_int", ) - -def test_postgres_reader_snapshot_partitioning_mode_hash(spark, processing, load_table_data): - postgres = Postgres( - host=processing.host, - port=processing.port, - user=processing.user, - password=processing.password, - database=processing.database, - spark=spark, - ) - - reader = DBReader( - connection=postgres, - source=load_table_data.full_name, - options=postgres.ReadOptions( - partitioning_mode=PartitioningMode.hash, - partition_column="text_string", - num_partitions=5, - ), - ) - - table_df = reader.run() - - processing.assert_equal_df( - schema=load_table_data.schema, - table=load_table_data.table, - df=table_df, - order_by="id_int", - ) + assert table_df.rdd.getNumPartitions() == 5 def test_postgres_reader_snapshot_with_columns(spark, processing, load_table_data): diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_greenplum_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_greenplum_writer_integration.py index 7899b3257..c97105a44 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_greenplum_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_greenplum_writer_integration.py @@ -6,7 +6,17 @@ pytestmark = pytest.mark.greenplum -def test_greenplum_writer_snapshot(spark, processing, prepare_schema_table): +@pytest.mark.parametrize( + "options", + [ + {}, + {"if_exists": "append"}, + {"if_exists": "replace_entire_table"}, + {"if_exists": "error"}, + {"if_exists": "ignore"}, + ], +) +def test_greenplum_writer_snapshot(spark, processing, get_schema_table, options): df = processing.create_spark_df(spark=spark) greenplum = Greenplum( @@ -21,20 +31,21 @@ def test_greenplum_writer_snapshot(spark, processing, prepare_schema_table): writer = DBWriter( connection=greenplum, - target=prepare_schema_table.full_name, + target=get_schema_table.full_name, + options=Greenplum.WriteOptions(**options), ) writer.run(df) processing.assert_equal_df( - schema=prepare_schema_table.schema, - table=prepare_schema_table.table, + schema=get_schema_table.schema, + table=get_schema_table.table, df=df, order_by="id_int", ) -def test_greenplum_writer_mode_append(spark, processing, prepare_schema_table): +def test_greenplum_writer_if_exists_append(spark, processing, prepare_schema_table): df = processing.create_spark_df(spark=spark, min_id=1, max_id=1500) df1 = df[df.id_int < 1001] df2 = df[df.id_int > 1000] @@ -66,7 +77,7 @@ def test_greenplum_writer_mode_append(spark, processing, prepare_schema_table): ) -def test_greenplum_writer_mode(spark, processing, prepare_schema_table): +def test_greenplum_writer_if_exists_overwrite(spark, processing, prepare_schema_table): df = processing.create_spark_df(spark=spark, min_id=1, max_id=1500) df1 = df[df.id_int < 1001] df2 = df[df.id_int > 1000] @@ -96,3 +107,62 @@ def test_greenplum_writer_mode(spark, processing, prepare_schema_table): df=df2, order_by="id_int", ) + + +def test_greenplum_writer_if_exists_error(spark, processing, prepare_schema_table): + from py4j.java_gateway import Py4JJavaError + + df = processing.create_spark_df(spark=spark, min_id=1, max_id=1500) + + greenplum = Greenplum( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra=processing.extra, + ) + + writer = DBWriter( + connection=greenplum, + target=prepare_schema_table.full_name, + options=Greenplum.WriteOptions(if_exists="error"), + ) + + with pytest.raises( + Py4JJavaError, + match=f'Table "{prepare_schema_table.schema}"."{prepare_schema_table.table}"' + f" exists, and SaveMode.ErrorIfExists was specified", + ): + writer.run(df) + + +def test_greenplum_writer_if_exists_ignore(spark, processing, prepare_schema_table): + df = processing.create_spark_df(spark=spark, min_id=1, max_id=1500) + + greenplum = Greenplum( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra=processing.extra, + ) + + writer = DBWriter( + connection=greenplum, + target=prepare_schema_table.full_name, + options=Greenplum.WriteOptions(if_exists="ignore"), + ) + + writer.run(df) # The write operation is ignored + + empty_df = spark.createDataFrame([], df.schema) + + processing.assert_equal_df( + schema=prepare_schema_table.schema, + table=prepare_schema_table.table, + df=empty_df, + ) diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py index c6a43bb24..44553539b 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py @@ -346,8 +346,8 @@ def test_hive_writer_insert_into_append(spark, processing, get_schema_table, ori with caplog.at_level(logging.INFO): writer2.run(df1.union(df3)) - assert f"|Hive| Inserting data into existing table '{get_schema_table.full_name}'" in caplog.text - assert f"|Hive| Data is successfully inserted into table '{get_schema_table.full_name}'" in caplog.text + assert f"|Hive| Inserting data into existing table '{get_schema_table.full_name}' ..." in caplog.text + assert f"|Hive| Data is successfully inserted into table '{get_schema_table.full_name}'." in caplog.text new_ddl = hive.sql(f"SHOW CREATE TABLE {get_schema_table.full_name}").collect()[0][0] @@ -405,8 +405,8 @@ def test_hive_writer_insert_into_replace_entire_table( writer2.run(df2.select(*reversed(df2.columns))) # unlike other modes, this creates new table - assert f"|Hive| Saving data to a table '{get_schema_table.full_name}'" in caplog.text - assert f"|Hive| Table '{get_schema_table.full_name}' is successfully created" in caplog.text + assert f"|Hive| Saving data to a table '{get_schema_table.full_name}' ..." in caplog.text + assert f"|Hive| Table '{get_schema_table.full_name}' is successfully created." in caplog.text new_ddl = hive.sql(f"SHOW CREATE TABLE {get_schema_table.full_name}").collect()[0][0] @@ -456,8 +456,8 @@ def test_hive_writer_insert_into_replace_overlapping_partitions_in_non_partition with caplog.at_level(logging.INFO): writer2.run(df2_reversed) - assert f"|Hive| Inserting data into existing table '{get_schema_table.full_name}'" in caplog.text - assert f"|Hive| Data is successfully inserted into table '{get_schema_table.full_name}'" in caplog.text + assert f"|Hive| Inserting data into existing table '{get_schema_table.full_name}' ..." in caplog.text + assert f"|Hive| Data is successfully inserted into table '{get_schema_table.full_name}'." in caplog.text new_ddl = hive.sql(f"SHOW CREATE TABLE {get_schema_table.full_name}").collect()[0][0] diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py index 8a9427983..b35bfabad 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py @@ -1,3 +1,4 @@ +import contextlib import logging import re import secrets @@ -97,7 +98,7 @@ def test_kafka_writer_no_value_column_error(spark, kafka_processing, kafka_spark from pyspark.sql.utils import AnalysisException topic, processing = kafka_processing - df = kafka_spark_df.drop("value") + df = kafka_spark_df.select("key") kafka = Kafka( spark=spark, @@ -114,13 +115,28 @@ def test_kafka_writer_no_value_column_error(spark, kafka_processing, kafka_spark writer.run(df) -def test_kafka_writer_invalid_column_error(spark, kafka_processing, kafka_spark_df): +@pytest.mark.parametrize( + "column, value", + [ + ("offset", 0), + ("timestamp", 10000), + ("timestampType", 1), + ("unknown", "str"), + ], +) +def test_kafka_writer_invalid_column_error( + column, + value, + spark, + kafka_processing, + kafka_spark_df, +): from pyspark.sql.functions import lit topic, processing = kafka_processing # Add an unexpected column to the DataFrame - df = kafka_spark_df.withColumn("invalid_column", lit("invalid_value")) + df = kafka_spark_df.withColumn(column, lit(value)) kafka = Kafka( spark=spark, @@ -134,32 +150,13 @@ def test_kafka_writer_invalid_column_error(spark, kafka_processing, kafka_spark_ ) error_msg = ( - "Invalid column names: {'invalid_column'}. Expected columns: ['value'] (required), " - "['key', 'topic', 'partition', 'offset', 'timestamp', 'timestampType', 'headers'] (optional)" + f"Invalid column names: ['{column}']. " + "Expected columns: ['value'] (required), ['headers', 'key', 'partition'] (optional)" ) with pytest.raises(ValueError, match=re.escape(error_msg)): writer.run(df) -def test_kafka_writer_with_include_headers_error(spark, kafka_processing, kafka_spark_df): - topic, processing = kafka_processing - - kafka = Kafka( - spark=spark, - addresses=[f"{processing.host}:{processing.port}"], - cluster="cluster", - ) - - writer = DBWriter( - connection=kafka, - table=topic, - options=kafka.WriteOptions(includeHeaders=False), - ) - - with pytest.raises(ValueError, match="Cannot write 'headers' column"): - writer.run(kafka_spark_df) - - def test_kafka_writer_key_column(spark, kafka_processing, kafka_spark_df): topic, processing = kafka_processing df = kafka_spark_df.select("value", "key") @@ -185,6 +182,7 @@ def test_kafka_writer_topic_column(spark, kafka_processing, caplog, kafka_spark_ from pyspark.sql.functions import lit topic, processing = kafka_processing + original_df = kafka_spark_df.select("value") kafka = Kafka( spark=spark, @@ -196,12 +194,10 @@ def test_kafka_writer_topic_column(spark, kafka_processing, caplog, kafka_spark_ connection=kafka, table=topic, ) - writer.run(kafka_spark_df) - + writer.run(original_df) assert processing.topic_exists(topic) - df = kafka_spark_df.withColumn("topic", lit("other_topic")) - + df = original_df.withColumn("topic", lit("other_topic")) with caplog.at_level(logging.WARNING): writer.run(df) assert f"The 'topic' column in the DataFrame will be overridden with value '{topic}'" in caplog.text @@ -234,7 +230,10 @@ def test_kafka_writer_partition_column(spark, kafka_processing, kafka_spark_df): def test_kafka_writer_headers(spark, kafka_processing, kafka_spark_df): if get_spark_version(spark).major < 3: - pytest.skip("Spark 3.x or later is required to write/read 'headers' from Kafka messages") + msg = f"kafka.WriteOptions(include_headers=True) requires Spark 3.x, got {spark.version}" + context_manager = pytest.raises(ValueError, match=re.escape(msg)) + else: + context_manager = contextlib.nullcontext() topic, processing = kafka_processing @@ -247,19 +246,39 @@ def test_kafka_writer_headers(spark, kafka_processing, kafka_spark_df): writer = DBWriter( connection=kafka, table=topic, - options=kafka.WriteOptions(includeHeaders=True), + options=kafka.WriteOptions(include_headers=True), ) df = kafka_spark_df.select("value", "headers") - writer.run(df) + with context_manager: + writer.run(df) + + pd_df = processing.get_expected_df(topic, num_messages=kafka_spark_df.count()) + + processing.assert_equal_df( + df, + other_frame=pd_df.drop(columns=["key", "partition", "topic"], axis=1), + ) - pd_df = processing.get_expected_df(topic, num_messages=kafka_spark_df.count()) - processing.assert_equal_df( - df, - other_frame=pd_df.drop(columns=["key", "partition", "topic"], axis=1), +def test_kafka_writer_headers_without_include_headers_fail(spark, kafka_processing, kafka_spark_df): + topic, processing = kafka_processing + + kafka = Kafka( + spark=spark, + addresses=[f"{processing.host}:{processing.port}"], + cluster="cluster", ) + writer = DBWriter( + connection=kafka, + table=topic, + options=kafka.WriteOptions(include_headers=False), + ) + + with pytest.raises(ValueError, match="Cannot write 'headers' column"): + writer.run(kafka_spark_df) + def test_kafka_writer_mode(spark, kafka_processing, kafka_spark_df): from pyspark.sql.functions import lit @@ -276,7 +295,6 @@ def test_kafka_writer_mode(spark, kafka_processing, kafka_spark_df): writer = DBWriter( connection=kafka, table=topic, - options=kafka.WriteOptions(includeHeaders=True), ) writer.run(df) diff --git a/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py b/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py index 96671563a..410a6a02d 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py @@ -27,7 +27,7 @@ def test_clickhouse_connection_check(spark, processing, caplog): with caplog.at_level(logging.INFO): assert clickhouse.check() == clickhouse - assert "type = Clickhouse" in caplog.text + assert "|Clickhouse|" in caplog.text assert f"host = '{processing.host}'" in caplog.text assert f"port = {processing.port}" in caplog.text assert f"database = '{processing.database}'" in caplog.text @@ -37,7 +37,7 @@ def test_clickhouse_connection_check(spark, processing, caplog): assert "package = " not in caplog.text assert "spark = " not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_clickhouse_connection_check_fail(spark): diff --git a/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py b/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py index 8a31d828e..91a52bb9d 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py @@ -27,7 +27,7 @@ def test_greenplum_connection_check(spark, processing, caplog): with caplog.at_level(logging.INFO): assert greenplum.check() == greenplum - assert "type = Greenplum" in caplog.text + assert "|Greenplum|" in caplog.text assert f"host = '{processing.host}'" in caplog.text assert f"port = {processing.port}" in caplog.text assert f"user = '{processing.user}'" in caplog.text @@ -39,7 +39,7 @@ def test_greenplum_connection_check(spark, processing, caplog): assert "package = " not in caplog.text assert "spark = " not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_greenplum_connection_check_fail(spark): diff --git a/tests/tests_integration/tests_db_connection_integration/test_hive_integration.py b/tests/tests_integration/tests_db_connection_integration/test_hive_integration.py index 19ae816f9..8a90578cb 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_hive_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_hive_integration.py @@ -20,10 +20,10 @@ def test_hive_check(spark, caplog): with caplog.at_level(logging.INFO): assert hive.check() == hive - assert "type = Hive" in caplog.text + assert "|Hive|" in caplog.text assert "spark = " not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text @pytest.mark.parametrize("suffix", ["", ";"]) diff --git a/tests/tests_integration/tests_db_connection_integration/test_kafka_integration.py b/tests/tests_integration/tests_db_connection_integration/test_kafka_integration.py index 5081307cf..b07d6b99e 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_kafka_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_kafka_integration.py @@ -20,7 +20,7 @@ def test_kafka_check_plaintext_anonymous(spark, caplog): with caplog.at_level(logging.INFO): assert kafka.check() == kafka - assert "type = Kafka" in caplog.text + assert "|Kafka|" in caplog.text assert "addresses = [" in caplog.text assert f"'{kafka_processing.host}:{kafka_processing.port}'" in caplog.text assert "cluster = 'cluster'" in caplog.text @@ -28,7 +28,7 @@ def test_kafka_check_plaintext_anonymous(spark, caplog): assert "auth = None" in caplog.text assert "extra = {}" in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_kafka_check_plaintext_basic_auth(spark, caplog): @@ -48,7 +48,7 @@ def test_kafka_check_plaintext_basic_auth(spark, caplog): with caplog.at_level(logging.INFO): assert kafka.check() == kafka - assert "type = Kafka" in caplog.text + assert "|Kafka|" in caplog.text assert "addresses = [" in caplog.text assert f"'{kafka_processing.host}:{kafka_processing.sasl_port}'" in caplog.text assert "cluster = 'cluster'" in caplog.text @@ -56,7 +56,7 @@ def test_kafka_check_plaintext_basic_auth(spark, caplog): assert f"auth = KafkaBasicAuth(user='{kafka_processing.user}', password=SecretStr('**********'))" in caplog.text assert "extra = {}" in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text @pytest.mark.parametrize("digest", ["SHA-256", "SHA-512"]) @@ -78,7 +78,7 @@ def test_kafka_check_plaintext_scram_auth(digest, spark, caplog): with caplog.at_level(logging.INFO): assert kafka.check() == kafka - assert "type = Kafka" in caplog.text + assert "|Kafka|" in caplog.text assert "addresses = [" in caplog.text assert f"'{kafka_processing.host}:{kafka_processing.sasl_port}'" in caplog.text assert "cluster = 'cluster'" in caplog.text @@ -89,7 +89,7 @@ def test_kafka_check_plaintext_scram_auth(digest, spark, caplog): ) assert "extra = {}" in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_kafka_check_error(spark): diff --git a/tests/tests_integration/tests_db_connection_integration/test_mongodb_integration.py b/tests/tests_integration/tests_db_connection_integration/test_mongodb_integration.py index 4ea0221bf..6283df1a1 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_mongodb_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_mongodb_integration.py @@ -20,7 +20,7 @@ def test_mongodb_connection_check(spark, processing, caplog): with caplog.at_level(logging.INFO): assert mongo.check() == mongo - assert "type = MongoDB" in caplog.text + assert "|MongoDB|" in caplog.text assert f"host = '{processing.host}'" in caplog.text assert f"port = {processing.port}" in caplog.text assert f"database = '{processing.database}'" in caplog.text @@ -31,7 +31,7 @@ def test_mongodb_connection_check(spark, processing, caplog): assert "package = " not in caplog.text assert "spark = " not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_mongodb_connection_check_fail(processing, spark): diff --git a/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py b/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py index 45a732711..96e23183b 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py @@ -27,7 +27,7 @@ def test_mssql_connection_check(spark, processing, caplog): with caplog.at_level(logging.INFO): assert mssql.check() == mssql - assert "type = MSSQL" in caplog.text + assert "|MSSQL|" in caplog.text assert f"host = '{processing.host}'" in caplog.text assert f"port = {processing.port}" in caplog.text assert f"database = '{processing.database}'" in caplog.text @@ -39,7 +39,7 @@ def test_mssql_connection_check(spark, processing, caplog): assert "package = " not in caplog.text assert "spark = " not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_mssql_connection_check_fail(spark): diff --git a/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py b/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py index d96b1f397..dad717b1c 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py @@ -26,7 +26,7 @@ def test_mysql_connection_check(spark, processing, caplog): with caplog.at_level(logging.INFO): assert mysql.check() == mysql - assert "type = MySQL" in caplog.text + assert "|MySQL|" in caplog.text assert f"host = '{processing.host}'" in caplog.text assert f"port = {processing.port}" in caplog.text assert f"database = '{processing.database}'" in caplog.text @@ -37,7 +37,7 @@ def test_mysql_connection_check(spark, processing, caplog): assert "package = " not in caplog.text assert "spark = " not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_mysql_connection_check_fail(spark): diff --git a/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py b/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py index d6fdad022..40e813279 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py @@ -28,7 +28,7 @@ def test_oracle_connection_check(spark, processing, caplog): with caplog.at_level(logging.INFO): assert oracle.check() == oracle - assert "type = Oracle" in caplog.text + assert "|Oracle|" in caplog.text assert f"host = '{processing.host}'" in caplog.text assert f"port = {processing.port}" in caplog.text assert "database" not in caplog.text @@ -46,7 +46,7 @@ def test_oracle_connection_check(spark, processing, caplog): assert "package = " not in caplog.text assert "spark = " not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_oracle_connection_check_fail(spark): diff --git a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py index f53ec976f..9f2d2253b 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py @@ -26,7 +26,7 @@ def test_postgres_connection_check(spark, processing, caplog): with caplog.at_level(logging.INFO): assert postgres.check() == postgres - assert "type = Postgres" in caplog.text + assert "|Postgres|" in caplog.text assert f"host = '{processing.host}'" in caplog.text assert f"port = {processing.port}" in caplog.text assert f"user = '{processing.user}'" in caplog.text @@ -38,7 +38,7 @@ def test_postgres_connection_check(spark, processing, caplog): assert "package = " not in caplog.text assert "spark = " not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_postgres_connection_check_fail(spark): diff --git a/tests/tests_integration/tests_db_connection_integration/test_teradata_integration.py b/tests/tests_integration/tests_db_connection_integration/test_teradata_integration.py index 26ca462f1..f3224025e 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_teradata_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_teradata_integration.py @@ -29,7 +29,7 @@ def test_teradata_connection_check(spark, mocker, caplog): with caplog.at_level(logging.INFO): assert teradata.check() == teradata - assert "type = Teradata" in caplog.text + assert "|Teradata|" in caplog.text assert f"host = '{host}'" in caplog.text assert f"port = {port}" in caplog.text assert f"database = '{database}" in caplog.text @@ -40,7 +40,7 @@ def test_teradata_connection_check(spark, mocker, caplog): assert "package = " not in caplog.text assert "spark = " not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_teradata_connection_check_fail(spark): diff --git a/tests/tests_integration/tests_file_connection_integration/test_ftp_file_connection_integration.py b/tests/tests_integration/tests_file_connection_integration/test_ftp_file_connection_integration.py index df99d72d9..87e69227b 100644 --- a/tests/tests_integration/tests_file_connection_integration/test_ftp_file_connection_integration.py +++ b/tests/tests_integration/tests_file_connection_integration/test_ftp_file_connection_integration.py @@ -10,14 +10,14 @@ def test_ftp_file_connection_check_success(ftp_file_connection, caplog): with caplog.at_level(logging.INFO): assert ftp.check() == ftp - assert "type = FTP" in caplog.text + assert "|FTP|" in caplog.text assert f"host = '{ftp.host}'" in caplog.text assert f"port = {ftp.port}" in caplog.text assert f"user = '{ftp.user}'" in caplog.text assert "password = SecretStr('**********')" in caplog.text assert ftp.password.get_secret_value() not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_ftp_file_connection_check_anonymous(ftp_server, caplog): @@ -28,13 +28,13 @@ def test_ftp_file_connection_check_anonymous(ftp_server, caplog): with caplog.at_level(logging.INFO): assert anonymous.check() == anonymous - assert "type = FTP" in caplog.text + assert "|FTP|" in caplog.text assert f"host = '{anonymous.host}'" in caplog.text assert f"port = {anonymous.port}" in caplog.text assert "user = " not in caplog.text assert "password = " not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_ftp_file_connection_check_failed(ftp_server): diff --git a/tests/tests_integration/tests_file_connection_integration/test_ftps_file_connection_integration.py b/tests/tests_integration/tests_file_connection_integration/test_ftps_file_connection_integration.py index f3756801e..e82d4b47a 100644 --- a/tests/tests_integration/tests_file_connection_integration/test_ftps_file_connection_integration.py +++ b/tests/tests_integration/tests_file_connection_integration/test_ftps_file_connection_integration.py @@ -10,14 +10,14 @@ def test_ftps_file_connection_check_success(ftps_file_connection, caplog): with caplog.at_level(logging.INFO): assert ftps.check() == ftps - assert "type = FTPS" in caplog.text + assert "|FTPS|" in caplog.text assert f"host = '{ftps.host}'" in caplog.text assert f"port = {ftps.port}" in caplog.text assert f"user = '{ftps.user}'" in caplog.text assert "password = SecretStr('**********')" in caplog.text assert ftps.password.get_secret_value() not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_ftps_file_connection_check_anonymous(ftps_server, caplog): @@ -28,13 +28,13 @@ def test_ftps_file_connection_check_anonymous(ftps_server, caplog): with caplog.at_level(logging.INFO): assert anonymous.check() == anonymous - assert "type = FTP" in caplog.text + assert "|FTPS|" in caplog.text assert f"host = '{anonymous.host}'" in caplog.text assert f"port = {anonymous.port}" in caplog.text assert "user = " not in caplog.text assert "password = " not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_ftps_file_connection_check_failed(ftps_server): diff --git a/tests/tests_integration/tests_file_connection_integration/test_hdfs_file_connection_integration.py b/tests/tests_integration/tests_file_connection_integration/test_hdfs_file_connection_integration.py index a7600e166..f92a1917a 100644 --- a/tests/tests_integration/tests_file_connection_integration/test_hdfs_file_connection_integration.py +++ b/tests/tests_integration/tests_file_connection_integration/test_hdfs_file_connection_integration.py @@ -17,7 +17,7 @@ def test_hdfs_file_connection_check_anonymous(hdfs_file_connection, caplog): with caplog.at_level(logging.INFO): assert hdfs.check() == hdfs - assert "type = HDFS" in caplog.text + assert "|HDFS|" in caplog.text assert f"host = '{hdfs.host}'" in caplog.text assert f"webhdfs_port = {hdfs.webhdfs_port}" in caplog.text assert "timeout = 10" in caplog.text @@ -25,14 +25,14 @@ def test_hdfs_file_connection_check_anonymous(hdfs_file_connection, caplog): assert "keytab =" not in caplog.text assert "password =" not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_hdfs_file_connection_check_with_keytab(mocker, hdfs_server, caplog, request, tmp_path_factory): from onetl.connection import HDFS - from onetl.connection.file_connection import hdfs + from onetl.connection.file_connection.hdfs import connection - mocker.patch.object(hdfs, "kinit") + mocker.patch.object(connection, "kinit") folder: Path = tmp_path_factory.mktemp("keytab") folder.mkdir(exist_ok=True, parents=True) @@ -50,7 +50,7 @@ def finalizer(): with caplog.at_level(logging.INFO): assert hdfs.check() - assert "type = HDFS" in caplog.text + assert "|HDFS|" in caplog.text assert f"host = '{hdfs.host}'" in caplog.text assert f"webhdfs_port = {hdfs.webhdfs_port}" in caplog.text assert f"user = '{hdfs.user}'" in caplog.text @@ -58,21 +58,21 @@ def finalizer(): assert f"keytab = '{keytab}' (kind='file'" in caplog.text assert "password =" not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_hdfs_file_connection_check_with_password(mocker, hdfs_server, caplog): from onetl.connection import HDFS - from onetl.connection.file_connection import hdfs + from onetl.connection.file_connection.hdfs import connection - mocker.patch.object(hdfs, "kinit") + mocker.patch.object(connection, "kinit") hdfs = HDFS(host=hdfs_server.host, port=hdfs_server.webhdfs_port, user=getuser(), password="somepass") with caplog.at_level(logging.INFO): assert hdfs.check() - assert "type = HDFS" in caplog.text + assert "|HDFS|" in caplog.text assert f"host = '{hdfs.host}'" in caplog.text assert f"webhdfs_port = {hdfs.webhdfs_port}" in caplog.text assert "timeout = 10" in caplog.text @@ -81,7 +81,7 @@ def test_hdfs_file_connection_check_with_password(mocker, hdfs_server, caplog): assert "password = SecretStr('**********')" in caplog.text assert "somepass" not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_hdfs_file_connection_check_failed(): diff --git a/tests/tests_integration/tests_file_connection_integration/test_s3_file_connection_integration.py b/tests/tests_integration/tests_file_connection_integration/test_s3_file_connection_integration.py index 6f9276919..1d6c2a407 100644 --- a/tests/tests_integration/tests_file_connection_integration/test_s3_file_connection_integration.py +++ b/tests/tests_integration/tests_file_connection_integration/test_s3_file_connection_integration.py @@ -11,7 +11,7 @@ def test_s3_file_connection_check_success(caplog, s3_file_connection): with caplog.at_level(logging.INFO): assert s3.check() == s3 - assert "type = S3" in caplog.text + assert "|S3|" in caplog.text assert f"host = '{s3.host}'" in caplog.text assert f"port = {s3.port}" in caplog.text assert f"protocol = '{s3.protocol}'" in caplog.text @@ -21,7 +21,7 @@ def test_s3_file_connection_check_success(caplog, s3_file_connection): assert s3.secret_key.get_secret_value() not in caplog.text assert "session_token =" not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_s3_file_connection_check_failed(s3_server): diff --git a/tests/tests_integration/tests_file_connection_integration/test_sftp_file_connection_integration.py b/tests/tests_integration/tests_file_connection_integration/test_sftp_file_connection_integration.py index 70faf498e..8c95659ba 100644 --- a/tests/tests_integration/tests_file_connection_integration/test_sftp_file_connection_integration.py +++ b/tests/tests_integration/tests_file_connection_integration/test_sftp_file_connection_integration.py @@ -10,7 +10,7 @@ def test_sftp_file_connection_check_success(sftp_file_connection, caplog): with caplog.at_level(logging.INFO): assert sftp.check() == sftp - assert "type = SFTP" in caplog.text + assert "|SFTP|" in caplog.text assert f"host = '{sftp.host}'" in caplog.text assert f"port = {sftp.port}" in caplog.text assert f"user = '{sftp.user}'" in caplog.text @@ -21,7 +21,7 @@ def test_sftp_file_connection_check_success(sftp_file_connection, caplog): assert "password = SecretStr('**********')" in caplog.text assert sftp.password.get_secret_value() not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_sftp_file_connection_check_failed(sftp_server): diff --git a/tests/tests_integration/tests_file_connection_integration/test_webdav_file_connection_integration.py b/tests/tests_integration/tests_file_connection_integration/test_webdav_file_connection_integration.py index 3a91523ce..435d424aa 100644 --- a/tests/tests_integration/tests_file_connection_integration/test_webdav_file_connection_integration.py +++ b/tests/tests_integration/tests_file_connection_integration/test_webdav_file_connection_integration.py @@ -10,7 +10,7 @@ def test_webdav_file_connection_check_success(webdav_file_connection, caplog): with caplog.at_level(logging.INFO): assert webdav.check() == webdav - assert "type = WebDAV" in caplog.text + assert "|WebDAV|" in caplog.text assert f"host = '{webdav.host}'" in caplog.text assert f"port = {webdav.port}" in caplog.text assert f"protocol = '{webdav.protocol}'" in caplog.text @@ -19,7 +19,7 @@ def test_webdav_file_connection_check_success(webdav_file_connection, caplog): assert "password = SecretStr('**********')" in caplog.text assert webdav.password.get_secret_value() not in caplog.text - assert "Connection is available" in caplog.text + assert "Connection is available." in caplog.text def test_webdav_file_connection_check_failed(webdav_server): diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py index 6a46f25c1..ba16f231b 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py @@ -39,7 +39,7 @@ def test_clickhouse_reader_snapshot_error_pass_df_schema(spark_mock): def test_clickhouse_reader_wrong_table_name(spark_mock, table): clickhouse = Clickhouse(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBReader( connection=clickhouse, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py index 57d60f8a2..3770bcf19 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py @@ -40,7 +40,7 @@ def test_greenplum_reader_snapshot_error_pass_df_schema(spark_mock): def test_greenplum_reader_wrong_table_name(spark_mock, table): greenplum = Greenplum(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBReader( connection=greenplum, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py index a6493c57a..5fd228f5b 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py @@ -29,30 +29,6 @@ def df_schema(): ) -def test_kafka_reader_invalid_table(spark_mock): - kafka = Kafka( - addresses=["localhost:9092"], - cluster="my_cluster", - spark=spark_mock, - ) - with pytest.raises( - ValueError, - match="Table name should be passed in `mytable` format", - ): - DBReader( - connection=kafka, - table="schema.table", # Includes schema. Required format: table="table" - ) - with pytest.raises( - ValueError, - match="Table name should be passed in `schema.name` format", - ): - DBReader( - connection=kafka, - table="schema.table.subtable", # Includes subtable. Required format: table="table" - ) - - def test_kafka_reader_unsupported_parameters(spark_mock, df_schema): kafka = Kafka( addresses=["localhost:9092"], diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_mongodb_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_mongodb_reader_unit.py index bd92a8638..2a164c16d 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_mongodb_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_mongodb_reader_unit.py @@ -28,24 +28,6 @@ def df_schema(): ) -def test_mongodb_reader_with_dbschema(spark_mock): - mongo = MongoDB( - host="host", - user="user", - password="password", - database="database", - spark=spark_mock, - ) - with pytest.raises( - ValueError, - match="Table name should be passed in `mytable` format", - ): - DBReader( - connection=mongo, - table="schema.table", # Includes schema. Required format: table="table" - ) - - def test_mongodb_reader_wrong_hint_type(spark_mock, df_schema): mongo = MongoDB( host="host", diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py index febcf8d8c..c8c483447 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py @@ -40,7 +40,7 @@ def test_mssql_reader_snapshot_error_pass_df_schema(spark_mock): def test_mssql_reader_wrong_table_name(spark_mock, table): mssql = MSSQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBReader( connection=mssql, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py index 7e5848544..19c4c662f 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py @@ -40,7 +40,7 @@ def test_mysql_reader_snapshot_error_pass_df_schema(spark_mock): def test_mysql_reader_wrong_table_name(spark_mock, table): mysql = MySQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBReader( connection=mysql, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py index 3fa67b130..1011f17ec 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py @@ -39,7 +39,7 @@ def test_oracle_reader_error_df_schema(spark_mock): @pytest.mark.parametrize("table", ["table", "table.table.table"]) def test_oracle_reader_wrong_table_name(spark_mock, table): oracle = Oracle(host="some_host", user="user", sid="sid", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBReader( connection=oracle, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py index 1b15437e8..c541e33cd 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py @@ -40,7 +40,7 @@ def test_postgres_reader_snapshot_error_pass_df_schema(spark_mock): def test_postgres_reader_wrong_table_name(spark_mock, table): postgres = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBReader( connection=postgres, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_teradata_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_teradata_reader_unit.py index bc5140d27..262897fde 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_teradata_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_teradata_reader_unit.py @@ -40,7 +40,7 @@ def test_teradata_reader_snapshot_error_pass_df_schema(spark_mock): def test_teradata_reader_wrong_table_name(spark_mock, table): teradata = Teradata(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBReader( connection=teradata, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py index c82ff4af4..eb4b34028 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py @@ -10,7 +10,7 @@ def test_clickhouse_writer_wrong_table_name(spark_mock, table): clickhouse = Clickhouse(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBWriter( connection=clickhouse, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py index 6b2c63e0a..fb3614cdd 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py @@ -10,7 +10,7 @@ def test_greenplum_writer_wrong_table_name(spark_mock, table): greenplum = Greenplum(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBWriter( connection=greenplum, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py index f480d4b6a..2fbfdb573 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py @@ -10,7 +10,7 @@ def test_hive_writer_wrong_table_name(spark_mock, table): hive = Hive(cluster="rnd-dwh", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBWriter( connection=hive, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_mongodb_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_mongodb_writer_unit.py deleted file mode 100644 index 99ce4218a..000000000 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_mongodb_writer_unit.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from onetl.connection import MongoDB -from onetl.db import DBWriter - -pytestmark = pytest.mark.mongodb - - -def test_mongodb_writer_wrong_table_name(spark_mock): - mongo = MongoDB( - host="host", - user="user", - password="password", - database="database", - spark=spark_mock, - ) - - with pytest.raises(ValueError, match="Table name should be passed in `mytable` format"): - DBWriter( - connection=mongo, - table="schema.table", # Includes schema. Required format: table="table" - ) diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py index 57de4faaf..44618ff11 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py @@ -10,7 +10,7 @@ def test_mssql_writer_wrong_table_name(spark_mock, table): mssql = MSSQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBWriter( connection=mssql, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py index a35754457..8eb54f397 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py @@ -10,7 +10,7 @@ def test_mysql_writer_wrong_table_name(spark_mock, table): mysql = MySQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBWriter( connection=mysql, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py index 2a734e6e1..63668cacf 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py @@ -10,7 +10,7 @@ def test_oracle_writer_wrong_table_name(spark_mock, table): oracle = Oracle(host="some_host", user="user", sid="sid", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBWriter( connection=oracle, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py index db17a0119..d7322b5f1 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py @@ -10,7 +10,7 @@ def test_postgres_writer_wrong_table_name(spark_mock, table): postgres = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBWriter( connection=postgres, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_teradata_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_teradata_writer_unit.py index f13bb1957..6b9434ba7 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_teradata_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_teradata_writer_unit.py @@ -10,7 +10,7 @@ def test_teradata_writer_wrong_table_name(spark_mock, table): teradata = Teradata(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Table name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): DBWriter( connection=teradata, table=table, # Required format: table="shema.table" diff --git a/tests/tests_unit/tests_db_connection_unit/test_db_options_unit.py b/tests/tests_unit/tests_db_connection_unit/test_db_options_unit.py index 72415207f..8e51d0a89 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_db_options_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_db_options_unit.py @@ -36,13 +36,13 @@ def test_db_options_connection_parameters_cannot_be_passed(options_class, arg, v @pytest.mark.parametrize( "options_class, options_class_name, known_options", [ - (Hive.WriteOptions, "WriteOptions", {"if_exists": "replace_overlapping_partitions"}), - (Hive.Options, "Options", {"if_exists": "replace_overlapping_partitions"}), - (Postgres.ReadOptions, "ReadOptions", {"fetchsize": 10, "keytab": "a/b/c"}), - (Postgres.WriteOptions, "WriteOptions", {"if_exists": "replace_entire_table", "keytab": "a/b/c"}), - (Postgres.Options, "Options", {"if_exists": "replace_entire_table", "keytab": "a/b/c"}), - (Greenplum.ReadOptions, "ReadOptions", {"partitions": 10}), - (Greenplum.WriteOptions, "WriteOptions", {"if_exists": "replace_entire_table"}), + (Hive.WriteOptions, "HiveWriteOptions", {"if_exists": "replace_overlapping_partitions"}), + (Hive.Options, "HiveLegacyOptions", {"if_exists": "replace_overlapping_partitions"}), + (Postgres.ReadOptions, "JDBCReadOptions", {"fetchsize": 10, "keytab": "a/b/c"}), + (Postgres.WriteOptions, "JDBCWriteOptions", {"if_exists": "replace_entire_table", "keytab": "a/b/c"}), + (Postgres.Options, "JDBCLegacyOptions", {"if_exists": "replace_entire_table", "keytab": "a/b/c"}), + (Greenplum.ReadOptions, "GreenplumReadOptions", {"partitions": 10}), + (Greenplum.WriteOptions, "GreenplumWriteOptions", {"if_exists": "replace_entire_table"}), ], ) def test_db_options_warn_for_unknown(options_class, options_class_name, known_options, caplog): diff --git a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py index 6954186c1..f54eec0a4 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py @@ -176,60 +176,49 @@ def test_greenplum_write_options_default(): options = Greenplum.WriteOptions() assert options.if_exists == GreenplumTableExistBehavior.APPEND - assert options.query_timeout is None + + +@pytest.mark.parametrize( + "klass, name", + [ + (Greenplum.ReadOptions, "GreenplumReadOptions"), + (Greenplum.WriteOptions, "GreenplumWriteOptions"), + (Greenplum.JDBCOptions, "JDBCOptions"), + (Greenplum.Extra, "GreenplumExtra"), + ], +) +def test_greenplum_jdbc_options_populated_by_connection_class(klass, name): + error_msg = rf"Options \['driver', 'password', 'url', 'user'\] are not allowed to use in a {name}" + with pytest.raises(ValueError, match=error_msg): + klass(user="me", password="abc", driver="some.Class", url="jdbc:postgres://some/db") def test_greenplum_read_write_options_populated_by_connection_class(): - error_msg = r"Options \['dbschema', 'dbtable'\] are not allowed to use in a ReadOptions" + error_msg = r"Options \['dbschema', 'dbtable'\] are not allowed to use in a GreenplumReadOptions" with pytest.raises(ValueError, match=error_msg): Greenplum.ReadOptions(dbschema="myschema", dbtable="mytable") - error_msg = r"Options \['dbschema', 'dbtable'\] are not allowed to use in a WriteOptions" + error_msg = r"Options \['dbschema', 'dbtable'\] are not allowed to use in a GreenplumWriteOptions" with pytest.raises(ValueError, match=error_msg): Greenplum.WriteOptions(dbschema="myschema", dbtable="mytable") - error_msg = r"Options \['dbschema', 'dbtable'\] are not allowed to use in a Extra" - with pytest.raises(ValueError, match=error_msg): - Greenplum.Extra(dbschema="myschema", dbtable="mytable") - # JDBCOptions does not have such restriction options = Greenplum.JDBCOptions(dbschema="myschema", dbtable="mytable") assert options.dbschema == "myschema" assert options.dbtable == "mytable" -@pytest.mark.parametrize( - "options_class", - [ - Greenplum.ReadOptions, - Greenplum.WriteOptions, - ], -) -@pytest.mark.parametrize( - "arg, value", - [ - ("server.port", 8000), - ("pool.maxSize", "40"), - ], -) -def test_greenplum_read_write_options_prohibited(arg, value, options_class): - with pytest.raises(ValueError, match=rf"Options \['{arg}'\] are not allowed to use in a {options_class.__name__}"): - options_class.parse({arg: value}) - - @pytest.mark.parametrize( "arg, value", [ ("mode", "append"), ("truncate", "true"), ("distributedBy", "abc"), - ("distributed_by", "abc"), ("iteratorOptimization", "true"), - ("iterator_optimization", "true"), ], ) def test_greenplum_write_options_cannot_be_used_in_read_options(arg, value): - error_msg = rf"Options \['{arg}'\] are not allowed to use in a ReadOptions" + error_msg = rf"Options \['{arg}'\] are not allowed to use in a GreenplumReadOptions" with pytest.raises(ValueError, match=error_msg): Greenplum.ReadOptions.parse({arg: value}) @@ -238,14 +227,12 @@ def test_greenplum_write_options_cannot_be_used_in_read_options(arg, value): "arg, value", [ ("partitions", 10), - ("num_partitions", 10), ("numPartitions", 10), ("partitionColumn", "abc"), - ("partition_column", "abc"), ], ) def test_greenplum_read_options_cannot_be_used_in_write_options(arg, value): - error_msg = rf"Options \['{arg}'\] are not allowed to use in a WriteOptions" + error_msg = rf"Options \['{arg}'\] are not allowed to use in a GreenplumWriteOptions" with pytest.raises(ValueError, match=error_msg): Greenplum.WriteOptions.parse({arg: value}) @@ -256,6 +243,8 @@ def test_greenplum_read_options_cannot_be_used_in_write_options(arg, value): ({}, GreenplumTableExistBehavior.APPEND), ({"if_exists": "append"}, GreenplumTableExistBehavior.APPEND), ({"if_exists": "replace_entire_table"}, GreenplumTableExistBehavior.REPLACE_ENTIRE_TABLE), + ({"if_exists": "error"}, GreenplumTableExistBehavior.ERROR), + ({"if_exists": "ignore"}, GreenplumTableExistBehavior.IGNORE), ], ) def test_greenplum_write_options_if_exists(options, value): @@ -283,6 +272,18 @@ def test_greenplum_write_options_if_exists(options, value): "Mode `overwrite` is deprecated since v0.9.0 and will be removed in v1.0.0. " "Use `replace_entire_table` instead", ), + ( + {"mode": "ignore"}, + GreenplumTableExistBehavior.IGNORE, + "Option `Greenplum.WriteOptions(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `Greenplum.WriteOptions(if_exists=...)` instead", + ), + ( + {"mode": "error"}, + GreenplumTableExistBehavior.ERROR, + "Option `Greenplum.WriteOptions(mode=...)` is deprecated since v0.9.0 and will be removed in v1.0.0. " + "Use `Greenplum.WriteOptions(if_exists=...)` instead", + ), ], ) def test_greenplum_write_options_mode_deprecated(options, value, message): @@ -294,10 +295,6 @@ def test_greenplum_write_options_mode_deprecated(options, value, message): @pytest.mark.parametrize( "options", [ - # disallowed modes - {"mode": "error"}, - {"mode": "ignore"}, - # wrong mode {"mode": "wrong_mode"}, ], ) diff --git a/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py b/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py index 988bb8a27..ae81402cc 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py @@ -45,11 +45,11 @@ def test_jdbc_options_default(): ], ) def test_jdbc_read_write_options_populated_by_connection_class(arg, value): - error_msg = rf"Options \['{arg}'\] are not allowed to use in a ReadOptions" + error_msg = rf"Options \['{arg}'\] are not allowed to use in a JDBCReadOptions" with pytest.raises(ValueError, match=error_msg): Postgres.ReadOptions.parse({arg: value}) - error_msg = rf"Options \['{arg}'\] are not allowed to use in a WriteOptions" + error_msg = rf"Options \['{arg}'\] are not allowed to use in a JDBCWriteOptions" with pytest.raises(ValueError, match=error_msg): Postgres.WriteOptions.parse({arg: value}) @@ -73,7 +73,7 @@ def test_jdbc_read_write_options_populated_by_connection_class(arg, value): ], ) def test_jdbc_write_options_cannot_be_used_in_read_options(arg, value): - error_msg = rf"Options \['{arg}'\] are not allowed to use in a ReadOptions" + error_msg = rf"Options \['{arg}'\] are not allowed to use in a JDBCReadOptions" with pytest.raises(ValueError, match=error_msg): Postgres.ReadOptions.parse({arg: value}) @@ -101,7 +101,7 @@ def test_jdbc_write_options_cannot_be_used_in_read_options(arg, value): ], ) def test_jdbc_read_options_cannot_be_used_in_write_options(arg, value): - error_msg = rf"Options \['{arg}'\] are not allowed to use in a WriteOptions" + error_msg = rf"Options \['{arg}'\] are not allowed to use in a JDBCWriteOptions" with pytest.raises(ValueError, match=error_msg): Postgres.WriteOptions.parse({arg: value}) diff --git a/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py b/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py index d47270e39..1524d8ad7 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py @@ -71,7 +71,7 @@ def test_kafka_missing_package(spark_no_packages): @pytest.mark.parametrize( - "arg, value", + "option, value", [ ("assign", "assign_value"), ("subscribe", "subscribe_value"), @@ -87,17 +87,36 @@ def test_kafka_missing_package(spark_no_packages): ("topic", "topic_value"), ], ) -def test_kafka_prohibited_options_error(arg, value): - error_msg = rf"Options \['{arg}'\] are not allowed to use in a KafkaReadOptions" - with pytest.raises(ValueError, match=error_msg): - Kafka.ReadOptions.parse({arg: value}) - error_msg = rf"Options \['{arg}'\] are not allowed to use in a KafkaWriteOptions" +@pytest.mark.parametrize( + "options_class, class_name", + [ + (Kafka.ReadOptions, "KafkaReadOptions"), + (Kafka.WriteOptions, "KafkaWriteOptions"), + ], +) +def test_kafka_options_prohibited(option, value, options_class, class_name): + error_msg = rf"Options \['{option}'\] are not allowed to use in a {class_name}" with pytest.raises(ValueError, match=error_msg): - Kafka.WriteOptions.parse({arg: value}) + options_class.parse({option: value}) + + +@pytest.mark.parametrize( + "options_class, class_name", + [ + (Kafka.ReadOptions, "KafkaReadOptions"), + (Kafka.WriteOptions, "KafkaWriteOptions"), + ], +) +def test_kafka_options_unknown(caplog, options_class, class_name): + with caplog.at_level(logging.WARNING): + options = options_class(unknown="abc") + assert options.unknown == "abc" + + assert f"Options ['unknown'] are not known by {class_name}, are you sure they are valid?" in caplog.text @pytest.mark.parametrize( - "arg, value", + "option, value", [ ("failOnDataLoss", "false"), ("kafkaConsumer.pollTimeoutMs", "30000"), @@ -108,30 +127,21 @@ def test_kafka_prohibited_options_error(arg, value): ("maxTriggerDelay", "2000"), ("minPartitions", "2"), ("groupIdPrefix", "testPrefix"), - ("includeHeaders", "true"), ], ) -def test_kafka_allowed_read_options_no_error(arg, value): - try: - Kafka.ReadOptions.parse({arg: value}) - except ValidationError: - pytest.fail("ValidationError for ReadOptions raised unexpectedly!") +def test_kafka_read_options_allowed(option, value): + options = Kafka.ReadOptions.parse({option: value}) + assert getattr(options, option) == value -@pytest.mark.parametrize( - "arg, value", - [ - ("includeHeaders", "true"), - ], -) -def test_kafka_allowed_write_options_no_error(arg, value): - try: - Kafka.WriteOptions.parse({arg: value}) - except ValidationError: - pytest.fail("ValidationError for Write options raised unexpectedly!") +@pytest.mark.parametrize("value", [True, False]) +@pytest.mark.parametrize("options_class", [Kafka.ReadOptions, Kafka.WriteOptions]) +def test_kafka_options_include_headers(options_class, value): + options = options_class(includeHeaders=value) + assert options.include_headers == value -def test_kafka_basic_auth(spark_mock): +def test_kafka_basic_auth_get_jaas_conf(spark_mock): conn = Kafka( spark=spark_mock, cluster="some_cluster", @@ -237,7 +247,7 @@ def test_kafka_empty_cluster(spark_mock): @pytest.mark.parametrize( - "arg, value", + "option, value", [ ("bootstrap.servers", "kafka.bootstrap.servers_value"), ("security.protocol", "ssl"), @@ -251,23 +261,23 @@ def test_kafka_empty_cluster(spark_mock): ("value.deserializer", "org.apache.kafka.common.serialization.ByteArrayDeserializer"), ], ) -def test_kafka_invalid_extras(arg, value): +def test_kafka_invalid_extras(option, value): msg = re.escape("are not allowed to use in a KafkaExtra") with pytest.raises(ValueError, match=msg): - KafkaExtra.parse({arg: value}) + KafkaExtra.parse({option: value}) with pytest.raises(ValueError, match=msg): - KafkaExtra.parse({"kafka." + arg: value}) + KafkaExtra.parse({"kafka." + option: value}) @pytest.mark.parametrize( - "arg, value", + "option, value", [ ("kafka.group.id", "group_id"), ("group.id", "group_id"), ], ) -def test_kafka_valid_extras(arg, value): - extra_dict = KafkaExtra.parse({arg: value}).dict() +def test_kafka_valid_extras(option, value): + extra_dict = KafkaExtra.parse({option: value}).dict() assert extra_dict["group.id"] == value diff --git a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py index 6846b94c6..e6cd8eb89 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py @@ -110,7 +110,7 @@ def test_mssql_with_extra(spark_mock): def test_mssql_with_extra_prohibited(spark_mock): - with pytest.raises(ValueError, match=r"Options \['databaseName'\] are not allowed to use in a Extra"): + with pytest.raises(ValueError, match=r"Options \['databaseName'\] are not allowed to use in a MSSQLExtra"): MSSQL( host="some_host", user="user", diff --git a/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py b/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py index 92d5b7f2c..1daf14dc4 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py @@ -115,7 +115,10 @@ def test_teradata_with_extra(spark_mock): def test_teradata_with_extra_prohibited(spark_mock): - with pytest.raises(ValueError, match=r"Options \['DATABASE', 'DBS_PORT'\] are not allowed to use in a Extra"): + with pytest.raises( + ValueError, + match=r"Options \['DATABASE', 'DBS_PORT'\] are not allowed to use in a TeradataExtra", + ): Teradata( host="some_host", user="user", diff --git a/tests/util/to_pandas.py b/tests/util/to_pandas.py index 3b3e3d9b2..142a024d6 100644 --- a/tests/util/to_pandas.py +++ b/tests/util/to_pandas.py @@ -23,6 +23,8 @@ def fix_pyspark_df(df: SparkDataFrame) -> SparkDataFrame: TypeError: Casting to unit-less dtype 'datetime64' is not supported. Pass e.g. 'datetime64[ns]' instead. This method converts dates and timestamps to strings, to convert them back to original type later. + + TODO: remove after https://issues.apache.org/jira/browse/SPARK-43194 """ from pyspark.sql.functions import date_format from pyspark.sql.types import DateType, TimestampType