-
Notifications
You must be signed in to change notification settings - Fork 340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPU] Add new flag xla_gpu_exclude_nondeterministic_ops. #12520
Conversation
Imported from GitHub PR openxla/xla#12520 It's more granular than the existing --xla_gpu_deterministic_ops because it allows doing an autotuning compilation with non-deterministic ops disabled. --xla_gpu_deterministic_ops is a superset of --xla_gpu_exclude_nondeterministic_ops, so --xla_gpu_deterministic_ops=true will be setting --xla_gpu_exclude_nondeterministic_ops=true too. Copybara import of the project: -- f47f18016777468fe274bea00945d5209a2cdb57 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Add new flag xla_gpu_exclude_nondeterministic_ops. Merging this change closes #12520 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12520 from openxla:new_determinism_flag f47f18016777468fe274bea00945d5209a2cdb57 PiperOrigin-RevId: 634336994
This breaks one of TensorFlow's deterministic ops tests. I'm looking into it. |
@sergachev Could you please rebase this PR, as it has conflicting changes at the moment. |
f47f180
to
4e28374
Compare
Done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving the rebase
Imported from GitHub PR openxla/xla#12520 It's more granular than the existing --xla_gpu_deterministic_ops because it allows doing an autotuning compilation with non-deterministic ops disabled. --xla_gpu_deterministic_ops is a superset of --xla_gpu_exclude_nondeterministic_ops, so --xla_gpu_deterministic_ops=true will be setting --xla_gpu_exclude_nondeterministic_ops=true too. Copybara import of the project: -- 4e2837457dc426154bf80f321c001c916a7d3677 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Add new flag xla_gpu_exclude_nondeterministic_ops. Merging this change closes #12520 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12520 from openxla:new_determinism_flag 4e2837457dc426154bf80f321c001c916a7d3677 PiperOrigin-RevId: 634336994
Imported from GitHub PR openxla/xla#12520 It's more granular than the existing --xla_gpu_deterministic_ops because it allows doing an autotuning compilation with non-deterministic ops disabled. --xla_gpu_deterministic_ops is a superset of --xla_gpu_exclude_nondeterministic_ops, so --xla_gpu_deterministic_ops=true will be setting --xla_gpu_exclude_nondeterministic_ops=true too. Copybara import of the project: -- 4e2837457dc426154bf80f321c001c916a7d3677 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Add new flag xla_gpu_exclude_nondeterministic_ops. Merging this change closes #12520 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12520 from openxla:new_determinism_flag 4e2837457dc426154bf80f321c001c916a7d3677 PiperOrigin-RevId: 634336994
This change fails this OSS test (running on Linux): https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/kernel_tests/nn_ops/cudnn_deterministic_ops_test.py The failure is miscomparison in numerics, I guess something in the way the determinism is being controlled doesn't work out. Interestingly, running the same test in our internal build environment passes. I'll have a look at this, but if you have any ideas in the meantime, please let me know. |
4e28374
to
9a1688f
Compare
Maybe I know what the problem is - after reading it again I think this line xla/xla/service/gpu/stream_executor_util.cc Line 627 in b17918d
|
9a1688f
to
68b555f
Compare
Removed that change and updated the comments. |
Imported from GitHub PR #12520 It's more granular than the existing --xla_gpu_deterministic_ops because it allows doing an autotuning compilation with non-deterministic ops disabled. --xla_gpu_deterministic_ops is a superset of --xla_gpu_exclude_nondeterministic_ops, so --xla_gpu_deterministic_ops=true will be setting --xla_gpu_exclude_nondeterministic_ops=true too. Copybara import of the project: -- 4e28374 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Add new flag xla_gpu_exclude_nondeterministic_ops. Merging this change closes #12520 FUTURE_COPYBARA_INTEGRATE_REVIEW=#12520 from openxla:new_determinism_flag 4e28374 PiperOrigin-RevId: 634721436
Thanks. I'm not sure if that was the issue. In any case there is another which I'm fixing internally:
|
Imported from GitHub PR #12520 It's more granular than the existing --xla_gpu_deterministic_ops because it allows doing an autotuning compilation with non-deterministic ops disabled. --xla_gpu_deterministic_ops is a superset of --xla_gpu_exclude_nondeterministic_ops, so --xla_gpu_deterministic_ops=true will be setting --xla_gpu_exclude_nondeterministic_ops=true too. Copybara import of the project: -- 4e28374 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Add new flag xla_gpu_exclude_nondeterministic_ops. Merging this change closes #12520 FUTURE_COPYBARA_INTEGRATE_REVIEW=#12520 from openxla:new_determinism_flag 4e28374 PiperOrigin-RevId: 634721436
Imported from GitHub PR openxla/xla#12520 It's more granular than the existing --xla_gpu_deterministic_ops because it allows doing an autotuning compilation with non-deterministic ops disabled. --xla_gpu_deterministic_ops is a superset of --xla_gpu_exclude_nondeterministic_ops, so --xla_gpu_deterministic_ops=true will be setting --xla_gpu_exclude_nondeterministic_ops=true too. Copybara import of the project: -- 68b555ff66f299423a8de6aef595cf38f621976f by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Add new flag xla_gpu_exclude_nondeterministic_ops. Merging this change closes #12520 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12520 from openxla:new_determinism_flag 68b555ff66f299423a8de6aef595cf38f621976f PiperOrigin-RevId: 634336994
Imported from GitHub PR openxla/xla#12520 It's more granular than the existing --xla_gpu_deterministic_ops because it allows doing an autotuning compilation with non-deterministic ops disabled. --xla_gpu_deterministic_ops is a superset of --xla_gpu_exclude_nondeterministic_ops, so --xla_gpu_deterministic_ops=true will be setting --xla_gpu_exclude_nondeterministic_ops=true too. Copybara import of the project: -- 4e2837457dc426154bf80f321c001c916a7d3677 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Add new flag xla_gpu_exclude_nondeterministic_ops. Merging this change closes #12520 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12520 from openxla:new_determinism_flag 4e2837457dc426154bf80f321c001c916a7d3677 PiperOrigin-RevId: 634721436
Imported from GitHub PR openxla/xla#12520 It's more granular than the existing --xla_gpu_deterministic_ops because it allows doing an autotuning compilation with non-deterministic ops disabled. --xla_gpu_deterministic_ops is a superset of --xla_gpu_exclude_nondeterministic_ops, so --xla_gpu_deterministic_ops=true will be setting --xla_gpu_exclude_nondeterministic_ops=true too. Copybara import of the project: -- 68b555ff66f299423a8de6aef595cf38f621976f by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Add new flag xla_gpu_exclude_nondeterministic_ops. Merging this change closes #12520 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12520 from openxla:new_determinism_flag 68b555ff66f299423a8de6aef595cf38f621976f PiperOrigin-RevId: 634336994
Imported from GitHub PR openxla/xla#12520 It's more granular than the existing --xla_gpu_deterministic_ops because it allows doing an autotuning compilation with non-deterministic ops disabled. --xla_gpu_deterministic_ops is a superset of --xla_gpu_exclude_nondeterministic_ops, so --xla_gpu_deterministic_ops=true will be setting --xla_gpu_exclude_nondeterministic_ops=true too. Copybara import of the project: -- 68b555ff66f299423a8de6aef595cf38f621976f by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Add new flag xla_gpu_exclude_nondeterministic_ops. Merging this change closes #12520 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12520 from openxla:new_determinism_flag 68b555ff66f299423a8de6aef595cf38f621976f PiperOrigin-RevId: 634336994
Imported from GitHub PR openxla/xla#12520 It's more granular than the existing --xla_gpu_deterministic_ops because it allows doing an autotuning compilation with non-deterministic ops disabled. --xla_gpu_deterministic_ops is a superset of --xla_gpu_exclude_nondeterministic_ops, so --xla_gpu_deterministic_ops=true will be setting --xla_gpu_exclude_nondeterministic_ops=true too. Copybara import of the project: -- 4e2837457dc426154bf80f321c001c916a7d3677 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Add new flag xla_gpu_exclude_nondeterministic_ops. Merging this change closes #12520 PiperOrigin-RevId: 634756524
It's more granular than the existing --xla_gpu_deterministic_ops because it allows doing an autotuning compilation with non-deterministic ops disabled.
--xla_gpu_deterministic_ops is a superset of --xla_gpu_exclude_nondeterministic_ops, so --xla_gpu_deterministic_ops=true will be setting --xla_gpu_exclude_nondeterministic_ops=true too.