You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using torch backend with MeanIoU or its child class, the GPU utilization is low. It does not increase with batch size either. Also the process runs very slow. On the other hand, if the metrics is changed to accuracy the process runs faster and utilizes full GPU. I have also tested the MeanIoU (and OneHotMeanIoU) with jax backend, where the process runs faster and utilize full GPU with same batch size.
GPU utilization when using torch as backend is ~85% for 1st case (i.e. running with accuracy as the metrics). Time taken is 15 seconds.
For other two cases (i.e. MeanIoU and OneHotMeanIoU), GPU utilization is ~12%. Time taken are:
- Case 2 (MeanIoU): 8 hours, 26 minutes, 53 seconds.
- Case 3 (OneHotMeanIoU): 8 hours, 27 minutes, 43 seconds.
GPU utilization when using jax as backend is ~85% for all 3 cases observed. Time taken are:
- Case 1 (accuracy): 38 seconds
- Case 2 (MeanIoU): 28 seconds
- Case 3 (OneHotMeanIoU): 27 seconds
Errors/Warnings noticed in logs
When running with jax:
- In the beginning: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
- In Case 1 : E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng0{} for conv (f32[16,32,2,2]{3,2,1,0}, u8[0]{0}) custom-call(f32[16,32,256,256]{3,2,1,0}, f32[32,32,255,255]{3,2,1,0}), window={size=255x255 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
- In Case 2 & 3 : UserWarning: Explicitly requested dtype int64 requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return jnp.array(x, dtype=dtype)
However, all data fed into models are in float32 datatype.
When running with pytorch: None
Version details:
OS: Red Hat Enterprise Linux 9.3 (Plow)
GPU: NVIDIA H100 PCIe
CUDA Version: 12.3
NVIDIA-SMI 545.23.08
Driver Version: 545.23.08
torch: 2.2.2
jax: 0.4.26
The text was updated successfully, but these errors were encountered:
The code not able to run on colab due to requirement of higher RAM. If possible could you simplify the code for reproduction of the reported behaviour. Thanks!
Problem
When using
torch
backend withMeanIoU
or its child class, the GPU utilization is low. It does not increase with batch size either. Also the process runs very slow. On the other hand, if the metrics is changed toaccuracy
the process runs faster and utilizes full GPU. I have also tested theMeanIoU
(andOneHotMeanIoU
) withjax
backend, where the process runs faster and utilize full GPU with same batch size.Code to reproduce
https://gist.github.com/savindi-wijenayaka/43da7ac5930afc3ffbf20686ecca1193
Observations
GPU utilization when using
torch
as backend is~85%
for 1st case (i.e. running withaccuracy
as the metrics). Time taken is15 seconds
.For other two cases (i.e.
MeanIoU
andOneHotMeanIoU
), GPU utilization is~12%
. Time taken are:GPU utilization when using
jax
as backend is~85%
for all 3 cases observed. Time taken are:Errors/Warnings noticed in logs
When running with
jax
:However, all data fed into models are in float32 datatype.
When running with
pytorch
: NoneVersion details:
The text was updated successfully, but these errors were encountered: