From 50eb4a3a291f81ccaac8e0869b75eab869bf90c0 Mon Sep 17 00:00:00 2001 From: Angus Gibson Date: Thu, 7 Nov 2024 11:02:33 +1100 Subject: [PATCH] Refine dtype logic for Power node We know `0 ** x` and `x ** 0`, but we can't assume that `x` is an instance of `Constant` with a defined dtype. --- gem/gem.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/gem/gem.py b/gem/gem.py index c37e4352..350a2283 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -415,12 +415,18 @@ def __new__(cls, base, exponent): # Constant folding if isinstance(base, Zero): - dtype = numpy.result_type(base.dtype, exponent.dtype) + if isinstance(exponent, Constant): + dtype = numpy.result_type(base.dtype, exponent.dtype) + else: + dtype = base.dtype if isinstance(exponent, Zero): raise ValueError("cannot solve 0^0") return Zero(dtype=dtype) elif isinstance(exponent, Zero): - dtype = numpy.result_type(base.dtype, exponent.dtype) + if isinstance(base, Constant): + dtype = numpy.result_type(base.dtype, exponent.dtype) + else: + dtype = exponent.dtype return Literal(1, dtype=dtype) elif isinstance(base, Constant) and isinstance(exponent, Constant): dtype = numpy.result_type(base.dtype, exponent.dtype)