Skip to content

Commit

Permalink
simpler becomes_map update [pr] (tinygrad#9201)
Browse files Browse the repository at this point in the history
* simpler becomes_map update

* err, no metadata for device

* simpler tensor metadata mapping + tests [pr]

* remove kernel metadata

* don't map nones

* pruning

* linter
  • Loading branch information
Qazalin authored Feb 22, 2025
1 parent 4578c3e commit e6d20c4
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,35 +373,27 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")

# do_realize + group_realizes
# get realizes
sink = tensor_map[big_sink]
realize_map = group_realizes(sink)

# map tensors to new uops
becomes_map: dict[UOp, UOp] = {}
rev_tensor_map: dict[UOp, list[UOp]] = {}
for k,v in tensor_map.items():
rev_tensor_map.setdefault(v, []).append(k)
if k is v: continue
if v.base.op is Ops.BUFFER:
# VIEW isn't a valid tensor uop, we need to backtrack to the movement op that created it
if v.op is Ops.VIEW:
mop = [x for x in k.toposort if (xs:=tensor_map[x]).base is v.base and xs.st == v.st][0]
if k is not mop: becomes_map[k] = mop
else: becomes_map[k] = v
elif v.base.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v

# map tensor metadata to simplified ops
ops_metadata = {v:k.metadata for k,v in tensor_map.items() if k.base.op not in {Ops.CONST, Ops.DEVICE} and isinstance(k.metadata, Metadata)}
# create kernels
kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
sched_sink = kernel_map[sink]
type_verify(list(sched_sink.toposort), kernel_spec)

# map realized tensors to buffers
for k,v in kernel_map.items():
if k is v or v.op is not Ops.ASSIGN: continue
for t in rev_tensor_map[k]: becomes_map[t] = t.src[0] if t.op is Ops.ASSIGN else v.buf_uop.reshape(t.shape)
# map tensors to buffer/const
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
if (a:=kernel_map.get(v)) is not None and a.op is Ops.ASSIGN: becomes_map[k] = k.src[0] if k.op is Ops.ASSIGN else a.buf_uop.reshape(k.shape)
if v is k: continue
if v.base.op is Ops.BUFFER:
# VIEW isn't a valid tensor uop, we need to backtrack to the movement op that created it
if v.op is Ops.VIEW: v = next(iter(x for x in k.toposort if (xs:=tensor_map[x]).base is v.base and xs.st == v.st))
if k is not v: becomes_map[k] = v
elif v.base.op is Ops.CONST:
if all_int(v.shape): becomes_map[k] = v

# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
Expand Down

0 comments on commit e6d20c4

Please sign in to comment.