Skip to content

Commit

Permalink
fix: Pre-query normalization with custom SQL (apache#30389)
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-s-molina authored Sep 25, 2024
1 parent 69d5f76 commit ad29985
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
4 changes: 2 additions & 2 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,10 +1688,10 @@ def _normalize_prequery_result_type(
if isinstance(value, np.generic):
value = value.item()

column_ = columns_by_name[dimension]
column_ = columns_by_name.get(dimension)
db_extra: dict[str, Any] = self.database.get_extra()

if column_.type and column_.is_temporal and isinstance(value, str):
if column_ and column_.type and column_.is_temporal and isinstance(value, str):
sql = self.db_engine_spec.convert_dttm(
column_.type, dateutil.parser.parse(value), db_extra=db_extra
)
Expand Down
26 changes: 25 additions & 1 deletion tests/unit_tests/connectors/sqla/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
# specific language governing permissions and limitations
# under the License.

import pandas as pd
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import create_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.session import Session

from superset.connectors.sqla.models import SqlaTable
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.daos.dataset import DatasetDAO
from superset.exceptions import OAuth2RedirectError
from superset.models.core import Database
Expand Down Expand Up @@ -263,3 +264,26 @@ def test_dataset_uniqueness(session: Session) -> None:
database,
Table("table", "schema", "some_catalog"),
)


def test_normalize_prequery_result_type_custom_sql() -> None:
"""
Test that the `_normalize_prequery_result_type` can hanndle custom SQL.
"""
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=Database(database_name="my_db", sqlalchemy_uri="sqlite://"),
)
row: pd.Series = {
"custom_sql": "Car",
}
dimension: str = "custom_sql"
columns_by_name: dict[str, TableColumn] = {
"product_line": TableColumn(column_name="product_line"),
}
assert (
sqla_table._normalize_prequery_result_type(row, dimension, columns_by_name)
== "Car"
)

0 comments on commit ad29985

Please sign in to comment.