Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[question] How do I load a tensorflow ckpt? #1147

Open
Syzygianinfern0 opened this issue Jan 2, 2022 · 4 comments
Open

[question] How do I load a tensorflow ckpt? #1147

Syzygianinfern0 opened this issue Jan 2, 2022 · 4 comments
Labels
more information needed Please fill the issue template completely question Further information is requested RTFM Answer is the documentation

Comments

@Syzygianinfern0
Copy link

I am trying to load a pre-trained model from some old code using this framework and my familiarity with tensorflow is very limited. I've tried multiple things to load the model but I am unable to find the right way 馃く

Here is how the model is created and saved. I just want to load back the weights after saving for evaluation.

Below, I've shown a representative of how the model is created then stored.

import tensorflow as tf
from stable_baselines import PPO1
from stable_baselines.common.policies import FeedForwardPolicy

training_sess = None


class MyMlpPolicy(FeedForwardPolicy):
    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **_kwargs):
        super(MyMlpPolicy, self).__init__(
            sess,
            ob_space,
            ac_space,
            n_env,
            n_steps,
            n_batch,
            reuse,
            net_arch=[{"pi": [32, 16], "vf": [32, 16]}],
            feature_extraction="mlp",
            **_kwargs
        )
        global training_sess
        training_sess = sess


model = PPO1(MyMlpPolicy, env)

# This is how the model is saved
with model.graph.as_default():
    saver = tf.train.Saver()
    saver.save(training_sess, "./model_0.ckpt")

# The above step produces 4 types of files
# 1. checkpoint
# 2. model_0.ckpt.data-00000-of-00001
# 3. model_0.ckpt.index
# 4. model_0.ckpt.meta
@Miffyli Miffyli added the more information needed Please fill the issue template completely label Jan 2, 2022
@Miffyli
Copy link
Collaborator

Miffyli commented Jan 2, 2022

Please fill in the issue template. If you only want to save the full agent you do not need to do any TF stuff, only use save and load functions (see examples in docs). We can not offer custom tech support for saving/loading in a custom way like this.

@araffin
Copy link
Collaborator

araffin commented Jan 2, 2022

we also highly recommend to switch to Stable-Baselines3 (PyTorch).

@Syzygianinfern0
Copy link
Author

we also highly recommend to switch to Stable-Baselines3 (PyTorch).

Yeah that is what I currently use. I just need to run some old code for a comparison. I just have their provided weights.

@Miffyli
Copy link
Collaborator

Miffyli commented Jan 2, 2022

Yeah that is what I currently use. I just need to run some old code for a comparison. I just have their provided weights.

In that case you should look at the set_parameters function in the SB3 documentation :).

You can close this issue if your question has been answered.

@araffin araffin added RTFM Answer is the documentation question Further information is requested labels Jan 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
more information needed Please fill the issue template completely question Further information is requested RTFM Answer is the documentation
Projects
None yet
Development

No branches or pull requests

3 participants