Skip to content

Commit

Permalink
move copy kernel to out of schedule ordering (#4530)
Browse files Browse the repository at this point in the history
* delete from sorting

* move the logic
  • Loading branch information
Qazalin committed May 11, 2024
1 parent 2fb564c commit 4871476
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tinygrad/engine/schedule.py
@@ -1,6 +1,6 @@
import sys, pickle, atexit
from collections import defaultdict, deque
from dataclasses import dataclass, replace
from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
Expand Down Expand Up @@ -82,7 +82,10 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None]
inputs: List[LazyBuffer] = []
ast: List[LazyOp] = []
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
if outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}:
if outs[0].op is LoadOps.COPY and getenv("USE_COPY_KERNEL") and outs[0].device.split(":")[0] == outs[0].srcs[0].device.split(":")[0]:
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((outs[0].arg,))))
ast, inputs = [LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))], [x.base for x in outs[0].srcs]
elif outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}:
ast, inputs = [LazyOp(outs[0].op, (), outs[0].arg)], [x.base for x in outs[0].srcs]
else:
for i, out in enumerate(outs):
Expand Down Expand Up @@ -267,9 +270,6 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
var_vals = merge_dicts([var_vals, ps.var_vals])
for out in ps.outputs: del out.srcs # can only schedule once
if getenv("USE_COPY_KERNEL") and ps.ast[0].op == LoadOps.COPY and ps.outputs[0].device.split(":")[0] == ps.inputs[0].device.split(":")[0]:
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((ps.ast[0].arg,))))
ps = replace(ps, ast=(LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)),))
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
for x in graph[ps.outputs[0]]:
Expand Down

0 comments on commit 4871476

Please sign in to comment.