Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XlaRuntimeError when applying FusedAttention #225

Open
luise1030 opened this issue Aug 28, 2023 · 0 comments
Open

XlaRuntimeError when applying FusedAttention #225

luise1030 opened this issue Aug 28, 2023 · 0 comments

Comments

@luise1030
Copy link

luise1030 commented Aug 28, 2023

Description

Hi, while applying FusedAttention with jax-triton, we got the following XLA error happens on Nvidia-A100:

2023-08-28 03:06:51.319566: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.319790: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.319901: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320187: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320240: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.320386: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
2023-08-28 03:06:51.320465: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Shared memory requested exceeds device resources.
2023-08-28 03:06:51.320846: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/workspace/paxml/paxml/main.py", line 510, in
app.run(main, flags_parser=absl_flags.flags_parser)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/workspace/paxml/paxml/main.py", line 445, in main
_main(argv)
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 487, in _main
run(experiment_config=experiment_config,
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 420, in run
run_experiment(
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/main.py", line 285, in run_experiment
train.train_and_evaluate(
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/train.py", line 274, in train_and_evaluate
executor.start()
File "/workspace/paxml/paxml/executors.py", line 269, in start
_train_and_evaluate_common(
File "/workspace/paxml/paxml/executors.py", line 406, in _train_and_evaluate_common
program_output = train_program.run(partitioned_train_state, step_i)
File "/root/.local/lib/python3.10/site-packages/praxis/py_utils.py", line 1023, in wrapper
result = func(*args, **kwargs)
File "/workspace/paxml/paxml/programs.py", line 332, in run
new_step, new_state, train_outputs = self.train_step(
File "/workspace/paxml/paxml/programs.py", line 620, in train_step
return step + 1, *train_step(state, prng_key, inputs, static_args)
File "/workspace/paxml/paxml/trainer_lib.py", line 1634, in call
return pjitted_fn(*args)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Shared memory requested exceeds device resources.; current tracing scope: custom-call.371; current profiling annotation: XlaModule:#hlo_module=pjit__wrapped_step_fn,program_id=190#.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

Steps for reproducing:
Add model variants to /root/.local/lib/python3.10/site-packages/paxml/tasks/lm/params/nvidia.py

--- a/paxml/tasks/lm/params/nvidia.py
+++ b/paxml/tasks/lm/params/nvidia.py
@@ -350,6 +350,20 @@ class NVIDIA70BProxy(NVIDIA5B):
   MODEL_DIMS = 8192
   HIDDEN_DIMS = 4 * 8192

+@experiment_registry.register
+class test7B(NVIDIA70BProxy):
+  PERCORE_BATCH_SIZE = 16
+  MICROBATCH_SIZE = 1
+  USE_FLASH_ATTENTION = False
+  USE_TRITON_LAYER_NORM = False
+  NUM_LAYERS = 8
+  NUM_STAGES = 4
+  ICI_MESH_SHAPE = [4, 1, 1, 1]
+
+@experiment_registry.register
+class test7BFA(test7B):
+  USE_FLASH_ATTENTION = True
+  USE_TRITON_LAYER_NORM = True

 @experiment_registry.register
 class NVIDIA116BProxy(NVIDIA5B):

Run w/o FusedAttention (PASS case):
python3 -u -m paxml.main --noenable_checkpoint_saving --job_log_dir=./jax_tmp --exp=paxml.tasks.lm.params.nvidia.test7B

Run w FusedAttention (FAILED case):
python3 -u -m paxml.main --noenable_checkpoint_saving --job_log_dir=./jax_tmp --exp=paxml.tasks.lm.params.nvidia.test7BFA

Versions:

python3 -m pip install git+https://github.com/google/paxml orbax==0.1.6 --user
python3 -m pip install git+https://github.com/google/praxis --user
python3 -m pip install git+https://github.com/google/flax --user
python3 -m pip uninstall orbax orbax-checkpoint -y
python3 -m pip install git+https://github.com/google/orbax/#subdirectory=checkpoint --user

python3 -m pip uninstall triton -y
python3 -m pip install git+https://github.com/openai/triton@b24dc19##subdirectory=python

python3 -m pip uninstall jax-triton -y
python3 -m pip install git+https://github.com/jax-ml/jax-triton@4f97b83 --no-deps

python3 -m pip uninstall jax jaxlib -y
git clone https://github.com/google/jax
pushd jax
git checkout 8d80e25
#build JAXLIB
apt update -y;apt install g++ -y
python3 -m pip install numpy wheel build
python3 build/build.py --enable_cuda
#install JAX
python3 setup.py develop --user
#install JAXLIB
python3 -m pip install dist/*.whl
popd
 
## Change the used source of pallas.ops in praxis
sed -i 's/jax.experimental.pallas.ops/jax_triton.pallas.ops/g' /root/.local/lib/python3.10/site-packages/praxis/layers/gpu_fast_attention.py

NVIDIA GPU info

4 A100-SXM-80GB GPUs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant