Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hopper Support? #215

Open
IzzyPutterman opened this issue Aug 2, 2023 · 5 comments
Open

Hopper Support? #215

IzzyPutterman opened this issue Aug 2, 2023 · 5 comments

Comments

@IzzyPutterman
Copy link

Hey, I am running on latest main: d258f8f, using
"jax @ git+https://github.com/google/jax@d872812a359a3bafcfdeba1fcdb874ec77c209db",
"triton @ git+https://github.com/openai/triton@3452615d795bc0c69a189e41f1e775904e5659be#subdirectory=python"
When running on a hopper node I get the following error

    test_out = test_fn(**fn_kwargs, **extra_test_kwargs)    # My call to triton                                                                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback                                                                             
    return fun(*args, **kwargs)                                                                                                                                                                       
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/custom_derivatives.py", line 620, in __call__                                                                                                
    out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,                                                                                                                                   
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/custom_derivatives.py", line 770, in bind
    outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 840, in process_custom_vjp_call
    return fun.call_wrapped(*tracers)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 252, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 165, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 2596, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 389, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 821, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1143, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 1228, in __call__
    results = self.xla_executable.execute_sharded(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: ptxas exited with non-zero error code 65280, output: ptxas fatal   : PTX with .target 'sm_90a' cannot be compiled for architecture 'sm_90'
; current tracing scope: custom-call.7; current profiling annotation: XlaModule:#hlo_module=jit_attention,program_id=22#.

Is hopper supported?

@sharadmv
Copy link
Collaborator

sharadmv commented Aug 3, 2023

Are you using a version of Triton that supports Hopper? I personally don't have one and haven't tested it out.

@IzzyPutterman
Copy link
Author

Yep, I have been able to run torch-triton kernels on Hopper for a few months now.

@sharadmv
Copy link
Collaborator

sharadmv commented Aug 3, 2023

@chr1sj0nes do we expect this to work?

@abhinavgoel95
Copy link

From our initial tests, It seems like this PR fixes the issue.

@sharadmv
Copy link
Collaborator

Awesome, thanks for the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants