Skip to content

Commit

Permalink
test_linearizer_correctness (#4458)
Browse files Browse the repository at this point in the history
* test helper

* uops asserts

* cleanup args

* nits
  • Loading branch information
Qazalin committed May 11, 2024
1 parent b3d9fd4 commit 3cba229
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions test/test_linearizer.py
Expand Up @@ -88,8 +88,9 @@ def test_multioutput(self):
out0 = LazyOp(BufferOps.STORE, (LazyOp(op=BinaryOps.ADD, src=(a,b)),), MemBuffer(idx=0, dtype=dtype, st=st))
out1 = LazyOp(BufferOps.STORE, (LazyOp(op=BinaryOps.MUL, src=(a,b)),), MemBuffer(idx=1, dtype=dtype, st=st))

lin = Linearizer(out0, out1)
lin.linearize()
a_t = Tensor.full(st.shape, 2).contiguous().realize()
b_t = Tensor.full(st.shape, 3).contiguous().realize()
lin = helper_linearizer_ast((out0, out1), [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0]

stores = [u for u in lin.uops if u.uop is UOps.STORE]
mutable_bufs = [u for u in lin.uops if u.uop is UOps.DEFINE_GLOBAL and u.arg[-1]]
Expand Down Expand Up @@ -584,18 +585,25 @@ def test_matvec(self):
assert k.local_dims == 1
assert k.upcasted == 1

def helper_linearizer_opt(r:Tensor, opts=[], apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[]):
def helper_linearizer_ast(ast:Tuple[LazyOp, ...], inputs:List[Tensor], *args, **kwargs):
inbufs = [x.lazydata.buffer for x in inputs]
outbufs = [Buffer(inbufs[-1].device, out.arg.st.size, out.arg.dtype).allocate() for out in ast]
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)

def helper_linearizer_opt(r:Tensor, *args, **kwargs):
realized_ast, real_bufs = helper_realized_ast(r)
return helper_linearizer_opt_ast((realized_ast, ), real_bufs, opts, apply_tc, atol, rtol, color_sizes)
return _helper_linearizer_opt_ast((realized_ast, ), real_bufs, *args, **kwargs)

def helper_linearizer_opt_ast(realized_ast:Tuple[LazyOp, ...], real_bufs:List[Buffer], opts=[], apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[]):
wanna_output = []
def _helper_linearizer_opt_ast(realized_ast:Tuple[LazyOp, ...], real_bufs:List[Buffer], opts=[],
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Linearizer]:
lins: List[Linearizer] = []
outbufs = [real_bufs[i] for i in range(len(realized_ast))]

def get_prg(k:Linearizer): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))

def check_opt(opts, create_k, expected_color_size):
k = create_k()
lins.append(k)
if apply_tc:
assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered"
else:
Expand All @@ -610,14 +618,19 @@ def check_opt(opts, create_k, expected_color_size):
for i, buf in enumerate(outbufs):
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), buf.dtype.np), wanna_output[i], atol=atol, rtol=rtol)

# Get baseline, which is not optimized at all.
# Get baseline if it is not provided, which is not optimized at all.
k = Linearizer(*realized_ast)
lins.append(k)
prg = get_prg(k)
prg.exec(real_bufs)
wanna_output = [np.frombuffer(buf.as_buffer(), buf.dtype.np).copy() for buf in outbufs]
if len(wanna_output) == 0: wanna_output = [np.frombuffer(buf.as_buffer(), buf.dtype.np).copy() for buf in outbufs]
else:
for i, buf in enumerate(outbufs):
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), buf.dtype.np), wanna_output[i], atol=atol, rtol=rtol)

# Check correctness of handcoded optimiztions.
k = Linearizer(*realized_ast)
lins.append(k)
k.hand_coded_optimizations()
prg = get_prg(k)
for buf in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=buf.dtype.np).data) # Zero to check that all values are filled
Expand All @@ -626,6 +639,7 @@ def check_opt(opts, create_k, expected_color_size):
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), buf.dtype.np), wanna_output[i], atol=atol, rtol=rtol)
for i, x in enumerate(opts): # Check custom transformations if any.
check_opt(x, lambda: Linearizer(*realized_ast), color_sizes[i] if i < len(color_sizes) else None)
return lins

class TestKernelOpts(unittest.TestCase):
def test_local_and_grouped_reduce(self):
Expand Down

0 comments on commit 3cba229

Please sign in to comment.