Skip to content

Commit

Permalink
SVE fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Oct 31, 2024
1 parent 9947fd3 commit 44ce520
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 23 deletions.
24 changes: 12 additions & 12 deletions dace/codegen/targets/sve/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,14 @@ def generate_read(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, edge: grap
src_type = edge.src.out_connectors[edge.src_conn]
if util.is_vector(src_type) and util.is_vector(dst_type):
# Directly read from shared vector register
code.write(f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = {edge.data.data};')
code.write(f'{util.TYPE_TO_SVE[dst_type.base_type]} {dst_name} = {edge.data.data};')
elif util.is_scalar(src_type) and util.is_scalar(dst_type):
# Directly read from shared scalar register
code.write(f'{dst_type} {dst_name} = {edge.data.data};')
elif util.is_scalar(src_type) and util.is_vector(dst_type):
# Scalar broadcast from shared scalar register
code.write(
f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type.type]}({edge.data.data});'
f'{util.TYPE_TO_SVE[dst_type.base_type]} {dst_name} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type.base_type]}({edge.data.data});'
)
else:
raise util.NotSupportedError('Unsupported Code->Code edge')
Expand All @@ -183,13 +183,13 @@ def generate_read(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, edge: grap
stride = edge.data.get_stride(sdfg, map)

# First part of the declaration is `type name`
load_lhs = '{} {}'.format(util.TYPE_TO_SVE[dst_type.type], dst_name)
load_lhs = '{} {}'.format(util.TYPE_TO_SVE[dst_type.base_type], dst_name)

# long long issue casting
ptr_cast = ''
if dst_type == dtypes.int64:
if dst_type.base_type == dtypes.int64:
ptr_cast = '(int64_t*) '
elif dst_type == dtypes.uint64:
elif dst_type.base_type == dtypes.uint64:
ptr_cast = '(uint64_t*) '

# Regular load and gather share the first arguments
Expand All @@ -212,14 +212,14 @@ def generate_read(self, sdfg: SDFG, state: SDFGState, map: nodes.Map, edge: grap
src_type = desc.dtype
if util.is_vector(src_type) and util.is_vector(dst_type):
# Directly read from shared vector register
code.write(f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = {edge.data.data};')
code.write(f'{util.TYPE_TO_SVE[dst_type.base_type]} {dst_name} = {edge.data.data};')
elif util.is_scalar(src_type) and util.is_scalar(dst_type):
# Directly read from shared scalar register
code.write(f'{dst_type} {dst_name} = {edge.data.data};')
elif util.is_scalar(src_type) and util.is_vector(dst_type):
# Scalar broadcast from shared scalar register
code.write(
f'{util.TYPE_TO_SVE[dst_type.type]} {dst_name} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type.type]}({edge.data.data});'
f'{util.TYPE_TO_SVE[dst_type.base_type]} {dst_name} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type.base_type]}({edge.data.data});'
)
else:
raise util.NotSupportedError('Unsupported Scalar->Code edge')
Expand Down Expand Up @@ -259,7 +259,7 @@ def generate_out_register(self,
# Create temporary registers
ctype = None
if util.is_vector(src_type):
ctype = util.TYPE_TO_SVE[src_type.type]
ctype = util.TYPE_TO_SVE[src_type.base_type]
elif util.is_scalar(src_type):
ctype = src_type.ctype
else:
Expand Down Expand Up @@ -295,7 +295,7 @@ def generate_writeback(self, sdfg: SDFG, state: SDFGState, map: nodes.Map,
code.write(f'{edge.data.data} = {src_name};')
elif util.is_scalar(src_type) and util.is_vector(dst_type):
# Scalar broadcast to shared vector register
code.write(f'{edge.data.data} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type.type]}({src_name});')
code.write(f'{edge.data.data} = svdup_{util.TYPE_TO_SVE_SUFFIX[dst_type]}({src_name});')
else:
raise util.NotSupportedError('Unsupported Code->Code edge')
elif isinstance(dst_node, nodes.AccessNode):
Expand All @@ -315,9 +315,9 @@ def generate_writeback(self, sdfg: SDFG, state: SDFGState, map: nodes.Map,

# long long fix
ptr_cast = ''
if src_type == dtypes.int64:
if src_type.base_type == dtypes.int64:
ptr_cast = '(int64_t*) '
elif src_type == dtypes.uint64:
elif src_type.base_type == dtypes.uint64:
ptr_cast = '(uint64_t*) '

store_args = '{}, {}'.format(
Expand Down Expand Up @@ -368,7 +368,7 @@ def allocate_array(self, sdfg: SDFG, cfg: state.ControlFlowRegion, dfg: SDFGStat
nodedesc: data.Data, global_stream: CodeIOStream, declaration_stream: CodeIOStream,
allocation_stream: CodeIOStream) -> None:
if nodedesc.storage == dtypes.StorageType.SVE_Register:
sve_type = util.TYPE_TO_SVE[nodedesc.dtype]
sve_type = util.TYPE_TO_SVE[nodedesc.dtype.base_type]
self.dispatcher.defined_vars.add(node.data, DefinedType.Scalar, sve_type)
return

Expand Down
2 changes: 1 addition & 1 deletion dace/codegen/targets/sve/type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def assert_type_compatibility(defined_symbols: collections.OrderedDict, types: t

# Check if we can represent the types in SVE
for t in types:
if util.get_base_type(t).type not in util.TYPE_TO_SVE:
if util.get_base_type(t) not in util.TYPE_TO_SVE:
raise IncompatibleTypeError('Not available in SVE', types)

# Check if we have different vector types (would require casting, not implemented yet)
Expand Down
14 changes: 7 additions & 7 deletions dace/codegen/targets/sve/unparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,14 @@ def dispatch_expect(self, tree: ast.AST, expect: dtypes.typeclass):
# Unparsing a scalar
if isinstance(expect, dtypes.vector):
# Expecting a vector: duplicate the scalar
if expect in [dtypes.bool_, dtypes.bool]:
if expect.base_type in [dtypes.bool_, dtypes.bool]:
# Special case for duplicating boolean into predicate
suffix = f'b{self.pred_bits}'
#self.write(f'svptrue_{suffix}()')
self.dispatch_expect(tree, expect.base_type)
self.write(f' ? svptrue_{suffix}() : svpfalse_b()')
else:
self.write(f'svdup_{util.TYPE_TO_SVE_SUFFIX[expect.type]}(')
self.write(f'svdup_{util.TYPE_TO_SVE_SUFFIX[expect.base_type]}(')
self.dispatch_expect(tree, expect.base_type)
self.write(')')

Expand Down Expand Up @@ -271,9 +271,9 @@ def push_to_stream(self, t, target):

# Casting in case of `long long`
stream_type = copy.copy(stream_type)
if stream_type == dtypes.int64:
if stream_type == dtypes.int64 or getattr(stream_type, 'base_type', False) == dtypes.int64:
stream_type.ctype = 'int64_t'
elif stream_type == dtypes.uint64:
elif stream_type == dtypes.uint64 or getattr(stream_type, 'base_type', False) == dtypes.uint64:
stream_type.ctype = 'uint64_t'

# Create a temporary array on the heap, where we will copy the SVE register contents to
Expand Down Expand Up @@ -337,9 +337,9 @@ def vector_reduction_expr(self, edge, dtype, rhs):
ptr_cast = ''
src_type = edge.src.out_connectors[edge.src_conn]

if src_type == dtypes.int64:
if src_type.base_type == dtypes.int64:
ptr_cast = '(int64_t*) '
elif src_type == dtypes.uint64:
elif src_type.base_type == dtypes.uint64:
ptr_cast = '(uint64_t*) '

store_args = '{}, {}'.format(
Expand Down Expand Up @@ -396,7 +396,7 @@ def _Assign(self, t):
lhs_type = rhs_type
if isinstance(rhs_type, dtypes.vector):
# SVE register is possible (declare it as svXXX_t)
self.fill(util.TYPE_TO_SVE[rhs_type.type])
self.fill(util.TYPE_TO_SVE[rhs_type.base_type])
self.write(' ')
# Define the new symbol as vector
self.defined_symbols.update({target.id: rhs_type})
Expand Down
4 changes: 3 additions & 1 deletion dace/codegen/targets/sve/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def get_internal_symbols() -> dict:
res = {}

for func, type in itertools.product(FUSED_OPERATION_TO_SVE, TYPE_TO_SVE_SUFFIX):
res[f'{func}_{TYPE_TO_SVE_SUFFIX[type.type if isinstance(type, dace.dtypes.typeclass) else type]}'] = dtypes.vector(
if type == dace.vector:
continue
res[f'{func}_{TYPE_TO_SVE_SUFFIX[type]}'] = dtypes.vector(
type if isinstance(type, dtypes.typeclass) else dtypes.typeclass(type), SVE_LEN)
return res

Expand Down
4 changes: 2 additions & 2 deletions dace/transformation/dataflow/sve/vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def can_be_applied(self, state: SDFGState, expr_index, sdfg: SDFG, permissive=Fa
for conn in node.in_connectors:
t = inferred[(node, conn, True)]
bit_widths.add(util.get_base_type(t).bytes)
if not t.type in sve.util.TYPE_TO_SVE:
if not t.base_type in sve.util.TYPE_TO_SVE:
return False
for conn in node.out_connectors:
t = inferred[(node, conn, False)]
bit_widths.add(util.get_base_type(t).bytes)
if not t.type in sve.util.TYPE_TO_SVE:
if not t.base_type in sve.util.TYPE_TO_SVE:
return False

# Multiple different bit widths occuring (messes up the predicates)
Expand Down
1 change: 1 addition & 0 deletions tests/codegen/sve/ast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dace.dtypes
from tests.codegen.sve.common import get_code
import pytest
import math
from dace.codegen.targets.sve.type_compatibility import IncompatibleTypeError

N = dace.symbol('N')
Expand Down

0 comments on commit 44ce520

Please sign in to comment.