Skip to content

Commit

Permalink
move split_reduceop to scheduler + enable it for multi (tinygrad#9214)
Browse files Browse the repository at this point in the history
* move split_reduceop to scheduler + enable it for multi

* merge r and _reduceop
  • Loading branch information
Qazalin authored Feb 23, 2025
1 parent 05e3202 commit 898aafe
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
25 changes: 23 additions & 2 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
Expand All @@ -29,6 +29,23 @@ def simplify_stride0_reduce(reduce:UOp, x:UOp):
case Ops.MUL: return ret.pow(prshape)
case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough

def split_reduceop(reduce:UOp, x:UOp):
if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
# if there are few globals, make some reduces into globals by splitting into two kernels
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
# ~2**10 should be enough if GROUP is used
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
# split is moved to the end to provide maximum locality for the second phase reduce.
real_strides = unwrap(x.st).real_strides(ignore_valid=True)
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
if x.shape[i]%d==0 and real_strides[i]!=0]): return None
dim_to_split, divisor = split_candidates[0]
splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:]
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
if DEBUG >= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}")
# reduce original axes, then split
return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)

def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
Expand All @@ -48,6 +65,8 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce on stride 0 is collapsed
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)),
# no COPY to same device, except clone (arg is True)
Expand Down Expand Up @@ -386,7 +405,9 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# map tensors to buffer/const
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
if (a:=kernel_map.get(v)) is not None and a.op is Ops.ASSIGN: becomes_map[k] = k.src[0] if k.op is Ops.ASSIGN else a.buf_uop.reshape(k.shape)
# NOTE: tensors can also map to a VIEW, if it's contiguous and we can reshape it it's fine
if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN and a.size == k.size and unwrap(v.st).contiguous:
becomes_map[k] = k.src[0] if k.op is Ops.ASSIGN else a.buf_uop.reshape(k.shape)
if v is k: continue
if v.base.op is Ops.BUFFER:
# VIEW isn't a valid tensor uop, we need to backtrack to the movement op that created it
Expand Down
27 changes: 2 additions & 25 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup
from tinygrad.helpers import PICKLE_BUFFERS, dedup
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
Expand Down Expand Up @@ -393,32 +393,9 @@ def valid(self, st:ShapeTracker):
return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self.replace(src=(unmasked_st,)), UOp.const(self.dtype, 0).replace(src=(unmasked_st,)))
@staticmethod
def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx)
def _reduce_op(self, op:Ops, axis:tuple[int, ...]):
def r(self, op:Ops, axis:tuple[int, ...]):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
def r(self, op:Ops, axis:tuple[int, ...]) -> UOp:
new_shape = unwrap(self.st).reduce(axis)

# TODO: can we split symbolic shape if the reduce axis is not symbolic?
# TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi
if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
return self._reduce_op(op, axis)

# if there are few globals, make some reduces into globals by splitting into two kernels
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
# ~2**10 should be enough if GROUP is used
# 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum.
# split is moved to the end to provide maximum locality for the second phase reduce.
self_real_strides = unwrap(self.st).real_strides(ignore_valid=True)
split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1)
if self.shape[i] % x == 0 and self_real_strides[i] != 0]
if not split_candidates: return self._reduce_op(op, axis)
dim_to_split, divisor = split_candidates[0]
splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:]
splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split]))
if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}")
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
Expand Down

0 comments on commit 898aafe

Please sign in to comment.