Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exporting MultiInputActorCriticPolicy as ONNX #1873

Open
4 tasks done
MaximCamilleri opened this issue Mar 18, 2024 · 5 comments
Open
4 tasks done

Exporting MultiInputActorCriticPolicy as ONNX #1873

MaximCamilleri opened this issue Mar 18, 2024 · 5 comments
Labels
more information needed Please fill the issue template completely question Further information is requested

Comments

@MaximCamilleri
Copy link

❓ Question

Hi,

I am looking into the use of ONNX with SB3. I have tested 2 models (A2C and PPO) on a custom environment using a MultiInputActorCriticPolicy. The observation space of the environment is of type dict. So far I have not been able to produce an onnaxable policy.

In the documentation the words The following examples are for MlpPolicy only, and are general examples can be found. Is it possible to export a model of my type to ONNX? and if so would it be possible to provide an example?

Thanks

Checklist

@MaximCamilleri MaximCamilleri added the question Further information is requested label Mar 18, 2024
@araffin araffin added the more information needed Please fill the issue template completely label Mar 18, 2024
@araffin
Copy link
Member

araffin commented Mar 18, 2024

Hello,
what have you tried so far?
and what errors did you encounter?

Please provide a minimal and working code example (see link in issue template for what that means).

@MaximCamilleri
Copy link
Author

Hello, thanks for your response.

I have tried a couple of things so far. First I tried converting my model into an onnxable policy using the method shown in the documentation. My code is as follows:

class OnnxablePolicy(th.nn.Module):
    def __init__(self, policy):
        super(OnnxablePolicy2, self).__init__()
        self.policy = policy

    def forward(self, input):
        return self.policy(input)

model = PPO.load("Models/ppo.zip")
onnx_policy = OnnxablePolicy(model.policy)

th.onnx.export(
    onnx_policy,
    obs_dict,
    "ONNX/ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

To get the dummy input which I am here calling obs_dict, I used the following code snippet:

obs = env.reset()
obs_dict = {}
for key in obs.keys():
    obs_dict[key] = th.from_numpy(np.array([obs[key]])).float()

This creates an input with the same structure as the observation space after common.preprocessing.preprocess_obs is run.
The error I was getting at this point is: TypeError: OnnxablePolicy2.forward() missing 1 required positional argument: 'input'

I also tried the approach seen here, and created the following code:

class OnnxablePolicy(th.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super(OnnxablePolicy, self).__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net

    def forward(self, input):
        action_hidden = value_hidden = self.extractor(input)
        return self.action_net(action_hidden), self.value_net(value_hidden)

onnx_policy = OnnxablePolicy(model.policy.features_extractor, model.policy.action_net, model.policy.value_net)

th.onnx.export(
    onnx_policy,
    obs_dict,
    "ONNX/ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

Which resulted in the same error as before.

Finally I tried using the policy as is:

model = PPO.load("Models/ppo.zip")
obs = env.reset()
th.onnx.export(
    model.policy,
    obs,
    "ONNX/ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

This seemingly got me the furthest, producing the new error:

[110](file:////.venv/Lib/site-packages/stable_baselines3/common/preprocessing.py:110)     assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
    [111](file:////.venv/Lib/site-packages/stable_baselines3/common/preprocessing.py:111)     preprocessed_obs = {}
    [112](file:////.venv/Lib/site-packages/stable_baselines3/common/preprocessing.py:112)     for key, _obs in obs.items():
AssertionError: Expected dict, got <class 'torch.Tensor'>

@araffin
Copy link
Member

araffin commented Mar 25, 2024

I gave it a try but this one seems to be a bit hard, you probably need to use the experimental onnx export from pytorch (using dynamo).
The thing that got me further was to pass (obs_dict, {}) as observation, otherwise pytorch try to use it as keyword arguments.

my current attempt (the export seems to work but the loading doesn't :/)

import torch as th
from typing import Tuple
import gymnasium as gym

from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy

import onnx
import onnxruntime as ort
import numpy as np


class OnnxableSB3Policy(th.nn.Module):
    def __init__(self, policy: BasePolicy):
        super().__init__()
        self.policy = policy

    def forward(self, observation):
        print(observation)
        return observation["a"]
        # NOTE: Preprocessing is included, but postprocessing
        # (clipping/inscaling actions) is not,
        # If needed, you also need to transpose the images so that they are channel first
        # use deterministic=False if you want to export the stochastic policy
        return self.policy._predict(observation, deterministic=True)


class Custom(gym.Env):
    def __init__(self):
        super().__init__()
        self.observation_space = gym.spaces.Dict(
            {
                "a": gym.spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32),
                # "b": gym.spaces.Discrete(5),
            }
        )
        self.action_space = gym.spaces.Discrete(2)

    def reset(self, seed=None):
        return self.observation_space.sample(), {}

    def step(self, action):
        return self.observation_space.sample(), 0.0, False, False, {}


env = Custom()
obs, _ = env.reset()
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MultiInputPolicy", env).save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")


onnx_policy = OnnxableSB3Policy(model.policy)

observation_size = model.observation_space.shape
# Add batch dimension
dummy_input = {
    # "a": np.array(obs["a"])[np.newaxis, ...],
    "a": np.array(obs["a"]),
    # "b": np.array(obs["b"])[np.newaxis, ...],
}
dummy_input_tensor = {
    "a": th.as_tensor(dummy_input["a"]),
    # "b": th.as_tensor(dummy_input["b"]),
}

print(model.predict(dummy_input, deterministic=True))


th.onnx.export(
    onnx_policy,
    args=(dummy_input_tensor, {}),
    f="my_ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

##### Load and test with onnx


onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

observation = dummy_input.copy()
ort_sess = ort.InferenceSession(onnx_path)

# print(ort_sess.get_inputs()[0].name)
# print(ort_sess.get_inputs())

output = ort_sess.run(None, {"input": observation})

print(output)

# Check that the predictions are the same
# with th.no_grad():
#     print(model.policy(th.as_tensor(observation), deterministic=True))

@araffin
Copy link
Member

araffin commented Mar 31, 2024

"
Due to design differences, input/output format between PyTorch model and exported ONNX model are often not the same. E.g., None is allowed for PyTorch model, but are not supported by ONNX. Nested constructs of tensors are allowed for PyTorch model, but only flattened tensors are supported by ONNX, etc."

from https://pytorch.org/docs/stable/onnx_dynamo.html#torch.onnx.ONNXProgram.adapt_torch_inputs_to_onnx

@NickLucche
Copy link
Contributor

HI all, I wouldn't really export the sampling procedure to onnx here (''self.policy._predict(observation, deterministic=True)
"), but rather have the network output the raw logits and implement the sampling as a postprocessing step.
A consistent export procedure would be a nice feature to add to the framework :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
more information needed Please fill the issue template completely question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants