diff --git a/dace/codegen/cppunparse.py b/dace/codegen/cppunparse.py index e4456e3e18..18ee00721b 100644 --- a/dace/codegen/cppunparse.py +++ b/dace/codegen/cppunparse.py @@ -746,9 +746,9 @@ def _Repr(self, t): def _Num(self, t): t_n = t.value if sys.version_info >= (3, 8) else t.n repr_n = repr(t_n) - # For complex values, use DTYPE_TO_TYPECLASS dictionary + # For complex values, use ``dtype_to_typeclass`` if isinstance(t_n, complex): - dtype = dtypes.DTYPE_TO_TYPECLASS[complex] + dtype = dtypes.dtype_to_typeclass(complex) # Handle large integer values if isinstance(t_n, int): diff --git a/dace/dtypes.py b/dace/dtypes.py index f0bac23958..f3ccddbfb7 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -360,6 +360,7 @@ class typeclass(object): 2. Enabling declaration syntax: `dace.float32[M,N]` 3. Enabling extensions such as `dace.struct` and `dace.vector` """ + def __init__(self, wrapped_type, typename=None): # Convert python basic types if isinstance(wrapped_type, str): @@ -600,6 +601,7 @@ def result_type_of(lhs, *rhs): class opaque(typeclass): """ A data type for an opaque object, useful for C bindings/libnodes, i.e., MPI_Request. """ + def __init__(self, typename): self.type = typename self.ctype = typename @@ -635,6 +637,7 @@ class pointer(typeclass): Example use: `dace.pointer(dace.struct(x=dace.float32, y=dace.float32))`. """ + def __init__(self, wrapped_typeclass): self._typeclass = wrapped_typeclass self.type = wrapped_typeclass.type @@ -680,6 +683,7 @@ class vector(typeclass): Example use: `dace.vector(dace.float32, 4)` becomes float4. """ + def __init__(self, dtype: typeclass, vector_length: int): self.vtype = dtype self.type = dtype.type @@ -737,6 +741,7 @@ class stringtype(pointer): Python/generated code marshalling. Used internally when `str` types are given """ + def __init__(self): super().__init__(int8) @@ -756,6 +761,7 @@ class struct(typeclass): Example use: `dace.struct(a=dace.int32, b=dace.float64)`. """ + def __init__(self, name, **fields_and_types): # self._data = fields_and_types self.type = ctypes.Structure @@ -859,6 +865,7 @@ class pyobject(opaque): It cannot be used inside a DaCe program, but can be passed back to other Python callbacks. Use with caution, and ensure the value is not removed by the garbage collector or the program will crash. """ + def __init__(self): super().__init__('pyobject') self.bytes = ctypes.sizeof(ctypes.c_void_p) @@ -892,6 +899,7 @@ def example(A: dace.float64[20], constant: dace.compiletime): In the above code, ``constant`` will be replaced with its value at call time during parsing. """ + @staticmethod def __descriptor__(): raise ValueError('All compile-time arguments must be provided in order to compile the SDFG ahead-of-time.') @@ -914,6 +922,7 @@ class callback(typeclass): """ Looks like ``dace.callback([None, ], *types)`` """ + def __init__(self, return_types, *variadic_args): from dace import data if return_types is None: @@ -1240,31 +1249,39 @@ class Typeclasses(aenum.AutoNumberEnum): complex128 = complex128 -DTYPE_TO_TYPECLASS = { - bool: typeclass(bool), - int: typeclass(int), - float: typeclass(float), - complex: typeclass(complex), - numpy.bool_: bool_, - numpy.int8: int8, - numpy.int16: int16, - numpy.int32: int32, - numpy.int64: int64, - numpy.intc: int32, - numpy.uint8: uint8, - numpy.uint16: uint16, - numpy.uint32: uint32, - numpy.uint64: uint64, - numpy.uintc: uint32, - numpy.float16: float16, - numpy.float32: float32, - numpy.float64: float64, - numpy.complex64: complex64, - numpy.complex128: complex128, - # FIXME - numpy.longlong: int64, - numpy.ulonglong: uint64 -} +_bool = bool + + +def dtype_to_typeclass(dtype=None): + DTYPE_TO_TYPECLASS = { + _bool: typeclass(_bool), + int: typeclass(int), + float: typeclass(float), + complex: typeclass(complex), + numpy.bool_: bool_, + numpy.int8: int8, + numpy.int16: int16, + numpy.int32: int32, + numpy.int64: int64, + numpy.intc: int32, + numpy.uint8: uint8, + numpy.uint16: uint16, + numpy.uint32: uint32, + numpy.uint64: uint64, + numpy.uintc: uint32, + numpy.float16: float16, + numpy.float32: float32, + numpy.float64: float64, + numpy.complex64: complex64, + numpy.complex128: complex128, + # FIXME + numpy.longlong: int64, + numpy.ulonglong: uint64 + } + if dtype is None: + return DTYPE_TO_TYPECLASS + return DTYPE_TO_TYPECLASS[dtype] + # Since this overrides the builtin bool, this should be after the # DTYPE_TO_TYPECLASS dictionary @@ -1354,6 +1371,7 @@ def isallowed(var, allow_recursive=False): class DebugInfo: """ Source code location identifier of a node/edge in an SDFG. Used for IDE and debugging purposes. """ + def __init__(self, start_line, start_column=0, end_line=-1, end_column=0, filename=None): self.start_line = start_line self.end_line = end_line if end_line >= 0 else start_line @@ -1397,6 +1415,7 @@ def json_to_typeclass(obj, context=None): def paramdec(dec): """ Parameterized decorator meta-decorator. Enables using `@decorator`, `@decorator()`, and `@decorator(...)` with the same function. """ + @wraps(dec) def layer(*args, **kwargs): from dace import data @@ -1478,20 +1497,22 @@ def can_allocate(storage: StorageType, schedule: ScheduleType): # Host-only allocation if storage in [StorageType.CPU_Heap, StorageType.CPU_Pinned, StorageType.CPU_ThreadLocal]: return schedule in [ - ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, ScheduleType.GPU_Default + ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, + ScheduleType.GPU_Default ] # GPU-global memory if storage is StorageType.GPU_Global: return schedule in [ - ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, ScheduleType.GPU_Default + ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, + ScheduleType.GPU_Default ] # FPGA-global memory if storage is StorageType.FPGA_Global: return schedule in [ - ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, ScheduleType.FPGA_Device, - ScheduleType.GPU_Default + ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, + ScheduleType.FPGA_Device, ScheduleType.GPU_Default ] # FPGA-local memory diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 733c3c7f62..2f77bd430d 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3240,7 +3240,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): raise DaceSyntaxError(self, target, 'Variable "{}" used before definition'.format(name)) new_data, rng = None, None - dtype_keys = tuple(dtypes.DTYPE_TO_TYPECLASS.keys()) + dtype_keys = tuple(dtypes.dtype_to_typeclass().keys()) if not (result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or (isinstance(result, str) and result in self.sdfg.arrays)): raise DaceSyntaxError( @@ -4653,14 +4653,14 @@ def visit_Num(self, node: NumConstant): if isinstance(node.n, bool): return dace.bool_(node.n) if isinstance(node.n, (int, float, complex)): - return dtypes.DTYPE_TO_TYPECLASS[type(node.n)](node.n) + return dtypes.dtype_to_typeclass(type(node.n))(node.n) return node.n def visit_Constant(self, node: ast.Constant): if isinstance(node.value, bool): return dace.bool_(node.value) if isinstance(node.value, (int, float, complex)): - return dtypes.DTYPE_TO_TYPECLASS[type(node.value)](node.value) + return dtypes.dtype_to_typeclass(type(node.value))(node.value) if isinstance(node.value, (str, bytes)): return StringLiteral(node.value) return node.value @@ -4745,7 +4745,7 @@ def _gettype(self, opnode: ast.AST) -> List[Tuple[str, str]]: result.append((operand, type(self.sdfg.arrays[operand]))) elif isinstance(operand, str) and operand in self.scope_arrays: result.append((operand, type(self.scope_arrays[operand]))) - elif isinstance(operand, tuple(dtypes.DTYPE_TO_TYPECLASS.keys())): + elif isinstance(operand, tuple(dtypes.dtype_to_typeclass().keys())): if isinstance(operand, (bool, numpy.bool_)): result.append((operand, 'BoolConstant')) else: diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index f55a65eabb..2e34b3077d 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -289,7 +289,7 @@ def _numpy_full(pv: ProgramVisitor, """ is_data = False if isinstance(fill_value, (Number, np.bool_)): - vtype = dtypes.DTYPE_TO_TYPECLASS[type(fill_value)] + vtype = dtypes.dtype_to_typeclass(type(fill_value)) elif isinstance(fill_value, sp.Expr): vtype = _sym_type(fill_value) else: @@ -546,10 +546,10 @@ def _arange(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs): if 'dtype' in kwargs and kwargs['dtype'] != None: dtype = kwargs['dtype'] if not isinstance(dtype, dtypes.typeclass): - dtype = dtypes.DTYPE_TO_TYPECLASS[dtype] + dtype = dtypes.dtype_to_typeclass(dtype) outname, outarr = sdfg.add_temp_transient(shape, dtype) else: - dtype = dtypes.DTYPE_TO_TYPECLASS[type(shape[0])] + dtype = dtypes.dtype_to_typeclass(type(shape[0])) outname, outarr = sdfg.add_temp_transient(shape, dtype) state.add_mapped_tasklet(name="_numpy_arange_", @@ -1076,8 +1076,8 @@ def _array_array_where(visitor: ProgramVisitor, left_arr = sdfg.arrays.get(left_operand, None) right_arr = sdfg.arrays.get(right_operand, None) - left_type = left_arr.dtype if left_arr else dtypes.DTYPE_TO_TYPECLASS[type(left_operand)] - right_type = right_arr.dtype if right_arr else dtypes.DTYPE_TO_TYPECLASS[type(right_operand)] + left_type = left_arr.dtype if left_arr else dtypes.dtype_to_typeclass(type(left_operand)) + right_type = right_arr.dtype if right_arr else dtypes.dtype_to_typeclass(type(right_operand)) # Implicit Python coversion implemented as casting arguments = [cond_arr, left_arr or left_type, right_arr or right_type] @@ -1356,11 +1356,11 @@ def _np_result_type(nptypes): # Fix for np.result_type returning platform-dependent types, # e.g. np.longlong restype = np.result_type(*nptypes) - if restype.type not in dtypes.DTYPE_TO_TYPECLASS.keys(): - for k in dtypes.DTYPE_TO_TYPECLASS.keys(): + if restype.type not in dtypes.dtype_to_typeclass().keys(): + for k in dtypes.dtype_to_typeclass().keys(): if k == restype.type: - return dtypes.DTYPE_TO_TYPECLASS[k] - return dtypes.DTYPE_TO_TYPECLASS[restype.type] + return dtypes.dtype_to_typeclass(k) + return dtypes.dtype_to_typeclass(restype.type) def _sym_type(expr: Union[symbolic.symbol, sp.Basic]) -> dtypes.typeclass: @@ -1393,7 +1393,7 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi datatypes.append(arg.dtype) dtypes_for_result.append(_representative_num(arg.dtype)) elif isinstance(arg, (Number, np.bool_)): - datatypes.append(dtypes.DTYPE_TO_TYPECLASS[type(arg)]) + datatypes.append(dtypes.dtype_to_typeclass(type(arg))) dtypes_for_result.append(arg) elif symbolic.issymbolic(arg): datatypes.append(_sym_type(arg)) @@ -1668,13 +1668,13 @@ def _array_const_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, le left_shape = left_arr.shape storage = left_arr.storage right_arr = None - right_type = dtypes.DTYPE_TO_TYPECLASS[type(right_operand)] + right_type = dtypes.dtype_to_typeclass(type(right_operand)) right_shape = [1] arguments = [left_arr, right_operand] tasklet_args = ['__in1', f'({str(right_operand)})'] else: left_arr = None - left_type = dtypes.DTYPE_TO_TYPECLASS[type(left_operand)] + left_type = dtypes.dtype_to_typeclass(type(left_operand)) left_shape = [1] right_arr = sdfg.arrays[right_operand] right_type = right_arr.dtype @@ -2229,7 +2229,7 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op type1 = arr1.dtype.type type2 = arr2.dtype.type - restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type] + restype = dace.dtype_to_typeclass(np.result_type(type1, type2).type) op3, arr3 = sdfg.add_temp_transient(output_shape, restype, arr1.storage) @@ -3517,7 +3517,7 @@ def implement_ufunc(visitor: ProgramVisitor, ast_node: ast.Call, sdfg: SDFG, sta ufunc_impl['operator']) if 'dtype' in kwargs.keys(): dtype = kwargs['dtype'] - if dtype in dtypes.DTYPE_TO_TYPECLASS.keys(): + if dtype in dtypes.dtype_to_typeclass().keys(): result_type = dtype # Create output data (if needed) @@ -3709,7 +3709,7 @@ def implement_ufunc_reduce(visitor: ProgramVisitor, ast_node: ast.Call, sdfg: SD datadesc = sdfg.arrays[arg] result_type = datadesc.dtype elif isinstance(arg, (Number, np.bool_)): - result_type = dtypes.DTYPE_TO_TYPECLASS[type(arg)] + result_type = dtypes.dtype_to_typeclass(type(arg)) elif isinstance(arg, sp.Basic): result_type = _sym_type(arg) @@ -4018,7 +4018,7 @@ def implement_ufunc_outer(visitor: ProgramVisitor, ast_node: ast.Call, sdfg: SDF ufunc_impl['operator']) if 'dtype' in kwargs.keys(): dtype = kwargs['dtype'] - if dtype in dtypes.DTYPE_TO_TYPECLASS.keys(): + if dtype in dtypes.dtype_to_typeclass().keys(): result_type = dtype # Create output data (if needed) @@ -4412,9 +4412,9 @@ def _make_datatype_converter(typeclass: str): if typeclass == "bool": dtype = dace.bool elif typeclass in {"int", "float", "complex"}: - dtype = dtypes.DTYPE_TO_TYPECLASS[eval(typeclass)] + dtype = dtypes.dtype_to_typeclass(eval(typeclass)) else: - dtype = dtypes.DTYPE_TO_TYPECLASS[eval("np.{}".format(typeclass))] + dtype = dtypes.dtype_to_typeclass(eval("np.{}".format(typeclass))) @oprepo.replaces(typeclass) @oprepo.replaces("dace.{}".format(typeclass)) @@ -4711,7 +4711,7 @@ def _cupy_full(pv: ProgramVisitor, the fill value. """ if isinstance(fill_value, (Number, np.bool_)): - vtype = dtypes.DTYPE_TO_TYPECLASS[type(fill_value)] + vtype = dtypes.dtype_to_typeclass(type(fill_value)) elif isinstance(fill_value, sp.Expr): vtype = _sym_type(fill_value) else: diff --git a/dace/libraries/blas/nodes/gemm.py b/dace/libraries/blas/nodes/gemm.py index 83be99d78b..d78e54eb6e 100644 --- a/dace/libraries/blas/nodes/gemm.py +++ b/dace/libraries/blas/nodes/gemm.py @@ -30,12 +30,12 @@ def _cast_to_dtype_str(value, dtype: dace.dtypes.typeclass) -> str: cast_value = complex(value) return "dace.{type}({real}, {imag})".format( - type=dace.DTYPE_TO_TYPECLASS[dtype].to_string(), + type=dace.dtype_to_typeclass(dtype).to_string(), real=cast_value.real, imag=cast_value.imag, ) else: - return "dace.{}({})".format(dace.DTYPE_TO_TYPECLASS[dtype].to_string(), value) + return "dace.{}({})".format(dace.dtype_to_typeclass(dtype).to_string(), value) @dace.library.expansion @@ -52,7 +52,7 @@ def make_sdfg(node, parent_state, parent_sdfg): dtype_a = outer_array_a.dtype.type dtype_b = outer_array_b.dtype.type - dtype_c = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a, dtype_b).type] + dtype_c = dace.dtype_to_typeclass(np.result_type(dtype_a, dtype_b).type) if node.transA: trans_shape_a = list(reversed(shape_a)) @@ -518,7 +518,7 @@ def expansion(node, parent_state, parent_sdfg, num_pes=32, tile_size_m=None): dtype_a = outer_array_a.dtype.type dtype_b = outer_array_b.dtype.type - dtype_c = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a, dtype_b).type] + dtype_c = dace.dtype_to_typeclass(np.result_type(dtype_a, dtype_b).type) shape_c = (shape_a[0], shape_b[1]) if node.transA: raise NotImplementedError("GEMM FPGA expansion not implemented for transposed A.") diff --git a/dace/libraries/sparse/nodes/csrmm.py b/dace/libraries/sparse/nodes/csrmm.py index d5707b400d..b21867b0e9 100644 --- a/dace/libraries/sparse/nodes/csrmm.py +++ b/dace/libraries/sparse/nodes/csrmm.py @@ -28,12 +28,12 @@ def _cast_to_dtype_str(value, dtype: dace.dtypes.typeclass) -> str: cast_value = complex(value) return "dace.{type}({real}, {imag})".format( - type=dace.DTYPE_TO_TYPECLASS[dtype].to_string(), + type=dace.dtype_to_typeclass(dtype).to_string(), real=cast_value.real, imag=cast_value.imag, ) else: - return "dace.{}({})".format(dace.DTYPE_TO_TYPECLASS[dtype].to_string(), value) + return "dace.{}({})".format(dace.dtype_to_typeclass(dtype).to_string(), value) def _get_csrmm_operands(node, diff --git a/dace/libraries/sparse/nodes/csrmv.py b/dace/libraries/sparse/nodes/csrmv.py index 7b69a7af00..cc3e98eec4 100644 --- a/dace/libraries/sparse/nodes/csrmv.py +++ b/dace/libraries/sparse/nodes/csrmv.py @@ -27,12 +27,12 @@ def _cast_to_dtype_str(value, dtype: dace.dtypes.typeclass) -> str: cast_value = complex(value) return "dace.{type}({real}, {imag})".format( - type=dace.DTYPE_TO_TYPECLASS[dtype].to_string(), + type=dace.dtype_to_typeclass(dtype).to_string(), real=cast_value.real, imag=cast_value.imag, ) else: - return "dace.{}({})".format(dace.DTYPE_TO_TYPECLASS[dtype].to_string(), value) + return "dace.{}({})".format(dace.dtype_to_typeclass(dtype).to_string(), value) def _get_csrmv_operands(node: dace.sdfg.nodes.LibraryNode, diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 2e35218a3d..eb43a99a54 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -769,7 +769,7 @@ def add_symbol(self, name, stype): if name in self.symbols: raise FileExistsError('Symbol "%s" already exists in SDFG' % name) if not isinstance(stype, dtypes.typeclass): - stype = dtypes.DTYPE_TO_TYPECLASS[stype] + stype = dtypes.dtype_to_typeclass(stype) self.symbols[name] = stype def remove_symbol(self, name): diff --git a/dace/serialize.py b/dace/serialize.py index ef07530905..4afaef69ee 100644 --- a/dace/serialize.py +++ b/dace/serialize.py @@ -47,7 +47,7 @@ def to_json(obj): return None try: - dtype_json = dace.dtypes.DTYPE_TO_TYPECLASS[obj.dtype.type].to_json() + dtype_json = dace.dtypes.dtype_to_typeclass(obj.dtype.type).to_json() except KeyError: dtype_json = str(obj.dtype) @@ -69,12 +69,19 @@ def to_json(obj): # All classes annotated with the make_properties decorator will register # themselves here. } -# Also register each of the basic types -_DACE_SERIALIZE_TYPES.update({v.to_string(): v for v in dace.dtypes.DTYPE_TO_TYPECLASS.values()}) def get_serializer(type_name): - return _DACE_SERIALIZE_TYPES[type_name] + if type_name in _DACE_SERIALIZE_TYPES: + return _DACE_SERIALIZE_TYPES[type_name] + + # Also try each of the basic types + basic_dtypes = {v.to_string(): v for v in dace.dtypes.dtype_to_typeclass().values()} + if type_name in basic_dtypes: + return basic_dtypes[type_name] + + raise KeyError(f'Serializer for type "{type_name}" was not found. Object type does not support serialization. ' + 'Please implement serialization by decorating the class with ``@serializable``.') # Decorator for objects that should be serializable, but don't call @@ -144,7 +151,7 @@ def from_json(obj, context=None, known_type=None): if t: try: - deserialized = _DACE_SERIALIZE_TYPES[t].from_json(obj, context=context) + deserialized = get_serializer(t).from_json(obj, context=context) except Exception as ex: if config.Config.get_bool('testing', 'deserialize_exception'): raise diff --git a/dace/symbolic.py b/dace/symbolic.py index f3dfcfb36d..8342725349 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -42,7 +42,7 @@ def __new__(cls, name=None, dtype=DEFAULT_SYMBOL_TYPE, **assumptions): if not isinstance(dtype, dtypes.typeclass): raise TypeError('dtype must be a DaCe type, got %s' % str(dtype)) - dkeys = [k for k, v in dtypes.DTYPE_TO_TYPECLASS.items() if v == dtype] + dkeys = [k for k, v in dtypes.dtype_to_typeclass().items() if v == dtype] is_integer = [issubclass(k, int) or issubclass(k, numpy.integer) for k in dkeys] if 'integer' in assumptions or not numpy.any(is_integer): # Using __xnew__ as the regular __new__ is cached, which leads