Skip to content

Commit

Permalink
CuPy fixes and special cases for HIP (spcl#1492)
Browse files Browse the repository at this point in the history
CuPy with AMD HIP does not support `__cuda_array_interface__`. This PR
adds special cases to support CuPy with HIP.
The PR also fixes the reduce node's GPUAuto schedule with respect to
warp size.
  • Loading branch information
tbennun authored Jan 5, 2024
1 parent e427617 commit fa305d2
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 18 deletions.
10 changes: 9 additions & 1 deletion dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,16 @@ def _array_interface_ptr(array: Any, storage: dtypes.StorageType) -> int:
"""
if hasattr(array, 'data_ptr'):
return array.data_ptr()

if storage == dtypes.StorageType.GPU_Global:
return array.__cuda_array_interface__['data'][0]
try:
return array.__cuda_array_interface__['data'][0]
except AttributeError:
# Special case for CuPy with HIP
if hasattr(array, 'data') and hasattr(array.data, 'ptr'):
return array.data.ptr
raise

return array.__array_interface__['data'][0]


Expand Down
10 changes: 8 additions & 2 deletions dace/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,15 @@ def create_datadescriptor(obj, no_custom_desc=False):
else:
dtype = dtypes.typeclass(obj.dtype.type)
return Array(dtype=dtype, strides=tuple(s // obj.itemsize for s in obj.strides), shape=obj.shape)
# special case for torch tensors. Maybe __array__ could be used here for a more
# general solution, but torch doesn't support __array__ for cuda tensors.
elif type(obj).__module__ == "cupy" and type(obj).__name__ == "ndarray":
# special case for CuPy and HIP, which does not support __cuda_array_interface__
storage = dtypes.StorageType.GPU_Global
dtype = dtypes.typeclass(obj.dtype.type)
itemsize = obj.itemsize
return Array(dtype=dtype, shape=obj.shape, strides=tuple(s // itemsize for s in obj.strides), storage=storage)
elif type(obj).__module__ == "torch" and type(obj).__name__ == "Tensor":
# special case for torch tensors. Maybe __array__ could be used here for a more
# general solution, but torch doesn't support __array__ for cuda tensors.
try:
# If torch is importable, define translations between typeclasses and torch types. These are reused by daceml.
# conversion happens here in pytorch:
Expand Down
7 changes: 7 additions & 0 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,8 @@ def is_array(obj: Any) -> bool:
return hasattr(obj, 'shape') and len(obj.shape) > 0
except TypeError: # PyTorch scalar objects define an attribute called shape that cannot be used
return False
if hasattr(obj, 'data') and hasattr(obj.data, 'ptr'): # CuPy special case with HIP
return True
return False


Expand All @@ -1577,4 +1579,9 @@ def is_gpu_array(obj: Any) -> bool:
# In PyTorch, accessing this attribute throws a runtime error for
# variables that require grad, or KeyError when a boolean array is used
return False

if hasattr(obj, 'data') and hasattr(obj.data, 'ptr'): # CuPy special case with HIP
if hasattr(obj, 'device') and getattr(obj.device, 'id', -1) >= 0:
return True

return False
31 changes: 17 additions & 14 deletions dace/libraries/standard/nodes/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ class ExpandReduceGPUAuto(pm.ExpandTransformation):
"""
GPU implementation of the reduce node. This expansion aims to map the reduction inputs to an optimal GPU schedule.
"""
environments = [CUDA]
environments = []

@staticmethod
def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
Expand All @@ -1099,13 +1099,16 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
:param state: the state in which the node is in
:param sdfg: the SDFG in which the node is in
"""
from dace.codegen import common

node.validate(sdfg, state)
inedge: graph.MultiConnectorEdge = state.in_edges(node)[0]
outedge: graph.MultiConnectorEdge = state.out_edges(node)[0]
insubset = dcpy(inedge.data.subset)
isqdim = insubset.squeeze()
raw_input_data = sdfg.arrays[inedge.data.data]
raw_output_data = sdfg.arrays[outedge.data.data]
warp_size = 64 if common.get_gpu_backend() == 'hip' else 32

in_type = raw_input_data.dtype

Expand All @@ -1132,7 +1135,7 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
axes = [axis for axis in axes if axis in isqdim]

# call the planner script
schedule = red_planner.get_reduction_schedule(raw_input_data, axes)
schedule = red_planner.get_reduction_schedule(raw_input_data, axes, warp_size=warp_size)

if schedule.error:
# return pure expansion if error
Expand Down Expand Up @@ -1340,25 +1343,25 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
real_state = nested_sdfg.add_state('real_state')

nested_sdfg.add_edge(start_state, real_state,
dace.InterstateEdge(f'_b1 + 32 * _g < {schedule.in_shape[-1]}'))
dace.InterstateEdge(f'_b1 + {warp_size} * _g < {schedule.in_shape[-1]}'))

reset_outm = dace.Memlet(f'_out[{",".join(["_o%d" % i for i in range(len(schedule.out_shape))])}]')
if len(schedule.out_shape) > 1:
outm = dace.Memlet(
f'_out[{",".join(["_o%d" % i for i in range(len(schedule.out_shape) - 1)])},_g * 32 + _b]',
f'_out[{",".join(["_o%d" % i for i in range(len(schedule.out_shape) - 1)])},_g * {warp_size} + _b]',
dynamic=True)
outm_wcr = dace.Memlet(
f'_out[{",".join(["_o%d" % i for i in range(len(schedule.out_shape) - 1)])},_g * 32 + _b]',
f'_out[{",".join(["_o%d" % i for i in range(len(schedule.out_shape) - 1)])},_g * {warp_size} + _b]',
dynamic=True,
wcr=node.wcr)

else:
outm = dace.Memlet(f'_out[_g * 32 + _b]', dynamic=True)
outm_wcr = dace.Memlet(f'_out[_g * 32 + _b]', dynamic=True, wcr=node.wcr)
outm = dace.Memlet(f'_out[_g * {warp_size} + _b]', dynamic=True)
outm_wcr = dace.Memlet(f'_out[_g * {warp_size} + _b]', dynamic=True, wcr=node.wcr)

input_subset = input_subset[:-2]
input_subset.append(f'0:{schedule.sequential[0]}')
input_subset.append('_g * 32 + _b1')
input_subset.append(f'_g * {warp_size} + _b1')
inmm = dace.Memlet(f'_in[{",".join(input_subset)}]', dynamic=True)

if schedule.multi_axes:
Expand Down Expand Up @@ -1401,13 +1404,13 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
schedule=dtypes.ScheduleType.GPU_ThreadBlock)

else:
bme1, bmx1 = nstate.add_map('block', {'_b': f'0:32'}, schedule=dtypes.ScheduleType.GPU_ThreadBlock)
bme1, bmx1 = nstate.add_map('block', {'_b': f'0:{warp_size}'}, schedule=dtypes.ScheduleType.GPU_ThreadBlock)

bme2, bmx2 = nstate.add_map('block', {f'_b{i}': f'0:{sz}'
for i, sz in enumerate(schedule.block)},
schedule=dtypes.ScheduleType.GPU_ThreadBlock)

# add shared memory of size 32 to outer sdfg
# add shared memory of warp size to outer sdfg
nsdfg.add_array('s_mem', [schedule.shared_mem_size],
nsdfg.arrays['_in'].dtype,
dtypes.StorageType.GPU_Shared,
Expand Down Expand Up @@ -1482,11 +1485,11 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
if mini_warps:
cond_tasklet = nstate.add_tasklet(
'cond_write', {'_input'}, {'_output'},
f'if _b + 32 * _g < {schedule.out_shape[-1]} and _bb == 0 and _mwid == 0: _output = _input')
f'if _b + {warp_size} * _g < {schedule.out_shape[-1]} and _bb == 0 and _mwid == 0: _output = _input')
else:
cond_tasklet = nstate.add_tasklet(
'cond_write', {'_input'}, {'_output'},
f'if _b + 32 * _g < {schedule.out_shape[-1]} and _bb == 0: _output = _input')
f'if _b + {warp_size} * _g < {schedule.out_shape[-1]} and _bb == 0: _output = _input')

# connect accumulator to identity tasklet
real_state.add_memlet_path(accread, ime, id, dst_conn='a', memlet=dace.Memlet('acc[0]'))
Expand All @@ -1511,8 +1514,8 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
nstate.add_memlet_path(s_mem3, bme3, cond_tasklet, dst_conn='_input', memlet=dace.Memlet('s_mem[_b]'))
else:
bme3, bmx3 = nstate.add_map('block', {
'_bb': '0:16',
'_b': f'0:32'
'_bb': f'0:{512//warp_size}',
'_b': f'0:{warp_size}'
},
schedule=dtypes.ScheduleType.GPU_ThreadBlock)
nstate.add_memlet_path(s_mem3, bme3, cond_tasklet, dst_conn='_input', memlet=dace.Memlet('s_mem[_b]'))
Expand Down
5 changes: 4 additions & 1 deletion dace/libraries/standard/nodes/ttranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def expansion(node, parent_state, parent_sdfg):
out_mem = dace.Memlet(expr=f"_out_tensor[{','.join([map_params[i] for i in node.axes])}]")
inputs = {"_inp": inp_mem}
outputs = {"_out": out_mem}
code = f"_out = {node.alpha} * _inp"
if node.alpha == 1:
code = "_out = _inp"
else:
code = f"_out = decltype(_inp)({node.alpha}) * _inp"
if node.beta != 0:
inputs["_inout"] = out_mem
code = f"_out = {node.alpha} * _inp + {node.beta} * _inout"
Expand Down
20 changes: 20 additions & 0 deletions dace/runtime/include/dace/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,9 @@ namespace dace {
cub::TransformInputIterator<int, decltype(conversion_op), decltype(counting_iterator)> itr(counting_iterator, conversion_op);
return itr;
}
#endif

#if defined(__CUDACC__)
template <ReductionType REDTYPE, typename T>
struct warpReduce {
static DACE_DFI T reduce(T v)
Expand All @@ -610,6 +612,24 @@ namespace dace {
return v;
}
};
#elif defined(__HIPCC__)
template <ReductionType REDTYPE, typename T>
struct warpReduce {
static DACE_DFI T reduce(T v)
{
for (int i = 1; i < warpSize; i = i * 2)
v = _wcr_fixed<REDTYPE, T>()(v, __shfl_xor(v, i));
return v;
}

template<int NUM_MW>
static DACE_DFI T mini_reduce(T v)
{
for (int i = 1; i < NUM_MW; i = i * 2)
v = _wcr_fixed<REDTYPE, T>()(v, __shfl_xor(v, i));
return v;
}
};
#endif

} // namespace dace
Expand Down

0 comments on commit fa305d2

Please sign in to comment.