Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] RecurrentPPO: Reset LSTM states early? #239

Open
4 tasks done
phisad opened this issue Apr 5, 2024 · 3 comments
Open
4 tasks done

[Question] RecurrentPPO: Reset LSTM states early? #239

phisad opened this issue Apr 5, 2024 · 3 comments
Labels
enhancement New feature or request question Further information is requested

Comments

@phisad
Copy link

phisad commented Apr 5, 2024

❓ Question

Hi and thanks for the great work!

I am using RecurrentPPO in a current project and it strikes me that on L294 the self._last_lstm_states added to the buffer are actually the one from the last terminal state (and not all zeros), when an environment is reset on L252. Is my understanding correct?

If so, would it not be better to check for an episode start already one line before L242 and set the states to zero for those environments instead of handling this in _process_sequence of RecurrentActorCriticPolicy L198 on each forward pass?

Checklist

@phisad phisad added the question Further information is requested label Apr 5, 2024
@araffin
Copy link
Member

araffin commented Apr 5, 2024

Hello,
that's a good suggestion =)
Would you mind giving it a try and check that you obtain the exact same results?
If so, please open a PR ;)

That would simplify and make things much faster hopefully.

@araffin araffin added the enhancement New feature or request label Apr 5, 2024
@phisad
Copy link
Author

phisad commented Apr 11, 2024

Alright, thanks for the confirmation. ^^

I'll give a try and make sure that these tests run through https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/tests/test_lstm.py without any errors (and maybe even a bit faster).

@araffin
Copy link
Member

araffin commented Apr 11, 2024

Thinking again about that issue, I'm afraid we still need

(1.0 - episode_start).view(1, n_seq, 1) * lstm_states[0],
(1.0 - episode_start).view(1, n_seq, 1) * lstm_states[1],
to reset states manually when starting a new episode? (at least when updating the network, when calling train())

or can we pass all hidden states to PyTorch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants