Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes required for AMD/rocm support with Triton #257

Open
reachtarunhere opened this issue Jan 22, 2024 · 0 comments
Open

Changes required for AMD/rocm support with Triton #257

reachtarunhere opened this issue Jan 22, 2024 · 0 comments

Comments

@reachtarunhere
Copy link

There is now support for flash-attention2 on AMD GPUs with PyTorch. They use the triton kernels for the same.

https://github.com/ROCmSoftwarePlatform/flash-attention

JAX-Triton currently doesn't work. On trying the add example I get the following error which I suspect is due to some CUDA specific things in the triton_lib.py

I can run other tests etc. that are requested here to help make progress on this.

(/jax_miniconda) Singularity> python add.py 
2024-01-22 13:12:41.578159: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
/usr/share/libdrm/amdgpu.ids: No such file or directory
2024-01-22 13:12:46.635298: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: //
// Generated by LLVM NVPTX Back-End
//

.version 8.2
.target sm_90a
.address_size 64

	// .globl	add_kernel_0d1d2d

.visible .entry add_kernel_0d1d2d(
	.param .u64 add_kernel_0d1d2d_param_0,
	.param .u64 add_kernel_0d1d2d_param_1,
	.param .u64 add_kernel_0d1d2d_param_2
)
.maxntid 128, 1, 1
{
	.reg .pred 	%p<5>;
	.reg .b32 	%r<10>;
	.reg .b64 	%rd<8>;
	.loc	1 24 0
$L__func_begin0:
	.loc	1 24 0

	ld.param.u64 	%rd4, [add_kernel_0d1d2d_param_0];
	ld.param.u64 	%rd5, [add_kernel_0d1d2d_param_1];
$L__tmp0:
	.loc	1 33 39
	mov.u32 	%r5, %tid.x;
	and.b32  	%r6, %r5, 7;
	ld.param.u64 	%rd6, [add_kernel_0d1d2d_param_2];
	.loc	1 31 22
	mov.u32 %r1, %ctaid.x;
	.loc	1 32 22
	shl.b32 	%r7, %r1, 3;
	.loc	1 33 26
	or.b32  	%r8, %r7, %r6;
	.loc	1 34 19
	setp.lt.s32 	%p1, %r8, 8;
	.loc	1 35 22
	mul.wide.s32 	%rd7, %r8, 4;
	add.s64 	%rd1, %rd4, %rd7;
	.loc	1 35 14
	mov.u32 %r2, 0x0;
	@%p1 ld.global.b32 { %r2 }, [ %rd1 + 0 ];
	.loc	1 36 22
	add.s64 	%rd2, %rd5, %rd7;
	.loc	1 36 14
	mov.u32 %r3, 0x0;
	@%p1 ld.global.b32 { %r3 }, [ %rd2 + 0 ];
	.loc	1 37 15
	add.s32 	%r4, %r3, %r2;
	.loc	1 38 24
	add.s64 	%rd3, %rd6, %rd7;
	.loc	1 38 33
	and.b32  	%r9, %r5, 120;
	setp.eq.s32 	%p4, %r9, 0;
	and.pred  	%p3, %p4, %p1;
	@%p3 st.global.b32 [ %rd3 + 0 ], { %r4 };
	.loc	1 38 2
	ret;
$L__tmp1:
$L__func_end0:

}
	.file	1 "/jax_miniconda/add.py"
	.section	.debug_abbrev
	{
.b8 1
.b8 17
.b8 1
.b8 37
.b8 8
.b8 19
.b8 5
.b8 3
.b8 8
.b8 16
.b8 6
.b8 27
.b8 8
.b8 180
.b8 66
.b8 12
.b8 17
.b8 1
.b8 18
.b8 1
.b8 0
.b8 0
.b8 2
.b8 46
.b8 0
.b8 17
.b8 1
.b8 18
.b8 1
.b8 64
.b8 10
.b8 135
.b8 64
.b8 8
.b8 3
.b8 8
.b8 58
.b8 11
.b8 59
.b8 11
.b8 63
.b8 12
.b8 0
.b8 0
.b8 0
	}
	.section	.debug_info
	{
.b32 119
.b8 2
.b8 0
.b32 .debug_abbrev
.b8 8
.b8 1
.b8 116
.b8 114
.b8 105
.b8 116
.b8 111
.b8 110
.b8 0
.b8 2
.b8 0
.b8 97
.b8 100
.b8 100
.b8 46
.b8 112
.b8 121
.b8 0
.b32 .debug_line
.b8 47
.b8 106
.b8 97
.b8 120
.b8 95
.b8 109
.b8 105
.b8 110
.b8 105
.b8 99
.b8 111
.b8 110
.b8 100
.b8 97
.b8 0
.b8 1
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 2
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 1
.b8 156
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 1
.b8 24
.b8 1
.b8 0
	}
	.section	.debug_pubnames
	{
.b32 $L__pubNames_end0-$L__pubNames_start0
$L__pubNames_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 64
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b32 0
$L__pubNames_end0:
	}
	.section	.debug_pubtypes
	{
.b32 $L__pubTypes_end0-$L__pubTypes_start0
$L__pubTypes_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 0
$L__pubTypes_end0:
	}
	.section	.debug_loc	{	}
; No such file or directory
2024-01-22 13:12:46.635704: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2716] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: //
// Generated by LLVM NVPTX Back-End
//

.version 8.2
.target sm_90a
.address_size 64

	// .globl	add_kernel_0d1d2d

.visible .entry add_kernel_0d1d2d(
	.param .u64 add_kernel_0d1d2d_param_0,
	.param .u64 add_kernel_0d1d2d_param_1,
	.param .u64 add_kernel_0d1d2d_param_2
)
.maxntid 128, 1, 1
{
	.reg .pred 	%p<5>;
	.reg .b32 	%r<10>;
	.reg .b64 	%rd<8>;
	.loc	1 24 0
$L__func_begin0:
	.loc	1 24 0

	ld.param.u64 	%rd4, [add_kernel_0d1d2d_param_0];
	ld.param.u64 	%rd5, [add_kernel_0d1d2d_param_1];
$L__tmp0:
	.loc	1 33 39
	mov.u32 	%r5, %tid.x;
	and.b32  	%r6, %r5, 7;
	ld.param.u64 	%rd6, [add_kernel_0d1d2d_param_2];
	.loc	1 31 22
	mov.u32 %r1, %ctaid.x;
	.loc	1 32 22
	shl.b32 	%r7, %r1, 3;
	.loc	1 33 26
	or.b32  	%r8, %r7, %r6;
	.loc	1 34 19
	setp.lt.s32 	%p1, %r8, 8;
	.loc	1 35 22
	mul.wide.s32 	%rd7, %r8, 4;
	add.s64 	%rd1, %rd4, %rd7;
	.loc	1 35 14
	mov.u32 %r2, 0x0;
	@%p1 ld.global.b32 { %r2 }, [ %rd1 + 0 ];
	.loc	1 36 22
	add.s64 	%rd2, %rd5, %rd7;
	.loc	1 36 14
	mov.u32 %r3, 0x0;
	@%p1 ld.global.b32 { %r3 }, [ %rd2 + 0 ];
	.loc	1 37 15
	add.s32 	%r4, %r3, %r2;
	.loc	1 38 24
	add.s64 	%rd3, %rd6, %rd7;
	.loc	1 38 33
	and.b32  	%r9, %r5, 120;
	setp.eq.s32 	%p4, %r9, 0;
	and.pred  	%p3, %p4, %p1;
	@%p3 st.global.b32 [ %rd3 + 0 ], { %r4 };
	.loc	1 38 2
	ret;
$L__tmp1:
$L__func_end0:

}
	.file	1 "/jax_miniconda/add.py"
	.section	.debug_abbrev
	{
.b8 1
.b8 17
.b8 1
.b8 37
.b8 8
.b8 19
.b8 5
.b8 3
.b8 8
.b8 16
.b8 6
.b8 27
.b8 8
.b8 180
.b8 66
.b8 12
.b8 17
.b8 1
.b8 18
.b8 1
.b8 0
.b8 0
.b8 2
.b8 46
.b8 0
.b8 17
.b8 1
.b8 18
.b8 1
.b8 64
.b8 10
.b8 135
.b8 64
.b8 8
.b8 3
.b8 8
.b8 58
.b8 11
.b8 59
.b8 11
.b8 63
.b8 12
.b8 0
.b8 0
.b8 0
	}
	.section	.debug_info
	{
.b32 119
.b8 2
.b8 0
.b32 .debug_abbrev
.b8 8
.b8 1
.b8 116
.b8 114
.b8 105
.b8 116
.b8 111
.b8 110
.b8 0
.b8 2
.b8 0
.b8 97
.b8 100
.b8 100
.b8 46
.b8 112
.b8 121
.b8 0
.b32 .debug_line
.b8 47
.b8 106
.b8 97
.b8 120
.b8 95
.b8 109
.b8 105
.b8 110
.b8 105
.b8 99
.b8 111
.b8 110
.b8 100
.b8 97
.b8 0
.b8 1
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 2
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 1
.b8 156
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 1
.b8 24
.b8 1
.b8 0
	}
	.section	.debug_pubnames
	{
.b32 $L__pubNames_end0-$L__pubNames_start0
$L__pubNames_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 64
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b32 0
$L__pubNames_end0:
	}
	.section	.debug_pubtypes
	{
.b32 $L__pubTypes_end0-$L__pubTypes_start0
$L__pubTypes_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 0
$L__pubTypes_end0:
	}
	.section	.debug_loc	{	}
; No such file or directory; current tracing scope: custom-call.3; current profiling annotation: XlaModule:#prefix=jit(triton_kernel_call)/jit(main)/triton_kernel_call[fn=JITFunction(__main__:add_kernel) scalar_args=() name= custom_call_target_name=triton_kernel_call out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=int32),) grid=(1,) num_warps=None num_stages=None num_ctas=1 enable_fp_fusion=True enable_warp_specialization=False enable_persistent=False input_output_aliases=() zeroed_outputs=() debug=False serialized_metadata=b'' block_size=8],hlo_module=jit_triton_kernel_call,program_id=2#.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/jax_miniconda/add.py", line 56, in <module>
    print(add(x_val, y_val))
  File "/jax_miniconda/add.py", line 44, in add
    return jt.triton_call(
  File "/jax_miniconda/lib/python3.10/site-packages/jax_triton/triton_lib.py", line 681, in triton_call
    out_flat = triton_kernel_call_p.bind(
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/core.py", line 402, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/core.py", line 405, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/core.py", line 893, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/users/tavangani/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: //
// Generated by LLVM NVPTX Back-End
//

.version 8.2
.target sm_90a
.address_size 64

	// .globl	add_kernel_0d1d2d

.visible .entry add_kernel_0d1d2d(
	.param .u64 add_kernel_0d1d2d_param_0,
	.param .u64 add_kernel_0d1d2d_param_1,
	.param .u64 add_kernel_0d1d2d_param_2
)
.maxntid 128, 1, 1
{
	.reg .pred 	%p<5>;
	.reg .b32 	%r<10>;
	.reg .b64 	%rd<8>;
	.loc	1 24 0
$L__func_begin0:
	.loc	1 24 0

	ld.param.u64 	%rd4, [add_kernel_0d1d2d_param_0];
	ld.param.u64 	%rd5, [add_kernel_0d1d2d_param_1];
$L__tmp0:
	.loc	1 33 39
	mov.u32 	%r5, %tid.x;
	and.b32  	%r6, %r5, 7;
	ld.param.u64 	%rd6, [add_kernel_0d1d2d_param_2];
	.loc	1 31 22
	mov.u32 %r1, %ctaid.x;
	.loc	1 32 22
	shl.b32 	%r7, %r1, 3;
	.loc	1 33 26
	or.b32  	%r8, %r7, %r6;
	.loc	1 34 19
	setp.lt.s32 	%p1, %r8, 8;
	.loc	1 35 22
	mul.wide.s32 	%rd7, %r8, 4;
	add.s64 	%rd1, %rd4, %rd7;
	.loc	1 35 14
	mov.u32 %r2, 0x0;
	@%p1 ld.global.b32 { %r2 }, [ %rd1 + 0 ];
	.loc	1 36 22
	add.s64 	%rd2, %rd5, %rd7;
	.loc	1 36 14
	mov.u32 %r3, 0x0;
	@%p1 ld.global.b32 { %r3 }, [ %rd2 + 0 ];
	.loc	1 37 15
	add.s32 	%r4, %r3, %r2;
	.loc	1 38 24
	add.s64 	%rd3, %rd6, %rd7;
	.loc	1 38 33
	and.b32  	%r9, %r5, 120;
	setp.eq.s32 	%p4, %r9, 0;
	and.pred  	%p3, %p4, %p1;
	@%p3 st.global.b32 [ %rd3 + 0 ], { %r4 };
	.loc	1 38 2
	ret;
$L__tmp1:
$L__func_end0:

}
	.file	1 "/jax_miniconda/add.py"
	.section	.debug_abbrev
	{
.b8 1
.b8 17
.b8 1
.b8 37
.b8 8
.b8 19
.b8 5
.b8 3
.b8 8
.b8 16
.b8 6
.b8 27
.b8 8
.b8 180
.b8 66
.b8 12
.b8 17
.b8 1
.b8 18
.b8 1
.b8 0
.b8 0
.b8 2
.b8 46
.b8 0
.b8 17
.b8 1
.b8 18
.b8 1
.b8 64
.b8 10
.b8 135
.b8 64
.b8 8
.b8 3
.b8 8
.b8 58
.b8 11
.b8 59
.b8 11
.b8 63
.b8 12
.b8 0
.b8 0
.b8 0
	}
	.section	.debug_info
	{
.b32 119
.b8 2
.b8 0
.b32 .debug_abbrev
.b8 8
.b8 1
.b8 116
.b8 114
.b8 105
.b8 116
.b8 111
.b8 110
.b8 0
.b8 2
.b8 0
.b8 97
.b8 100
.b8 100
.b8 46
.b8 112
.b8 121
.b8 0
.b32 .debug_line
.b8 47
.b8 106
.b8 97
.b8 120
.b8 95
.b8 109
.b8 105
.b8 110
.b8 105
.b8 99
.b8 111
.b8 110
.b8 100
.b8 97
.b8 0
.b8 1
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 2
.b64 $L__func_begin0
.b64 $L__func_end0
.b8 1
.b8 156
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b8 1
.b8 24
.b8 1
.b8 0
	}
	.section	.debug_pubnames
	{
.b32 $L__pubNames_end0-$L__pubNames_start0
$L__pubNames_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 64
.b8 97
.b8 100
.b8 100
.b8 95
.b8 107
.b8 101
.b8 114
.b8 110
.b8 101
.b8 108
.b8 95
.b8 48
.b8 100
.b8 49
.b8 100
.b8 50
.b8 100
.b8 0
.b32 0
$L__pubNames_end0:
	}
	.section	.debug_pubtypes
	{
.b32 $L__pubTypes_end0-$L__pubTypes_start0
$L__pubTypes_start0:
.b8 2
.b8 0
.b32 .debug_info
.b32 123
.b32 0
$L__pubTypes_end0:
	}
	.section	.debug_loc	{	}
; No such file or directory; current tracing scope: custom-call.3; current profiling annotation: XlaModule:#prefix=jit(triton_kernel_call)/jit(main)/triton_kernel_call[fn=JITFunction(__main__:add_kernel) scalar_args=() name= custom_call_target_name=triton_kernel_call out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=int32),) grid=(1,) num_warps=None num_stages=None num_ctas=1 enable_fp_fusion=True enable_warp_specialization=False enable_persistent=False input_output_aliases=() zeroed_outputs=() debug=False serialized_metadata=b'' block_size=8],hlo_module=jit_triton_kernel_call,program_id=2#.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant