Skip to content

Commit

Permalink
fix(mssql db_engine_spec): adds uniqueidentifier to column_type_mappi…
Browse files Browse the repository at this point in the history
…ngs (apache#30618)
  • Loading branch information
rparsonsbb authored Oct 30, 2024
1 parent a729f04 commit 58edc79
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
8 changes: 8 additions & 0 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
import re
from datetime import datetime
Expand All @@ -27,6 +29,7 @@
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.errors import SupersetErrorType
from superset.models.sql_types.mssql_sql_types import GUID
from superset.utils.core import GenericDataType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,6 +90,11 @@ class MssqlEngineSpec(BaseEngineSpec):
SMALLDATETIME(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^uniqueidentifier.*", re.IGNORECASE),
GUID(),
GenericDataType.STRING,
),
)

custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
Expand Down
61 changes: 61 additions & 0 deletions superset/models/sql_types/mssql_sql_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=abstract-method
import uuid
from typing import Any, Optional

from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.sqltypes import CHAR
from sqlalchemy.sql.visitors import Visitable
from sqlalchemy.types import TypeDecorator

# _compiler_dispatch is defined to help with type compilation


class GUID(TypeDecorator):
"""
A type for SQL Server's uniqueidentifier, stored as stringified UUIDs.
"""

impl = CHAR

@property
def python_type(self) -> type[uuid.UUID]:
"""The Python type for this SQL type is `uuid.UUID`."""
return uuid.UUID

@classmethod
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
"""Return the SQL type for the GUID type, which is CHAR(36) in SQL Server."""
return "CHAR(36)"

def process_bind_param(self, value: str, dialect: Dialect) -> Optional[str]:
"""Prepare the UUID value for binding to the database."""
if value is None:
return None
if not isinstance(value, uuid.UUID):
return str(uuid.UUID(value)) # Convert to string UUID if needed
return str(value)

def process_result_value(
self, value: Optional[str], dialect: Dialect
) -> Optional[uuid.UUID]:
"""Convert the string back to a UUID when retrieving from the database."""
if value is None:
return None
return uuid.UUID(value)
2 changes: 2 additions & 0 deletions tests/unit_tests/db_engine_specs/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sqlalchemy.types import String, TypeEngine, UnicodeText

from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_types.mssql_sql_types import GUID
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
Expand All @@ -46,6 +47,7 @@
("NCHAR(10)", UnicodeText, None, GenericDataType.STRING, False),
("NVARCHAR(10)", UnicodeText, None, GenericDataType.STRING, False),
("NTEXT", UnicodeText, None, GenericDataType.STRING, False),
("uniqueidentifier", GUID, None, GenericDataType.STRING, False),
],
)
def test_get_column_spec(
Expand Down

0 comments on commit 58edc79

Please sign in to comment.