Skip to content

Commit

Permalink
No extra vars call (#3054)
Browse files Browse the repository at this point in the history
* remove unused reciprocal

* comment

* remove unneeded call to vars

* free speedup
  • Loading branch information
geohot committed Jan 9, 2024
1 parent 259bf9b commit 2c6f2e8
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -49,3 +49,4 @@ model.safetensors
quickstart.py
.hypothesis
weights
*.lprof
3 changes: 2 additions & 1 deletion test/external/external_test_speed_llama.py
Expand Up @@ -54,4 +54,5 @@ def run_llama(st, empty_method_cache=True):
Device[Device.DEFAULT].compiler = backup_compiler

if __name__ == '__main__':
unittest.main()
TestLLaMASpeed().test_llama_compile()
#unittest.main()
28 changes: 14 additions & 14 deletions tinygrad/lazy.py
Expand Up @@ -155,37 +155,38 @@ def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))

# recursively create a lazyop
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Set[LazyBuffer], first=True, cache=None) -> LazyOp:
if cache is None: cache = {}
realizes:Set[LazyBuffer], cache, first=True) -> LazyOp:
if (buf, st) in cache: return cache[(buf, st)]
if buf != buf.base:
var_vals.update(merge_dicts([var_vals, buf.st.var_vals]))
st = buf.st.unbind()+st
st = buf.st + st
buf = buf.base
# all buffers here are base now
assert buf.op is not None

# consts are always fused and generated
if buf.op == LoadOps.CONST:
return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, st.simplify()))
# TODO: make shapetracker unbind also return var_vals
var_vals.update(merge_dicts([var_vals, st.var_vals]))
return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, st.simplify().unbind()))

# if we aren't fusing it, it's a load and we add it to the inputs
if buf.realized or (buf in realizes and not first):
if buf not in inputs: inputs.append(buf)
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, st.simplify()))
var_vals.update(merge_dicts([var_vals, st.var_vals]))
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, st.simplify().unbind()))

# if a CONTIGUOUS made it all the way here, just skip it
if buf.op == LoadOps.CONTIGUOUS:
assert first
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, False, cache)
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)

# if it's a reduce, we have to change the shapetracker
if buf.op in ReduceOps:
assert st.contiguous, "ReduceOps late fusion must be contiguous"
st = ShapeTracker.from_shape(buf.srcs[0].shape).unbind()
st = ShapeTracker.from_shape(buf.srcs[0].shape)

# otherwise we fuse it like normal
cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, False, cache) for x in buf.srcs), buf.arg)
cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg)
return ret

# recursively walk back in the graph to create the schedule
Expand All @@ -204,12 +205,11 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
elif out.op == LoadOps.EMPTY:
op = LazyOp(LoadOps.EMPTY)
else:
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape).unbind()
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes)
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify()))
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()))

return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + \
[ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in op.vars()})]
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]

# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
Expand Down

0 comments on commit 2c6f2e8

Please sign in to comment.