Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] remove_unused_args incompatible with VarUseDefAnalysis in make_packed_api #98

Open
qelk123 opened this issue Apr 6, 2023 · 2 comments

Comments

@qelk123
Copy link
Contributor

qelk123 commented Apr 6, 2023

Expected behavior

params removed in remove_unused_args pass shouldn't be checked in make_packed_api pass

Actual behavior

In primfunc with dynamic behavior, Variable args used in describing an axis metadata (shape of indices buffer, indptr buffer or data buffer) are removed in remove_unused_args pass, but are required in VarUseDefAnalysis of make_packed_api pass.

Case

TVM Script kernel:

@T.prim_func
def csrmm(
    a: T.handle,
    b: T.handle,
    c: T.handle,
    indptr: T.handle,
    indices: T.handle,
    m: T.int32,
    n: T.int32,
    num_tiles: T.int32,
    nnz: T.int32,
    cwm: T.int32,
) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 2})
    I = T.dense_fixed(m)
    J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32")
    J_detach = T.dense_fixed(n)
    K1 = T.dense_fixed(num_tiles)
    K2 = T.dense_fixed(cwm)
    K3 = T.dense_fixed(32)
    A = T.match_sparse_buffer(a, (I, J), "float32")
    B = T.match_sparse_buffer(b, (J_detach, K1, K2, K3), "float32")
    C = T.match_sparse_buffer(c, (I, K1, K2, K3), "float32")
    with T.sp_iter([I, J, K1, K2, K3], "SRSSS", "csrmm") as [i, j, k1, k2, k3]:
        with T.init():
            C[i, k1, k2, k3] = 0.0
        C[i, k1, k2, k3] = C[i, k1, k2, k3] + A[i, j] * B[j, k1, k2, k3]

Primfunc before make_packed_api (decompose format with 5 buckets and 2 tile blocks)

primfn(b: handle, c: handle, num_tiles: int32, cwm: int32, a_0_0: handle, indices_i_0_0: handle, indices_j_0_0: handle, num_rows_0_0: int32, a_0_1: handle, indices_i_0_1: handle, indices_j_0_1: handle, num_rows_0_1: int32, a_0_2: handle, indices_i_0_2: handle, indices_j_0_2: handle, num_rows_0_2: int32, a_0_3: handle, indices_i_0_3: handle, indices_j_0_3: handle, num_rows_0_3: int32, a_0_4: handle, indices_i_0_4: handle, indices_j_0_4: handle, num_rows_0_4: int32, a_0_5: handle, indices_i_0_5: handle, indices_j_0_5: handle, num_rows_0_5: int32, a_1_0: handle, indices_i_1_0: handle, indices_j_1_0: handle, num_rows_1_0: int32, a_1_1: handle, indices_i_1_1: handle, indices_j_1_1: handle, num_rows_1_1: int32, a_1_2: handle, indices_i_1_2: handle, indices_j_1_2: handle, num_rows_1_2: int32, a_1_3: handle, indices_i_1_3: handle, indices_j_1_3: handle, num_rows_1_3: int32, a_1_4: handle, indices_i_1_4: handle, indices_j_1_4: handle, num_rows_1_4: int32, a_1_5: handle, indices_i_1_5: handle, indices_j_1_5: handle, num_rows_1_5: int32) -> ()
  attr = {"target": Target(kind='cuda', keys={'cuda', 'gpu'}, attrs={'thread_warp_size': 32, 'max_num_threads': 1024, 'arch': "sm_75"}, host=Target(kind='llvm', keys={'cpu'}, attrs={'link-params': (bool)0})), "tir.noalias": True, "global_symbol": "main", "composable": 1, "sparse_tir_level": 0, "tir.is_entry_func": True}
  buffers = {B_data: Buffer(B: Pointer(global float32), float32, [(((n: int32*num_tiles)*cwm)*32)], []),
             C_data: Buffer(C: Pointer(global float32), float32, [(((m: int32*num_tiles)*cwm)*32)], []),
             A_0_0_data: Buffer(A_0_0: Pointer(global float32), float32, [num_rows_0_0], []),
             I_0_0_indices_data: Buffer(I_0_0_indices.data: Pointer(global int32), int32, [num_rows_0_0], []),
             J_0_0_indices_data: Buffer(J_0_0_indices.data: Pointer(global int32), int32, [num_rows_0_0], []),
             ...
  buffer_map = {...} {
	...
  }

From this example we can see m and n are removed from arg list, since they are only used in describing the shape of B_data buffer and C_data and some block read region and write region.All of these place are not included in remove_unused_args pass, so they are removed from the arg list.

However, these args are checked in VarUseDefAnalysis pass and used to check DLTensor shape of B_data&&C_data, and cause the error message:

TVMError: Not all Vars are passed in api_args: 'n' 'm' is not bound to any variables

possiable solutions

Keep var used in BufferNode shape and stride field in the arg list. We can modify the solution in VarUseDefAnalysis pass for BufferNode.

Also,this solution won't influence current case,since we remove these vars by specializing them before remove_unused_args pass with constant value.

@yzh119
Copy link
Member

yzh119 commented Apr 6, 2023

I have fixed this in a PR to TVM upstream: #apache/tvm#14502
Let's fix the behavior when upstreaming :)

@qelk123
Copy link
Contributor Author

qelk123 commented Apr 6, 2023

@yzh119 Do you mean simply bypass the var checking in buffer shape and stride field in VarUseDefAnalysis pass?
If so, I am wondering will it cause an assertion error when checking B_data or C_data buffer shape during excution,or it can inversely derive the value of these vars using actual buffer shape even the shape is calculated by different symbolic vars. In other word, can we just treat these symbolic var with T.int32 type as a var declared by T.var("int32") in the body of TVM Script function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants