Skip to content

Commit

Permalink
refactor(dtype): move all the castable logic to a single function (#8335
Browse files Browse the repository at this point in the history
)

This makes it easier to understand the implicit casting rules.
  • Loading branch information
kszucs authored Feb 14, 2024
1 parent a20f44a commit 580536c
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 141 deletions.
5 changes: 4 additions & 1 deletion ibis/common/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,10 @@ def _from_str(value):
elif lower == "today":
return datetime.datetime.today()

value = dateutil.parser.parse(value)
try:
value = dateutil.parser.parse(value)
except dateutil.parser.ParserError:
raise TypeError(f"Unable to normalize {value} to timestamp")
return value.replace(tzinfo=normalize_timezone(value.tzinfo))


Expand Down
112 changes: 107 additions & 5 deletions ibis/expr/datatypes/cast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from public import public

Expand All @@ -24,16 +24,118 @@ def cast(source: str | dt.DataType, target: str | dt.DataType, **kwargs) -> dt.D
return target


@public
def castable(source: dt.DataType, target: dt.DataType, value: Any = None) -> bool:
"""Return whether source ir type is implicitly castable to target."""
from ibis.expr.datatypes.value import normalizable

if source == target:
return True
elif source.is_null():
# The null type is castable to any type, even if the target type is *not*
# nullable.
#
# We handle the promotion of `null + !T -> T` at the `castable` call site.
#
# It might be possible to build a system with a single function that tries
# to promote types and use the exception to indicate castability, but that
# is a deeper refactor to be tackled later.
#
# See https://github.com/ibis-project/ibis/issues/2891 for the bug report
return True
elif target.is_boolean():
if source.is_boolean():
return True
elif source.is_integer():
return value in (0, 1)
else:
return False
elif target.is_integer():
# TODO(kszucs): ideally unsigned to signed shouldn't be allowed but that
# breaks the integral promotion rule logic in rules.py
if source.is_integer():
if value is not None:
return normalizable(target, value)
else:
return source.nbytes <= target.nbytes
else:
return False
elif target.is_floating():
if source.is_floating():
return source.nbytes <= target.nbytes
else:
return source.is_integer()
elif target.is_decimal():
if source.is_decimal():
downcast_precision = (
source.precision is not None
and target.precision is not None
and source.precision < target.precision
)
downcast_scale = (
source.scale is not None
and target.scale is not None
and source.scale < target.scale
)
return not (downcast_precision or downcast_scale)
else:
return source.is_numeric()
elif target.is_string():
return source.is_string() or source.is_uuid()
elif target.is_uuid():
return source.is_uuid() or source.is_string()
elif target.is_date() or target.is_timestamp():
if source.is_string():
return value is not None and normalizable(target, value)
else:
return source.is_timestamp() or source.is_date()
elif target.is_interval():
if source.is_interval():
return source.unit == target.unit
else:
return source.is_integer()
elif target.is_time():
if source.is_string():
return value is not None and normalizable(target, value)
else:
return source.is_time()
elif target.is_json():
return (
source.is_json()
or source.is_string()
or source.is_floating()
or source.is_integer()
)
elif target.is_array():
return source.is_array() and castable(source.value_type, target.value_type)
elif target.is_map():
return (
source.is_map()
and castable(source.key_type, target.key_type)
and castable(source.value_type, target.value_type)
)
elif target.is_struct():
return source.is_struct() and all(
castable(source[field], target[field]) for field in target.names
)
elif target.is_geospatial():
return source.is_geospatial() or source.is_array()
else:
return isinstance(target, source.__class__)


@public
def higher_precedence(left: dt.DataType, right: dt.DataType) -> dt.DataType:
nullable = left.nullable or right.nullable

if left.castable(right, upcast=True):
if left.castable(right):
return right.copy(nullable=nullable)
elif right.castable(left, upcast=True):
elif right.castable(left):
return left.copy(nullable=nullable)

raise IbisTypeError(f"Cannot compute precedence for `{left}` and `{right}` types")
else:
raise IbisTypeError(
f"Cannot compute precedence for `{left}` and `{right}` types"
)


@public
Expand Down
130 changes: 3 additions & 127 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def cast(self, other, **kwargs):

def castable(self, to, **kwargs) -> bool:
"""Check whether this type is castable to another."""
return isinstance(to, self.__class__)
from ibis.expr.datatypes.cast import castable

return castable(self, to, **kwargs)

@classmethod
def from_string(cls, value) -> Self:
Expand Down Expand Up @@ -474,19 +476,6 @@ class Null(Primitive):
scalar = "NullScalar"
column = "NullColumn"

def castable(self, to, **kwargs) -> bool:
# The null type is castable to any type, even if the target type is *not*
# nullable.
#
# We handle the promotion of `null + !T -> T` at the `castable` call site.
#
# It might be possible to build a system with a single function that tries
# to promote types and use the exception to indicate castability, but that
# is a deeper refactor to be tackled later.
#
# See https://github.com/ibis-project/ibis/issues/2891 for the bug report
return True


@public
class Boolean(Primitive):
Expand Down Expand Up @@ -524,11 +513,6 @@ class Integer(Primitive, Numeric):
def nbytes(self) -> int:
"""Return the number of bytes used to store values of this type."""

def castable(self, to, value: int | None = None, **kwargs) -> bool:
return (isinstance(to, Boolean) and value in (0, 1)) or isinstance(
to, (Floating, Decimal, JSON, Interval)
)


@public
class String(Variadic, Singleton):
Expand All @@ -544,24 +528,6 @@ class String(Variadic, Singleton):
scalar = "StringScalar"
column = "StringColumn"

def castable(self, to, value: str | None = None, **kwargs) -> bool:
def can_parse(value: str | None) -> bool:
import pandas as pd

if value is None:
return False

try:
pd.Timestamp(value)
except ValueError:
return False
else:
return True

return isinstance(to, (String, JSON, UUID)) or (
isinstance(to, (Date, Time, Timestamp)) and can_parse(value)
)


@public
class Binary(Variadic, Singleton):
Expand Down Expand Up @@ -593,9 +559,6 @@ class Date(Temporal, Primitive):
scalar = "DateScalar"
column = "DateColumn"

def castable(self, to, **kwargs) -> bool:
return isinstance(to, (Date, Timestamp))


@public
class Time(Temporal, Primitive):
Expand Down Expand Up @@ -660,9 +623,6 @@ def _pretty_piece(self) -> str:
else:
return ""

def castable(self, to, **kwargs) -> bool:
return isinstance(to, (Date, Timestamp))


@public
class SignedInteger(Integer):
Expand All @@ -674,23 +634,6 @@ def bounds(self):
upper = (1 << exp) - 1
return Bounds(lower=~upper, upper=upper)

def castable(self, to, value: int | None = None, **kwargs) -> bool:
if isinstance(to, SignedInteger):
return self.nbytes <= to.nbytes
elif isinstance(to, UnsignedInteger):
if value is not None:
# TODO(kszucs): we may not need to actually check the value since the
# literal construction now checks for bounds and doesn't use castable()
# anymore
return to.bounds.lower <= value <= to.bounds.upper
else:
return (
to.bounds.upper - to.bounds.lower
>= self.bounds.upper - self.bounds.lower
)
else:
return super().castable(to, value=value, **kwargs)


@public
class UnsignedInteger(Integer):
Expand All @@ -702,22 +645,6 @@ def bounds(self):
upper = (1 << exp) - 1
return Bounds(lower=0, upper=upper)

def castable(self, to, value: int | None = None, **kwargs) -> bool:
if isinstance(to, UnsignedInteger):
return self.nbytes <= to.nbytes
elif isinstance(to, SignedInteger):
if value is not None:
# TODO(kszucs): we may not need to actually check the value since the
# literal construction now checks for bounds and doesn't use castable()
# anymore
return to.bounds.lower <= value <= to.bounds.upper
else:
return (to.bounds.upper - to.bounds.lower) >= (
self.bounds.upper - self.bounds.lower
)
else:
return super().castable(to, value=value, **kwargs)


@public
class Floating(Primitive, Numeric):
Expand All @@ -731,14 +658,6 @@ class Floating(Primitive, Numeric):
def nbytes(self) -> int: # pragma: no cover
"""Return the number of bytes used to store values of this type."""

def castable(self, to, upcast: bool = False, **kwargs) -> bool:
return isinstance(to, (Decimal, JSON)) or (
isinstance(to, Floating)
# double -> float must be allowed because
# float literals are inferred as doubles
and ((not upcast) or to.nbytes >= self.nbytes)
)


@public
class Int8(SignedInteger):
Expand Down Expand Up @@ -872,25 +791,6 @@ def _pretty_piece(self) -> str:

return f"({', '.join(args)})"

def castable(self, to, **kwargs) -> bool:
if not isinstance(to, Decimal):
return False

to_prec = to.precision
self_prec = self.precision
to_sc = to.scale
self_sc = self.scale
return (
# If either sides precision and scale are both `None`, return `True`.
to_prec is None
and to_sc is None
or self_prec is None
and self_sc is None
# towise, return `True` unless we are downcasting precision or scale.
or (to_prec is None or (self_prec is not None and to_prec >= self_prec))
and (to_sc is None or (self_sc is not None and to_sc >= self_sc))
)


@public
class Interval(Parametric):
Expand All @@ -911,9 +811,6 @@ def resolution(self):
def _pretty_piece(self) -> str:
return f"('{self.unit.value}')"

def castable(self, to, **kwargs) -> bool:
return isinstance(to, Interval) and self.unit == to.unit


@public
class Struct(Parametric, MapSet):
Expand Down Expand Up @@ -972,12 +869,6 @@ def _pretty_piece(self) -> str:
pairs = ", ".join(map("{}: {}".format, self.names, self.types))
return f"<{pairs}>"

def castable(self, to, **kwargs) -> bool:
return isinstance(to, Struct) and all(
self[field].castable(to[field], **kwargs)
for field in self.keys() & to.keys()
)


T = TypeVar("T", bound=DataType, covariant=True)

Expand All @@ -995,11 +886,6 @@ class Array(Variadic, Parametric, Generic[T]):
def _pretty_piece(self) -> str:
return f"<{self.value_type}>"

def castable(self, to, **kwargs) -> bool:
return isinstance(to, GeoSpatial) or (
isinstance(to, Array) and self.value_type.castable(to.value_type, **kwargs)
)


K = TypeVar("K", bound=DataType, covariant=True)
V = TypeVar("V", bound=DataType, covariant=True)
Expand All @@ -1019,13 +905,6 @@ class Map(Variadic, Parametric, Generic[K, V]):
def _pretty_piece(self) -> str:
return f"<{self.key_type}, {self.value_type}>"

def castable(self, to, **kwargs) -> bool:
return (
isinstance(to, Map)
and self.key_type.castable(to.key_type, **kwargs)
and self.value_type.castable(to.value_type, **kwargs)
)


@public
class JSON(Variadic):
Expand Down Expand Up @@ -1117,9 +996,6 @@ class UUID(DataType):
scalar = "UUIDScalar"
column = "UUIDColumn"

def castable(self, to, **kwargs) -> bool:
return super().castable(to, **kwargs) or isinstance(to, String)


@public
class MACADDR(DataType):
Expand Down
Loading

0 comments on commit 580536c

Please sign in to comment.