Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

SobolEngine.draw does not respect the default dtype and always uses the passed in dtype (defaulted to float32) #126478

Closed
saitcakmak opened this issue May 16, 2024 · 1 comment
Assignees
Labels
module: python frontend For issues relating to PyTorch's Python frontend module: random Related to random number generation in PyTorch (rng generator) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@saitcakmak
Copy link
Contributor

saitcakmak commented May 16, 2024

馃悰 Describe the bug

Issue description

If the SobolEngine is initialized after torch.set_default_type(torch.float64), SobolEngine.draw ignores dtype argument and instead returns samples with the default dtype .

Works as expected with torch.set_default_type(torch.float32)

>>> import torch
>>> from torch.quasirandom import SobolEngine
>>> torch.set_default_dtype(torch.float32)
>>> sobol_engine = SobolEngine(dimension=10, scramble=True, seed=0)
>>> sobol_engine.draw(n=5, dtype=torch.float32).dtype
torch.float32
>>> sobol_engine.draw(n=5, dtype=torch.float64).dtype
torch.float64

Still works as expected if we update the default dtype but keep the previous SobolEngine instance

>>> torch.set_default_dtype(torch.float64)
>>> sobol_engine.draw(n=5, dtype=torch.float32).dtype
torch.float32
>>> sobol_engine.draw(n=5, dtype=torch.float64).dtype
torch.float64

dtype is ignored if SobolEngine is initialized after torch.set_default_type(torch.float64)

>>> sobol_engine = SobolEngine(dimension=10, scramble=True, seed=0)
>>> sobol_engine.draw(n=5, dtype=torch.float32).dtype
torch.float64

Expected behavior

SobolEngine(...).draw(n=n, dtype=dtype) should always produce samples with the provided dtype.

Other proposed improvements

Currently, the dtype argument defaults to torch.float32.

class SobolEngine:
    ...
    def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
             dtype: torch.dtype = torch.float32) -> torch.Tensor:
    ...

We can update default argument to None and produce samples with dtype=torch.get_default_dtype().

class SobolEngine:
    ...
    def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
             dtype: Optional[torch.dtype] = None) -> torch.Tensor:
    ...

Versions

PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.10.14 (main, May 6 2024, 14:42:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] torch==2.3.0
[conda] pytorch 2.3.0 py3.10_0 pytorch

cc @pbelevich @albanD

@saitcakmak
Copy link
Contributor Author

cc @Balandat

@mikaylagawarecki mikaylagawarecki added module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module triage review labels May 20, 2024
@albanD albanD changed the title SobolEngine.draw ignores dtype argument when default dtype is float64 SobolEngine.draw does not respect the default dtype and always uses the passed in dtype (defaulted to float32) May 20, 2024
@drisspg drisspg added the module: random Related to random number generation in PyTorch (rng generator) label May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: python frontend For issues relating to PyTorch's Python frontend module: random Related to random number generation in PyTorch (rng generator) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants