Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MPS device when available #951

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open

Use MPS device when available #951

wants to merge 32 commits into from

Conversation

araffin
Copy link
Member

@araffin araffin commented Jul 4, 2022

Description

Add support for MPS device (uses it if available) and save cloudpickle version (important to debug saving/loading issues).

DO NOT MERGE: this PR must be tested on a MPS device first

closes #914

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

@araffin
Copy link
Member Author

araffin commented Aug 13, 2022

@qgallouedec could you test this PR (do make pytest) on a MPS enabled machine? (best would be to test sb3 contrib too)

We should probably add a warning in the doc about the minimum pytorch version? (or in the code)

@qgallouedec
Copy link
Collaborator

Not only the pytest failed, but it caused a Python Fatal Error:

(env) quentingallouedec@MacBook-Pro-de-Quentin stable-baselines3 % pytest tests/test_cnn.py 
=========================================== test session starts ============================================
platform darwin -- Python 3.9.13, pytest-7.1.2, pluggy-1.0.0
rootdir: /Users/quentingallouedec/stable-baselines3, configfile: setup.cfg
plugins: xdist-2.5.0, forked-1.4.0, env-0.6.2, typeguard-2.13.3, cov-3.0.0
collected 14 items                                                                                         

tests/test_cnn.py Fatal Python error: Aborted

Current thread 0x0000000101308580 (most recent call first):
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 453 in _conv_forward
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 457 in forward
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139 in forward
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/torch_layers.py", line 93 in forward
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 129 in extract_features
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 588 in forward
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 167 in collect_rollouts
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 248 in learn
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/a2c/a2c.py", line 197 in learn
  File "/Users/quentingallouedec/stable-baselines3/tests/test_cnn.py", line 33 in test_cnn
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/python.py", line 192 in pytest_pyfunc_call
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/python.py", line 1761 in runtest
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 166 in pytest_runtest_call
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 259 in <lambda>
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 338 in from_call
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 258 in call_runtest_hook
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 219 in call_and_report
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 130 in runtestprotocol
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 111 in pytest_runtest_protocol
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 347 in pytest_runtestloop
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 322 in _main
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 268 in wrap_session
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 315 in pytest_cmdline_main
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/config/__init__.py", line 164 in main
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/config/__init__.py", line 187 in console_main
  File "/Users/quentingallouedec/stable-baselines3/env/bin/pytest", line 8 in <module>
zsh: abort      pytest tests/test_cnn.py

Don't know what it is. I will investigate.

@qgallouedec
Copy link
Collaborator

Well, I'm pretty sure the problem comes from the fact that the observation is transposed before passing into the CNN of the feature extractor, and this seems to cause some more bugs: pytorch/pytorch#81557

To reproduce:

from stable_baselines3 import A2C
from stable_baselines3.common.envs import FakeImageEnv

env = FakeImageEnv()
model = A2C("CnnPolicy", env).learn(250)

It causes fatal error in this line:

return self.linear(self.cnn(observations))

without traceback, but with this error message:

Assertion failed: (mapIt != _jitValueTypes.end()), function getStaticType, file MPSRuntime_Project.h, line 435.
zsh: abort      /Users/quentingallouedec/stable-baselines3/env/bin/python 

But more generally, there are still some features missing, such as support for the multinomial distribution (pytorch/pytorch#80760) for SB3 to work fully on the mps device

So we still have to be a bit patient.

@araffin
Copy link
Member Author

araffin commented Aug 16, 2022

Thanks for testing =)

@qgallouedec
Copy link
Collaborator

Pytorch 1.13 is out. MPS is still not fully supported and causes bugs in SB3.
To keep track of MPS op coverage, see pytorch/pytorch#77764

@kulinseth
Copy link

@qgallouedec , can you please provide which Ops are missing ?
Also if there is any Functional issue , can you provide a repro case? We will take a look.

@kulinseth
Copy link

Not only the pytest failed, but it caused a Python Fatal Error:

(env) quentingallouedec@MacBook-Pro-de-Quentin stable-baselines3 % pytest tests/test_cnn.py 
=========================================== test session starts ============================================
platform darwin -- Python 3.9.13, pytest-7.1.2, pluggy-1.0.0
rootdir: /Users/quentingallouedec/stable-baselines3, configfile: setup.cfg
plugins: xdist-2.5.0, forked-1.4.0, env-0.6.2, typeguard-2.13.3, cov-3.0.0
collected 14 items                                                                                         

tests/test_cnn.py Fatal Python error: Aborted

Current thread 0x0000000101308580 (most recent call first):
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 453 in _conv_forward
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 457 in forward
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139 in forward
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/torch_layers.py", line 93 in forward
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 129 in extract_features
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 588 in forward
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 167 in collect_rollouts
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 248 in learn
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/a2c/a2c.py", line 197 in learn
  File "/Users/quentingallouedec/stable-baselines3/tests/test_cnn.py", line 33 in test_cnn
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/python.py", line 192 in pytest_pyfunc_call
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/python.py", line 1761 in runtest
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 166 in pytest_runtest_call
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 259 in <lambda>
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 338 in from_call
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 258 in call_runtest_hook
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 219 in call_and_report
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 130 in runtestprotocol
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 111 in pytest_runtest_protocol
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 347 in pytest_runtestloop
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 322 in _main
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 268 in wrap_session
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 315 in pytest_cmdline_main
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/config/__init__.py", line 164 in main
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/config/__init__.py", line 187 in console_main
  File "/Users/quentingallouedec/stable-baselines3/env/bin/pytest", line 8 in <module>
zsh: abort      pytest tests/test_cnn.py

Don't know what it is. I will investigate.

Is this still happening in latest nightly cc @qgallouedec ?

@qgallouedec
Copy link
Collaborator

qgallouedec commented Nov 17, 2022

With the latest nightly:

% /Users/quentingallouedec/stable-baselines3/env/bin/python /Users/quentingallouedec/stable-baselines3/test_mps.py
[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.
Traceback (most recent call last):
  File "/Users/quentingallouedec/stable-baselines3/test_mps.py", line 5, in <module>
    model = A2C("CnnPolicy", env).learn(250)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/a2c/a2c.py", line 193, in learn
    return super().learn(
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 248, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 166, in collect_rollouts
    actions, values, log_probs = self.policy(obs_tensor)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1427, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 576, in forward
    log_prob = distribution.log_prob(actions)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/distributions.py", line 279, in log_prob
    return self.distribution.log_prob(actions)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/categorical.py", line 123, in log_prob
    self._validate_sample(value)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/distribution.py", line 298, in _validate_sample
    valid = support.check(value)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/constraints.py", line 257, in check
    return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
NotImplementedError: The operator 'aten::remainder.Tensor_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

EDIT: tested with PyTorch 2.0.0.dev20221220

@kulinseth
Copy link

With the latest nightly:

% /Users/quentingallouedec/stable-baselines3/env/bin/python /Users/quentingallouedec/stable-baselines3/test_mps.py
[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.
Traceback (most recent call last):
  File "/Users/quentingallouedec/stable-baselines3/test_mps.py", line 5, in <module>
    model = A2C("CnnPolicy", env).learn(250)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/a2c/a2c.py", line 193, in learn
    return super().learn(
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 248, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 166, in collect_rollouts
    actions, values, log_probs = self.policy(obs_tensor)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1427, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 576, in forward
    log_prob = distribution.log_prob(actions)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/distributions.py", line 279, in log_prob
    return self.distribution.log_prob(actions)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/categorical.py", line 123, in log_prob
    self._validate_sample(value)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/distribution.py", line 298, in _validate_sample
    valid = support.check(value)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/constraints.py", line 257, in check
    return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
NotImplementedError: The operator 'aten::remainder.Tensor_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

Its in PR. Will try to priortize the merge.
pytorch/pytorch#87582

@qgallouedec
Copy link
Collaborator

qgallouedec commented Feb 14, 2023

Some progress!

The code in #951 (comment) now ('2.0.0.dev20230214') outputs

/Users/quentingallouedec/stable-baselines3/env/lib/python3.10/site-packages/torch/distributions/categorical.py:118: UserWarning: 1MPS: no support for int64 min/max ops, casting it to int32 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/ReduceOps.mm:1260.)
  samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
Traceback (most recent call last):
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/mps.py", line 5, in <module>
    model = A2C("CnnPolicy", env).learn(250)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/a2c/a2c.py", line 190, in learn
    return super().learn(
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 246, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 165, in collect_rollouts
    actions, values, log_probs = self.policy(obs_tensor)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 622, in forward
    log_prob = distribution.log_prob(actions)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/distributions.py", line 292, in log_prob
    return self.distribution.log_prob(actions)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.10/site-packages/torch/distributions/categorical.py", line 123, in log_prob
    self._validate_sample(value)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.10/site-packages/torch/distributions/distribution.py", line 298, in _validate_sample
    valid = support.check(value)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.10/site-packages/torch/distributions/constraints.py", line 257, in check
    return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
RuntimeError: MPS: does not support floor_divide op with int64 input

I reported it in the MPS issue tracker

@BasLaa
Copy link

BasLaa commented Mar 16, 2023

Is there any progress on this? Is mps usable in any way already?

@araffin
Copy link
Member Author

araffin commented Apr 3, 2023

Is there any progress on this? Is mps usable in any way already?

@BasLaa you can already give it a try by passing device="mps" to the constructor and using latest pytorch version (pytorch nightly is probably even better).
It should work at least partially (but you might need to use the cpu fallback), please report any issue here.

@araffin
Copy link
Member Author

araffin commented Oct 6, 2023

@qgallouedec how is the support with PyTorch 2.1.0?

@qgallouedec
Copy link
Collaborator

The number of errors decreases. Here's one a them:

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Is double precision a feature of sb3 or should single precision be forced systematically?

@araffin
Copy link
Member Author

araffin commented Oct 8, 2023

Is double precision a feature of sb3 or should single precision be forced systematically?

I think we don't really support float64... (mainly to avoid issues when using CUDA)
there are several places where we already force float32 anyway (#1572), including preprocessing if I recall.

@tty666
Copy link

tty666 commented Oct 29, 2023

If you need someone to test something please tell me I could with my Mac because this PR is there for a while now and nobody comes with a solution or a review ...
Just tell me what to do and I will perform the testing for you to deliver this MPS support ...

@qgallouedec
Copy link
Collaborator

@tty666 thank you for the proposal. Feel free to test and provide your feedback if any. As far as I remember, there are still some issues related to dtype (float64 instead of float32), see #951 (comment). As soon as all the CI passes, we can consider this PR as ready to be merged

@ArthurMynl
Copy link

Any news regarding this PR? Is someone working on it?

@araffin
Copy link
Member Author

araffin commented Jan 10, 2024

Any news regarding this PR? Is someone working on it?

#951 (comment)

@lsibilla
Copy link

Hello!

I just tried this out, out of curiosity and it seems to work. The small snippet above and another project I have been working on recently work very similarly with and without MPS.

I can see GPU going to 100% with asitop and no crashes.

Performance-wise it's not as good as we might expect but that might related to my particular use-case.

@lsibilla
Copy link

Hi. I see the tests are still failing. I'll try to give a bit more details on my setup.

First, I'm running a MacBook Pro M1 Pro. The test from yesterday was running with Python 3.12.

This morning, I cloned the repo, switched to the feat/mps-support branch, created a Python 3.11 venv and ran test_cnn.py:

(.venv-3.11) ➜  stable-baselines3 git:(feat/mps-support) ✗ pytest tests/test_cnn.py            
======================================= test session starts =======================================
platform darwin -- Python 3.11.8, pytest-8.1.1, pluggy-1.4.0
rootdir: /Users/lsibilla/src/lab/stable-baselines3
configfile: pyproject.toml
plugins: cov-5.0.0, anyio-4.3.0, env-1.1.3, xdist-3.5.0
collected 29 items                                                                                

tests/test_cnn.py .............s...............                                             [100%]

======================================== warnings summary =========================================
tests/test_cnn.py: 25 warnings
  /Users/lsibilla/src/lab/stable-baselines3/stable_baselines3/common/utils.py:524: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
    if hasattr(th, "has_mps") and th.backends.mps.is_built():

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================== 28 passed, 1 skipped, 25 warnings in 76.15s (0:01:16) ======================
(.venv-3.11) ➜  stable-baselines3 git:(feat/mps-support) ✗ 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Supporting PyTorch GPU compatibility on Apple Silicon chips
7 participants