Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some underlying issues with tensor core sample #1336

Merged
merged 4 commits into from
Jul 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 36 additions & 51 deletions samples/codegen/tensor_cores.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dace.sdfg.graph import MultiConnectorEdge
from dace.sdfg.state import StateSubgraphView
from dace.codegen.prettycode import CodeIOStream
from dace.codegen.dispatcher import DefinedType
from typing import Any, List

# Other imports
Expand Down Expand Up @@ -76,6 +77,9 @@ def __init__(self, frame_codegen: DaCeCodeGenerator, sdfg: dace.SDFG):
def allocate_array(self, sdfg: dace.SDFG, dfg: StateSubgraphView, state_id: int, node: nodes.AccessNode,
nodedesc: dt.Array, function_stream: CodeIOStream, declaration_stream: CodeIOStream,
allocation_stream: CodeIOStream):
# Make sure the codegen includes the appropriate header files
_include_mma(sdfg)

name = node.data

# Based on the hardware, the total size must be 16^2
Expand All @@ -85,14 +89,16 @@ def allocate_array(self, sdfg: dace.SDFG, dfg: StateSubgraphView, state_id: int,

# Write a fragment based on the storage type
if nodedesc.storage == dace.StorageType.TensorCore_Accumulator:
declaration_stream.write('wmma::fragment<wmma::accumulator, '
'16, 16, 16, float> {};'.format(name), sdfg, state_id, node)
ctype = 'wmma::fragment<wmma::accumulator, 16, 16, 16, float>'
declaration_stream.write(f'{ctype} {name};', sdfg, state_id, node)
else:
declaration_stream.write(
'wmma::fragment<wmma::matrix_{mat}, '
'16, 16, 16, half, wmma::{maj}_major> '
'{name};'.format(mat=('a' if 'A' in nodedesc.storage.name else 'b'), maj=maj, name=name), sdfg,
state_id, node)
ctype = 'wmma::fragment<wmma::matrix_{mat}, 16, 16, 16, half, wmma::{maj}_major>'.format(
mat=('a' if 'A' in nodedesc.storage.name else 'b'), maj=maj)
declaration_stream.write(f'{ctype} {name};', sdfg, state_id, node)

# Add the ctype to defined_vars so that the codegen can properly pass
# fragments to functions as an object reference.
self._dispatcher.defined_vars.add(name, DefinedType.Stream, ctype)

def deallocate_array(self, sdfg: dace.SDFG, dfg: StateSubgraphView, state_id: int, node: nodes.AccessNode,
nodedesc: dt.Array, function_stream: CodeIOStream, callsite_stream: CodeIOStream):
Expand Down Expand Up @@ -187,50 +193,29 @@ def _include_mma(sdfg: dace.SDFG):
sdfg.append_global_code(global_code, 'cuda')


@replaces('frag_fill')
def frag_fill(pv: ProgramVisitor, sdfg: dace.SDFG, state: dace.SDFGState, frag: str, fill: Any) -> List[str]:
# Replacement functions receive the SDFG and the current state as the first
# two arguments, followed by all the other arguments. Here we treat them as
# two strings representing the array name to fill and what to fill it with.

# NOTE: If a slice is used in the `frag` argument, the Python frontend
# automatically creates a new array for it, and uses the correct string as
# the argument.
wnode = state.add_write(frag)
tasklet = state.add_tasklet('fill',
set(), {'out'},
'''
wmma::fill_fragment(out, %s);''' % fill,
language=dace.Language.CPP)

state.add_edge(tasklet, 'out', wnode, None, dace.Memlet.from_array(frag, wnode.desc(sdfg)))

_include_mma(sdfg)

# Function has no return value
return []


@replaces('wmma')
def wmma(pv: ProgramVisitor, sdfg: dace.SDFG, state: dace.SDFGState, a_frag: str, b_frag: str,
c_frag: str) -> List[str]:
# Implemented similarly to `frag_fill`, but with inputs and outputs.
anode = state.add_read(a_frag)
bnode = state.add_read(b_frag)
cnode = state.add_write(c_frag)
tasklet = state.add_tasklet('wmma', {'afrag', 'bfrag'}, {'cfrag'},
'''
wmma::mma_sync(cfrag, afrag, bfrag, cfrag);''',
language=dace.Language.CPP)

state.add_edge(anode, None, tasklet, 'afrag', dace.Memlet.from_array(a_frag, anode.desc(sdfg)))
state.add_edge(bnode, None, tasklet, 'bfrag', dace.Memlet.from_array(b_frag, bnode.desc(sdfg)))
state.add_edge(tasklet, 'cfrag', cnode, None, dace.Memlet.from_array(c_frag, cnode.desc(sdfg)))

_include_mma(sdfg)

# Function has no return value
return []
def frag_fill(frag, fill):
# Define a tasklet with the appropriate input and output connectors.
# Then we can directly emit CUDA for the tasklet.
with dace.tasklet(dace.Language.CPP):
val << fill
out >> frag
"""
wmma::fill_fragment(out, val);
"""

def wmma(a_frag, b_frag, c_frag):
# We do the same here as we did with frag_fill. Since c_frag is used
# as both an input and an output, we specify two separate variables
# to be passed to mma_sync and declare c_frag as an input to one and
# an output to the other. This ensures proper dataflow.
with dace.tasklet(dace.Language.CPP):
afrag << a_frag
bfrag << b_frag
cfrag << c_frag
dfrag >> c_frag
"""
wmma::mma_sync(dfrag, afrag, bfrag, cfrag);
"""


############################################################################
Expand Down
Loading