Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tracking Issue] Upgrade sparse_fuse primitive #38

Open
yzh119 opened this issue Sep 23, 2022 · 0 comments
Open

[Tracking Issue] Upgrade sparse_fuse primitive #38

yzh119 opened this issue Sep 23, 2022 · 0 comments

Comments

@yzh119
Copy link
Member

yzh119 commented Sep 23, 2022

Problem

Previously our fusion syntax looks like this:

# before fuse
with T.iter([I, J, K], "SSR", "sddmm") as [i, j, k]:
    # body

# afte fuse
with T.iter([T.fuse(I, J), K], "SSR", "sddmm") as [i, j, k]:
    # body

while we keep the body in its original form.

after sparse iteration lowering, the fused sparse iteration we will emit such structures:

for j, k in T.grid(nnz, feat_size):
    with T.block("sddmm0"):
        vi = T.axis.spatial(1, 0)
        vj, vk = T.axis.remap("SR", [j, k])
        T.reads(
            A[mid_0[vi, vj], vk], mid_0[vi, vj], B[J_indices[vi, vj], vk], J_indices[vi, vj]
        )
        T.writes(C[vi, vj])
        T.block_attr({"sparse": True})
        with T.init():
            C[vi, vj] = T.float32(0)
        C[vi, vj] = C[vi, vj] + A[mid_0[vi, vj], vk] * B[J_indices[vi, vj], vk]  

Its semantics are problematic for read/write region analysis because we didn't change the definition of C, which is still a 2-dimensional buffer while the fused iterator has collapsed to 1-dimensional. This behavior might introduce some problems when executing compute_at/...

Proposed Solution

We should fuse two axes, instead of two sparse iterators:

# before fusion
I = T.dense_fixed(m)
J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32")
J_detach = T.dense_fixed(n)
K = T.dense_fixed(feat_size)
A = T.match_sparse_buffer(a, (I, K), "float32")
B = T.match_sparse_buffer(b, (J_detach, K), "float32")
C = T.match_sparse_buffer(c, (I, J), "float32")

with T.iter([I, J, K], "SSR", "sddmm") as [i, j, k]:
    with T.init():
        C[i, j] = 0.0
    C[i, j] = C[i, j] + A[i, k] * B[j, k]

# after fusion (by calling sparse_fuse(I, J)
I = T.dense_fixed(m)
J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32")
IJ = T.dense_fixed(nnz, "int32")
J_detach = T.dense_fixed(n)
K = T.dense_fixed(feat_size)
A = T.match_sparse_buffer(a, (I, K), "float32")
B = T.match_sparse_buffer(b, (J_detach, K), "float32")
C = T.match_sparse_buffer(c, (IJ,), "float32")

with T.iter([IJ, K], "SR", "sddmm") as [ij, k]:
    with T.init():
        C[ij] = 0.0
    C[ij] = C[ij] + A[T.sparse_fuse_decode((I, J), ij, 0), k] * B[T.sparse_fuse_decode([I, J], ij, 1), k]

note that C has been transformed to a collapsed 1D buffer instead of 2D, the new T.sparse_fuse_decode primitive would be lowered to binary search procedures in stage-II.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant