diff --git a/gem/gem.py b/gem/gem.py index c37e4352..350a2283 100644 --- a/gem/gem.py +++ b/gem/gem.py @@ -415,12 +415,18 @@ def __new__(cls, base, exponent): # Constant folding if isinstance(base, Zero): - dtype = numpy.result_type(base.dtype, exponent.dtype) + if isinstance(exponent, Constant): + dtype = numpy.result_type(base.dtype, exponent.dtype) + else: + dtype = base.dtype if isinstance(exponent, Zero): raise ValueError("cannot solve 0^0") return Zero(dtype=dtype) elif isinstance(exponent, Zero): - dtype = numpy.result_type(base.dtype, exponent.dtype) + if isinstance(base, Constant): + dtype = numpy.result_type(base.dtype, exponent.dtype) + else: + dtype = exponent.dtype return Literal(1, dtype=dtype) elif isinstance(base, Constant) and isinstance(exponent, Constant): dtype = numpy.result_type(base.dtype, exponent.dtype)