diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 7f542365dc452..07be634d0777e 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -182,6 +182,19 @@ def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, .. def epoch_to_dttm(cls) -> str: return "(timestamp 'epoch' + {col} * interval '1 second')" + @classmethod + def convert_dttm( + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, Date): + return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')" + if isinstance(sqla_type, DateTime): + dttm_formatted = dttm.isoformat(sep=" ", timespec="microseconds") + return f"""TO_TIMESTAMP('{dttm_formatted}', 'YYYY-MM-DD HH24:MI:SS.US')""" + return None + class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): engine = "postgresql" @@ -357,19 +370,6 @@ def get_table_names( inspector.get_foreign_table_names(schema) ) - @classmethod - def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None - ) -> str | None: - sqla_type = cls.get_sqla_column_type(target_type) - - if isinstance(sqla_type, Date): - return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')" - if isinstance(sqla_type, DateTime): - dttm_formatted = dttm.isoformat(sep=" ", timespec="microseconds") - return f"""TO_TIMESTAMP('{dttm_formatted}', 'YYYY-MM-DD HH24:MI:SS.US')""" - return None - @staticmethod def get_extra_params(database: Database) -> dict[str, Any]: """ diff --git a/superset/db_engine_specs/redshift.py b/superset/db_engine_specs/redshift.py index 925e0424e41e2..44281d2c36f86 100644 --- a/superset/db_engine_specs/redshift.py +++ b/superset/db_engine_specs/redshift.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging import re from re import Pattern -from typing import Any, Optional +from typing import Any import pandas as pd from flask_babel import gettext as __ @@ -148,7 +150,7 @@ def _mutate_label(label: str) -> str: return label.lower() @classmethod - def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: + def get_cancel_query_id(cls, cursor: Any, query: Query) -> str | None: """ Get Redshift PID that will be used to cancel all other running queries in the same session. diff --git a/tests/unit_tests/db_engine_specs/test_redshift.py b/tests/unit_tests/db_engine_specs/test_redshift.py new file mode 100644 index 0000000000000..ddd2c1a5eb2ea --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_redshift.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"), + ( + "DateTime", + "TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')", + ), + ( + "TimeStamp", + "TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')", + ), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm)