Skip to content

Commit

Permalink
scatter reduce lowering with include_self=False
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Sep 11, 2024
1 parent 501a1e1 commit 9a8c124
Show file tree
Hide file tree
Showing 2 changed files with 492 additions and 39 deletions.
52 changes: 44 additions & 8 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,22 @@ def __new__(cls, description, func):
obj.func = func
return obj

def reduce_operation_with_scatter(
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
def reduce_operation_with_scatter_include_self(
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor, min_ele = float('-inf'), max_ele = float('inf'), include_self=True
):
scatter_tensor = None
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
scatter_tensor = torch.zeros_like(initial_tensor)
elif self == ReduceOperation.PROD:
scatter_tensor = torch.ones_like(initial_tensor)
elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX:
elif self == ReduceOperation.AMAX:
scatter_tensor = initial_tensor
if(not(include_self)):
scatter_tensor = torch.full_like(initial_tensor, min_ele)
elif self == ReduceOperation.AMIN:
scatter_tensor = initial_tensor
if(not(include_self)):
scatter_tensor = torch.full_like(initial_tensor, max_ele)
else:
# This case would not be encountered from torch itself
print("Invalid Operation for Reduce op!!")
Expand All @@ -336,13 +342,31 @@ def scatter_reduce_decomposition(
include_self: bool = True,
) -> torch.Tensor:
scatter_loop_tensor = input_tensor
MAX_ELE = 0
MIN_ELE = 0
if(src_tensor.dtype == torch.int32 or input_tensor.dtype == torch.int32):
MAX_ELE = 2147483647
MIN_ELE = -2147483648
else:
MAX_ELE = float('inf')
MIN_ELE = float('-inf')
if(not(include_self)):
if (reduce == "sum" or reduce == "mean"):
scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, torch.zeros_like(src_tensor))
if (reduce == "prod"):
scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, torch.ones_like(src_tensor))
if (reduce == "amax"):
src_red_tensor = torch.full_like(src_tensor, MIN_ELE)
scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, src_red_tensor)
if (reduce == "amin"):
src_red_tensor = torch.full_like(src_tensor, MAX_ELE)
scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, src_red_tensor)

device_input_tensor = input_tensor.device
# required for mean reduce operation
scatter_count_tensor = torch.zeros_like(input_tensor)
src_shape = list(src_tensor.shape)
src_dim = src_shape[dim]
if include_self == False:
raise AssertionError("include_self False for scatter reduce not yet supported")
for i in range(0, src_dim):
src_slice = torch.select(src_tensor, dim, i)
index_slice = torch.select(index, dim, i)
Expand All @@ -366,20 +390,32 @@ def scatter_reduce_decomposition(
dim,
index_slice,
torch.ones_like(src_slice),
MIN_ELE,
MAX_ELE,
include_self
)
elif reduce == "amax":
reduceOp = ReduceOperation.AMAX
elif reduce == "amin":
reduceOp = ReduceOperation.AMIN
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter_include_self(
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice, MIN_ELE, MAX_ELE, include_self
)
if reduce == "mean":
scatter_loop_tensor = torch.div(
scatter_loop_tensor,
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)) if include_self else scatter_count_tensor,,
rounding_mode="trunc",
)
#for include_self cases for amax and amin additional processing is required
#except for the max elements in amax, rest are -inf or INT_MIN
#except for the min elements in amin, rest are +inf or INT_MAX
if reduce == "amax" and not(include_self):
#the relevant should be min, rest original
return torch.max(scatter_loop_tensor, torch.scatter(input_tensor, dim, index, torch.full_like(src_tensor, MIN_ELE)))
if reduce == "amin" and not(include_self):
#the relevant should be min, rest original
return torch.min(scatter_loop_tensor, torch.scatter(input_tensor, dim, index, torch.full_like(src_tensor, MAX_ELE)))
return scatter_loop_tensor


Expand Down
Loading

0 comments on commit 9a8c124

Please sign in to comment.