Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tracking Issue] Support Axis Union/Intersection #72

Open
3 tasks
yzh119 opened this issue Nov 18, 2022 · 7 comments
Open
3 tasks

[Tracking Issue] Support Axis Union/Intersection #72

yzh119 opened this issue Nov 18, 2022 · 7 comments
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@yzh119
Copy link
Member

yzh119 commented Nov 18, 2022

The Problem

Currently, SparseTIR does not support lowering code to co-iteration structure, whenever we want to add/multiply two sparse tensors/vectors, we need to create another axis to indicate the union/intersection of axes.

Here is an example of SpMSpV.

I = T.dense_fixed(m)
J = T.sparse_variable(I, (n, nnz), indptr_j, indices_j)
IV = T.dense_fixed(1)
JV = T.sparse_variable(IV, (n, nnz), indptr_jv, indices_jv)
J_and = T.sparse_variable(I, (n, nnz), indptr_j_and, indices_j_and)
A = T.match_sparse_buffer(a, (I, J))
B = T.match_sparse_buffer(b, (IV, JV))
with T.iter([I, J_and], "SR", "spmspv") as [i, j]:
    with T.init():
        C[i] = T.float32(0)
    C[i] = C[i] + A[i, j] * B[0, j]

SparseTIR would generate several binary blocks for indexing A and B because we do not have co-iterations yet, and we need mid arrays generated by binary search blocks to access A and B under the for-loop structure.

Once we support axis union/intersection and co-iteration structure generation, we can declare J_and as:

J_and = T.intersection([J, JV], indptr_j_and, indices_j_and)
J_or = T.union([J, JV], indptr_j_or, indices_j_or)

and sparse iterations on union/intersect axes can yield co-iteration structures in sparse iteration lowering pass.

Milestone

  • Support co-iteration structure (either w/ While construct, or create a new statement in TIR).
  • Support T.intersection/T.union, and possibly more general ones (consider SpGEMM).
  • Modify sparse iteration lowering pass.
@yzh119 yzh119 added enhancement New feature or request help wanted Extra attention is needed labels Nov 18, 2022
@yzh119 yzh119 self-assigned this Nov 29, 2022
@qelk123
Copy link
Contributor

qelk123 commented Mar 30, 2023

Given the current design, compare to directly using the co-iteration, I think we can firstly replace the binary search module with traverse to generate mid-buffer.In this way, the complexity of generate mid-buffer in a sparse axis can be reduced to O(n) from O(nlogn). This method only change the binary search part and can probably prevent the potential performance loss caused by co-iteration like TACO.@yzh119
The sample code looks like this:
(1)for sparse axis iterate on dense axis:

I = T.dense_fixed(m)
K = T.sparse_variable(I, (k, nnz), (indptr, indices), "int32")
K_detach = T.dense_fixed(k)
J = T.dense_fixed(n)
A = T.match_sparse_buffer(a, (I, K), "float32")
B = T.match_sparse_buffer(b, (K_detach, J), "float32")
C = T.match_sparse_buffer(c, (I, J), "float32")
with T.iter([I, J, K_detach], "SSR", "csrmm") as [i, j, k]:
    T.block_attr({"binary_search_vaild_check": False})
    with T.init():
        C[i, j] = T.float32(0)
    C[i, j] = C[i, j] + A[i, k] * B[k, j]

generated "binary search block":

     mid_0 = alloc_buffer(int32[5, 5])
     {
      for (i: int32, 0, 5) {
        for (k: int32, 0, 5) {
          block([5, 5], "binary_search_block_0_0") as [vi, vk] {
            bind(vi, i)
            bind(vk, k)
            tir.reads([K_indices[vi, 0:(K_indptr[(vi + 1)] - K_indptr[vi])]])
            tir.writes([mid_0[vi, vk]])
            tir.attrs({"sparse": True, "preprocess": True, "is_binary_search_block": True})
            cur = alloc_buffer(int32[1])
            high = alloc_buffer(int32[1])
             {
              if (vk == 0) {
                cur[0] = 0
                high[0] = (K_indptr[(vi + 1)] - K_indptr[vi])
              }
              if ((vk == K_indices[vi, cur[0]]) && (cur[0] < high[0])) {
                mid_0[vi, vk] = cur[0]
                cur[0] = (cur[0] + 1)
              } else {
                mid_0[vi, vk] = -1
              }
            }
        }
      }

(2)for sparse axis iterate on another sparse axis(Intersection):

    I = T.dense_fixed(m)
    K = T.sparse_variable(I, (k, nnz), (indptr, indices), "int32")
    J = T.dense_fixed(n)
    K2 = T.sparse_variable(J, (k2, nnz2), (indptr2, indices2), "int32")
    A = T.match_sparse_buffer(a, (I, K), "float32")
    B = T.match_sparse_buffer(b, (J, K2), "float32")
    C = T.match_sparse_buffer(c, (I, J), "float32")
    with T.iter([I, J, K2], "SSR", "csrmm") as [i, j, k]:
        T.block_attr({"binary_search_vaild_check": False})
        with T.init():
            C[i, j] = T.float32(0)
        C[i, j] = C[i, j] + A[i, k] * B[j, k]

generated "binary search block":

     mid_0 = alloc_buffer(int32[m, n, k2])
     {
      for (i: int32, 0, m) {
        for (j: int32, 0, n) {
          block([m, n], "binary_search_block_0_0") as [vi, vj] {
            bind(vi, i)
            bind(vj, j)
            tir.reads([K2_indptr[vj:(vj + 2)], K_indices[vi, 0:(K_indptr[(vi + 1)] - K_indptr[vi])]])
            tir.writes([mid_0[vi, vj, 0:k2]])
            tir.attrs({"sparse": True, "preprocess": True, "is_binary_search_block": True})
            for (k_1: int32, 0, (K2_indptr[(vj + 1)] - K2_indptr[vj])) {
              block([k2], "binary_search_block_0_1") as [vk] {
                bind(vk, k_1)
                tir.reads([K_indices[vi, 0:(K_indptr[(vi + 1)] - K_indptr[vi])]])
                tir.writes([mid_0[vi, vj, vk]])
                tir.attrs({"sparse": True, "preprocess": True, "is_binary_search_block": True})
                cur = alloc_buffer(int32[1])
                high = alloc_buffer(int32[1])
                 {
                  if (vk == 0) {
                    cur[0] = 0
                    high[0] = (K_indptr[(vi + 1)] - K_indptr[vi])
                  }
                  while (((K_indices[vi, cur[0]] < K2_indices[vj, vk]) && (cur[0] < high[0]))) {
                    cur[0] = (cur[0] + 1)
                  }
                  if ((K2_indices[vj, vk] == K_indices[vi, cur[0]]) && (cur[0] < high[0])) {
                    mid_0[vi, vj, vk] = cur[0]
                    cur[0] = (cur[0] + 1)
                  } else {
                    mid_0[vi, vj, vk] = -1
                  }
                }
            }
        }
      }

@yzh119
Copy link
Member Author

yzh119 commented Mar 30, 2023

@qelk123 thanks for your proposal.
I have read some sparse compiler papers using the optimization you propose, the only issue is that we cannot parallelize the loop k because cur[0] for different k is not independent of each other.

Binary search is absolutely not the best solution, a good property of binary search is the computation at each position is independent of each other, so that we safely parallelize everything outside the block, and we can even compute the mid value inline after entering stage-II by using schedule primitives such as compute_at and compute_inline. If we do not care about this property, there are many better alternatives, depending on the applications.

@qelk123
Copy link
Contributor

qelk123 commented Mar 31, 2023

Yes, you are right. If you want to achieve maximum parallelism, using binary search is reasonable. I also read your related paper,it seems the backend you considered is GPU, which provides sufficient parallelism capability. However when it comes to other backend like CPUs and FPGA, is binary search also compatiable with these backends? Maybe a more proper way is to bind the generation of this search block with specific target?
BTW, I am wondering will any binary search block be generated in your ideal co-iteration structure design in this project?

@yzh119
Copy link
Member Author

yzh119 commented Mar 31, 2023

I don't think it needs to be bounded with the target, we can just let the user decide what to use by annotating sparse iterations like what did for checking binary search validity. We might call it a search function instead of "binary search".

SparseTIR is not only designed for GPU (in the paper our evaluation focuses on GPU, but CPU backend is totally supported by selecting LLVM target), CPU also has parallelism (though the number of cores might not be as large as GPUs). FPGA backend is another story, currently TVM support FPGA backend by generating HLS code but I suppose for sparsity we have better alternatives, e.g. The Sparse Abstract Machine which generate RTL to handle co-iterations, directly.

BTW, I am wondering will any binary search block be generated in your ideal co-iteration structure design in this project?

I guess so, it depends on the schedule actually and we may inevitably emit search functions in some cases. I am trying to write some formal semantics for that, please stay tuned.

@qelk123
Copy link
Contributor

qelk123 commented Mar 31, 2023

Thank you for your reply, as you said CPU may not has so manys cores as cuda cores in GPU. In this situation the benefit from parallelism in k axis may be not so obvious, for parallelism in m and n can already fully utilize the parallelism in CPU cores.So the trade of between parallelism and complexity is different in GPU and CPU.A more flexible search way which can choose by the user is more performance friendly?

@yzh119
Copy link
Member Author

yzh119 commented Mar 31, 2023

There are more search function choices than we have discussed (binary search and the algorithm you mentioned). The selection is not only hardware-dependent but also data-dependent (the sparse structure), an ideal solution is to design cost model to predict running time.

So I think it's reasonable to make the lower decision from sparse iteration annotations, and we will have some algorithm to annotate the sparse iterations by comparing the cost generated from our cost models(like what we did in MetaSchedule for TensorIR) during compilation.

@qelk123
Copy link
Contributor

qelk123 commented Mar 31, 2023

OK I got it,it would be a nice choice.We can discuss this in detail later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants