Skip to content

Commit

Permalink
multi reduce linearizer tests start (#4529)
Browse files Browse the repository at this point in the history
* test_end_local

* test_early_end_local

* todos

* mean+std

* skip no locals
  • Loading branch information
Qazalin committed May 11, 2024
1 parent 3cba229 commit 2fb564c
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions test/test_linearizer.py
Expand Up @@ -97,6 +97,45 @@ def test_multioutput(self):
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg[0] for u in mutable_bufs] == [0, 1]

def test_end_local(self):
if not (opts:=Device[Device.DEFAULT].renderer).has_local or not opts.has_shared: self.skipTest("device does not support locals")
load = MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker.from_shape((32,)))
store = MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker.from_shape((1,)))
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, arg=load),), arg=(0,)),), arg=store),

load_t = Tensor.full(load.st.shape, 1).contiguous().realize()
k = helper_linearizer_ast(ast, [load_t], wanna_output=[load_t.numpy().sum()])[1]
self.assertEqual(k.uops.uops[-1].uop, UOps.ENDIF)
self.assertLess(k.uops.uops.index([x for x in k.uops.uops if x.uop is UOps.STORE][-1]), k.uops.uops.index(k.uops.uops[-1]))

@unittest.expectedFailure
def test_early_end_local(self):
shape, output_shape = (32,), (1,)
load0 = MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape(shape))
load1 = MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape(shape))
store = MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape(output_shape))
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(
LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, arg=load0),)),LazyOp(op=BufferOps.LOAD, arg=load1), )),)),), arg=store),

load0_t = Tensor.randn(shape).realize()
load1_t = Tensor.randn(shape).realize()
k = helper_linearizer_ast(ast, [load0_t, load1_t], wanna_output=[(load0_t.numpy().sum() + load1_t.numpy()).sum()])[1]
self.assertEqual(len(endifs:=[x for x in k.uops.uops if x.uop is UOps.ENDIF]), len(ifs:=[x for x in k.uops.uops if x.uop is UOps.IF]))
self.assertEqual(len(barriers:=[x for x in k.uops.uops if x.uop is UOps.BARRIER]), 3)
self.assertEqual(k.uops.uops[k.uops.uops.index(endifs[0])-1].uop, UOps.STORE)
self.assertEqual(k.uops.uops[k.uops.uops.index(endifs[0])+1], barriers[1])
self.assertEqual(k.uops.uops[k.uops.uops.index(endifs[0])+2].uop, UOps.LOAD)
self.assertLess(k.uops.uops.index(barriers[0]), k.uops.uops.index(ifs[0]))
self.assertLess(k.uops.uops.index(ifs[0]), k.uops.uops.index(endifs[0]))
self.assertLess(k.uops.uops.index(barriers[1]), k.uops.uops.index(ifs[1]))

@unittest.expectedFailure
def test_mean_std_multireduce(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None), LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(875, 35, 1), offset=0, mask=None, contiguous=True),)))),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619047619047618e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(15, 25, 35), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None)), arg=None),), arg=(0, 1, 2)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=7.619628162145687e-05, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501

x = Tensor.randn(15, 25, 35).realize()
helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std()])

def test_load_dedup(self):
# for different leaves in the AST, the same loads may occur.

Expand Down

0 comments on commit 2fb564c

Please sign in to comment.