Skip to content

Commit

Permalink
[AMD] Add a tt.pointer_range_32 specialization (#4910)
Browse files Browse the repository at this point in the history
This is a PR adding an attribute in the HIP backend to test
for a tensor storage to be within 2GB. This will enable
support of buffer operations.
  • Loading branch information
giuseros authored Oct 17, 2024
1 parent 1883703 commit 692143c
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 15 deletions.
27 changes: 27 additions & 0 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import triton
import triton.language as tl
from triton.runtime.jit import JITFunction
from triton._internal_testing import is_hip


@triton.jit
Expand Down Expand Up @@ -572,3 +573,29 @@ def compiled_hook(*args, **kwargs):
assert specialization_data is not None and specialization_data_compiled == specialization_data
assert is_warmup is True
assert key in kernel_add.cache[getattr(torch, device).current_device()]


@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip())
def test_within_2gb(device, fresh_triton_cache) -> None:

@triton.jit
def kernel_add(a):
tl.load(a)

# This is the attribute we want to test
pointer_range_32 = None

def cache_hook(*args, **kwargs):
nonlocal pointer_range_32
pointer_range_32 = kwargs["compile"]["configs"][0].pointer_range_32

JITFunction.cache_hook = cache_hook
# In warmup we assume that the pointer range is 32 bits
kernel_add.warmup(torch.float32, grid=(1, ))
assert pointer_range_32 == [0]
# Torch tensor > 2GB
kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device))
assert len(pointer_range_32) == 0
# Torch tensor <= 2GB
kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device))
assert pointer_range_32 == [0]
52 changes: 38 additions & 14 deletions python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,21 @@
from typing import Dict, List, Tuple, Union
from types import ModuleType

# Table that associates strings to AttrsDescriptor (sub)classes.
# In this way we can dynamically select the correct class
# constructor
_descriptor_table = {}


def register_descriptor(cls):
"""
Register a descriptor into the descriptor table
"""
_descriptor_table[cls.__name__] = cls
return cls


@register_descriptor
class AttrsDescriptor:
"""
This class handles compile-time properties for specific function parameters.
Expand Down Expand Up @@ -135,18 +149,28 @@ def hash(self):
return hashlib.sha256(key.encode("utf-8")).hexdigest()

def to_dict(self):
return self.arg_properties
"""
Store the fields of this class in a serializable dictionary
"""
# We need to only store the `arg_properties` field. To initialize the
# other fields we relay on the class type. We store it as a string in
# the dictionary so that we can use it to invoke the appropriate
# (sub)class constructor in the `from_dict` method.
return {"arg_properties": self.arg_properties, "cls": type(self).__name__}

@staticmethod
def from_dict(data):
attrsDescriptor = AttrsDescriptor()
for prop_name, param_ids in data.items():
attrsDescriptor.arg_properties[prop_name] = param_ids
attrsDescriptor._init_slots()
return attrsDescriptor

@staticmethod
def from_hints(hints: List[Tuple[int, int]]):
"""
Create the object from a serializable dictionary
"""
attrs_descriptor = _descriptor_table[data["cls"]]()
for prop_name, param_ids in data["arg_properties"].items():
attrs_descriptor.arg_properties[prop_name] = param_ids
attrs_descriptor._init_slots()
return attrs_descriptor

@classmethod
def from_hints(cls, hints: List[Tuple[int, int]]):
"""
Create the class from a set of hints that are passed in.
Expand All @@ -156,11 +180,11 @@ def from_hints(hints: List[Tuple[int, int]]):
then we insert `param_index` into the correct list (e.g., in
`arg_properties[prop0]`)
"""
attrsDescriptor = AttrsDescriptor()
for prop_name, prop_val in attrsDescriptor.property_values.items():
attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
attrsDescriptor._init_slots()
return attrsDescriptor
attrs_descriptor = cls()
for prop_name, prop_val in attrs_descriptor.property_values.items():
attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
attrs_descriptor._init_slots()
return attrs_descriptor

@staticmethod
def is_divisible_by_16(x):
Expand Down
4 changes: 4 additions & 0 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,10 @@ def __init__(self, dtype):
def data_ptr():
return 0 # optimistically assumes multiple of 16

@staticmethod
def ptr_range():
return 0 # optimistically assumes 32 bit pointer range


class TensorWrapper:

Expand Down
47 changes: 46 additions & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from triton.backends.compiler import BaseBackend, GPUTarget
from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor
from triton._C.libtriton import ir, passes, llvm, amd
from dataclasses import dataclass
from typing import Any, Dict, Tuple
Expand Down Expand Up @@ -72,6 +72,44 @@ def hash(self):
return hashlib.sha256(key.encode("utf-8")).hexdigest()


@register_descriptor
class HIPAttrsDescriptor(AttrsDescriptor):
# This property asserts if the underlying storage area of a given pointer
# can be resepresented as a 32 bit integer. When this is true, we can be
# sure that all indices into the tensor behind that pointer can use 32-bit
# indexing. That opens the door for the AMD backend to use buffer load/store
# instrinsics, which requires this property. Buffer load/store intrinsics
# gives direct out-of-bound support and simplifies index calculation for
# lower register pressure.
__slots__ = ("pointer_range_32")

def _add_backend_properties(self, params=None, values=None):
self.property_values["tt.pointer_range"] = 32
if params is None or values is None:
return

self.arg_properties["tt.pointer_range"] = [
param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg)
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
]

@staticmethod
def is_within2gb(arg):
if hasattr(arg, "ptr_range"):
return arg.ptr_range() <= 2**31 - 1
if "torch.Tensor" in str(type(arg)) and hasattr(arg, "untyped_storage"):
# Please note that 2**31-1 is the max int32 positive limit
return arg.untyped_storage().size() <= 2**31 - 1
return False

@staticmethod
def get_property_key(val, align):
generic_key = AttrsDescriptor.get_property_key(val, align)
hip_key = "S" if HIPAttrsDescriptor.is_within2gb(val) else "N"
key = (generic_key + hip_key).replace("N", "")
return key if key else "N"


class HIPBackend(BaseBackend):

@staticmethod
Expand Down Expand Up @@ -118,6 +156,13 @@ def get_module_map(self) -> Dict[str, ModuleType]:
def load_dialects(self, ctx):
amd.load_dialects(ctx)

def get_attrs_descriptor(self, params, args):
return HIPAttrsDescriptor(params, args)

@staticmethod
def compute_spec_key(arg, align):
return HIPAttrsDescriptor.get_property_key(arg, align)

@staticmethod
def path_to_rocm_lld():
# Check env path for ld.lld
Expand Down

0 comments on commit 692143c

Please sign in to comment.