diff --git a/sqlalchemy-stubs/sql/sqltypes.pyi b/sqlalchemy-stubs/sql/sqltypes.pyi index b3d84ba..c887e57 100644 --- a/sqlalchemy-stubs/sql/sqltypes.pyi +++ b/sqlalchemy-stubs/sql/sqltypes.pyi @@ -95,16 +95,19 @@ class Numeric(_LookupExpressionAdapter, TypeEngine[decimal.Decimal]): def result_processor(self, dialect: Dialect, coltype: Any) -> Optional[Callable[[Optional[Any]], Optional[Union[float, decimal.Decimal]]]]: ... -class Float(Numeric): +class Float(_LookupExpressionAdapter, TypeEngine[float]): __visit_name__: str = ... - scale: Optional[int] = ... precision: Optional[int] = ... - asdecimal: bool = ... + scale: Optional[int] = ... decimal_return_scale: Optional[int] = ... + asdecimal: bool = ... def __init__(self, precision: Optional[int] = ..., asdecimal: bool = ..., decimal_return_scale: Optional[int] = ..., **kwargs: Any) -> None: ... - def result_processor(self, dialect: Dialect, coltype: Any) -> Optional[Callable[[Optional[Any]], - Optional[Union[float, decimal.Decimal]]]]: ... + def literal_processor(self, dialect: Dialect) -> Callable[[float], str]: ... + @property + def python_type(self) -> Type[float]: ... + def bind_processor(self, dialect: Dialect) -> Optional[Callable[[Optional[str]], float]]: ... + def result_processor(self, dialect: Dialect, coltype: Any) -> Optional[Callable[[Optional[Any]], Optional[float]]]: ... class DateTime(_LookupExpressionAdapter, TypeEngine[datetime]): __visit_name__: str = ... diff --git a/test/test-data/sqlalchemy-sql-sqltypes.test b/test/test-data/sqlalchemy-sql-sqltypes.test index 806acd8..9bf31c1 100644 --- a/test/test-data/sqlalchemy-sql-sqltypes.test +++ b/test/test-data/sqlalchemy-sql-sqltypes.test @@ -36,3 +36,23 @@ main:6: error: No overload variant of "UnicodeText" matches argument types "int" main:6: note: Possible overload variants: main:6: note: def __init__(self, length: Optional[int] = ..., collation: Optional[str] = ..., convert_unicode: bool = ..., _warn_on_bytestring: bool = ...) -> UnicodeText main:6: note: def __init__(self, length: Optional[int] = ..., collation: Optional[str] = ..., convert_unicode: str = ..., unicode_error: Optional[str] = ..., _warn_on_bytestring: bool = ...) -> UnicodeText + +[case testFloatAndDecimal] +from decimal import Decimal +from sqlalchemy import Column, Float, Numeric, Integer +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class Numbers(Base): + __tablename__ = "numbers" + id_ = Column(Integer, primary_key=True) + + c_numeric = Column(Numeric, nullable=False) + c_float = Column(Float, nullable=False) + +numbers_float = Numbers(c_numeric=1.0, c_float=1.0) +numbers_decimal = Numbers(c_numeric=Decimal(1.0), c_float=Decimal(1.0)) +[out] +main:14: error: Incompatible type for "c_numeric" of "Numbers" (got "float", expected "Decimal") +main:15: error: Incompatible type for "c_float" of "Numbers" (got "Decimal", expected "float")