From 898aafe6fde266fbc202826eb7538b951fd4dccf Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 23 Feb 2025 18:30:04 +0200 Subject: [PATCH] move split_reduceop to scheduler + enable it for multi (#9214) * move split_reduceop to scheduler + enable it for multi * merge r and _reduceop --- tinygrad/engine/schedule.py | 25 +++++++++++++++++++++++-- tinygrad/ops.py | 27 ++------------------------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 308fb9d55671a..6643555d8cc2f 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -5,7 +5,7 @@ from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv -from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND +from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND, SPLIT_REDUCEOP from tinygrad.dtype import ImageDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape @@ -29,6 +29,23 @@ def simplify_stride0_reduce(reduce:UOp, x:UOp): case Ops.MUL: return ret.pow(prshape) case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough +def split_reduceop(reduce:UOp, x:UOp): + if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}") + # reduce original axes, then split + return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape) + def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti) def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp): @@ -48,6 +65,8 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp): lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), # reduce on stride 0 is collapsed (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce), + # split_reduceop + (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop), # COPY(CONST) creates a new CONST on the destination device (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)), # no COPY to same device, except clone (arg is True) @@ -386,7 +405,9 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # map tensors to buffer/const becomes_map: dict[UOp, UOp] = {} for k,v in tensor_map.items(): - if (a:=kernel_map.get(v)) is not None and a.op is Ops.ASSIGN: becomes_map[k] = k.src[0] if k.op is Ops.ASSIGN else a.buf_uop.reshape(k.shape) + # NOTE: tensors can also map to a VIEW, if it's contiguous and we can reshape it it's fine + if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN and a.size == k.size and unwrap(v.st).contiguous: + becomes_map[k] = k.src[0] if k.op is Ops.ASSIGN else a.buf_uop.reshape(k.shape) if v is k: continue if v.base.op is Ops.BUFFER: # VIEW isn't a valid tensor uop, we need to backtrack to the movement op that created it diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7b16ffb61624a..d856d2051a22b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten -from tinygrad.helpers import PICKLE_BUFFERS, SPLIT_REDUCEOP, DEBUG, dedup +from tinygrad.helpers import PICKLE_BUFFERS, dedup if TYPE_CHECKING: from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer @@ -393,32 +393,9 @@ def valid(self, st:ShapeTracker): return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self.replace(src=(unmasked_st,)), UOp.const(self.dtype, 0).replace(src=(unmasked_st,))) @staticmethod def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx) - def _reduce_op(self, op:Ops, axis:tuple[int, ...]): + def r(self, op:Ops, axis:tuple[int, ...]): axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)])) return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) - def r(self, op:Ops, axis:tuple[int, ...]) -> UOp: - new_shape = unwrap(self.st).reduce(axis) - - # TODO: can we split symbolic shape if the reduce axis is not symbolic? - # TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi - if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \ - prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): - return self._reduce_op(op, axis) - - # if there are few globals, make some reduces into globals by splitting into two kernels - # cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm - # ~2**10 should be enough if GROUP is used - # 256 split maximum should be "negligible reduce" for low prod(new_shape), 8 split minimum. - # split is moved to the end to provide maximum locality for the second phase reduce. - self_real_strides = unwrap(self.st).real_strides(ignore_valid=True) - split_candidates = [(i, x) for i in axis for x in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(new_shape)),8-1,-1) - if self.shape[i] % x == 0 and self_real_strides[i] != 0] - if not split_candidates: return self._reduce_op(op, axis) - dim_to_split, divisor = split_candidates[0] - splitted_shape = self.shape[:dim_to_split] + (divisor,) + (self.shape[dim_to_split]//divisor,) + self.shape[dim_to_split+1:] - splitted = self.reshape(splitted_shape).permute(tuple([x for x in range(len(splitted_shape)) if x != dim_to_split]+[dim_to_split])) - if DEBUG >= 3: print(f"split {divisor}: {self.shape} -> {splitted.shape} -> {new_shape}") - return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) def contiguous(self): return self.alu(Ops.CONTIGUOUS) def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)