Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Add Attention nets (GTrXL model in particular) #165

Open
RemiG3 opened this issue Mar 15, 2023 · 13 comments 路 May be fixed by #176
Open

[Feature Request] Add Attention nets (GTrXL model in particular) #165

RemiG3 opened this issue Mar 15, 2023 · 13 comments 路 May be fixed by #176
Labels
enhancement New feature or request

Comments

@RemiG3
Copy link

RemiG3 commented Mar 15, 2023

馃殌 Feature Request

This feature request is a duplicate from stable-baselines3 (see DLR-RM/stable-baselines3#177).

The idea is to add the GTrXL model in the contrib repo from the paper Stabilizing Transformers for Reinforcement Learning, as done in RLlib: https://github.com/ray-project/ray/blob/master/rllib/models/torch/attention_net.py.

@araffin has already mentioned that he created it and will make it public (comment).

I wonder if this is still relevant?

@araffin
Copy link
Member

araffin commented Mar 15, 2023

@araffin has already mentioned that he created it and will make it public (DLR-RM/stable-baselines3#177 (comment)).

I meant the SB3 contrib repo.

For GTrXL, are you willing to contribute that algorithm?
Please read carefully the contributing guide if you decide to.

@RemiG3
Copy link
Author

RemiG3 commented Mar 16, 2023

I meant the SB3 contrib repo.

Sorry for the misunderstanding.

For GTrXL, are you willing to contribute that algorithm?

I'm not sure yet, I will try to implement it for my experiments first.

@araffin
Copy link
Member

araffin commented Mar 31, 2023

Also related: https://github.com/maohangyu/TIT_open_source

@richardjozsa
Copy link

@RemiG3 hey, have you started to implement it? Mayba I can give a free hand in it :)

@RemiG3
Copy link
Author

RemiG3 commented Apr 6, 2023

Yes, I have implemented it, but not tested properly. I'm currently having some troubles with my custom environment that I'm trying to solve.

@araffin is it possible to create a new branch for this feature (to share the code)?
If it is possible, I'll clean up the code and push it to this new branch soon.

@araffin
Copy link
Member

araffin commented Apr 6, 2023

Yes, I have implemented it, but not tested properly. I'm currently having some troubles with my custom environment that I'm trying to solve.

@araffin is it possible to create a new branch for this feature (to share the code)? If it is possible, I'll clean up the code and push it to this new branch soon.

yes, that's what a fork and pull request are meant for

@richardjozsa
Copy link

I have came accross on this, this is quite modular and easy to tune, Transformers-RL, the only backside is that, it has been implemented only to gaussian policy.

@araffin araffin linked a pull request Apr 11, 2023 that will close this issue
15 tasks
@RemiG3
Copy link
Author

RemiG3 commented Apr 11, 2023

Hey, I finally made the PR #176 to share the code.

It should work, but I'm not sure about the performances.
It would be nice if someone could make comparisons with other methods (or RLlib attention net for example).
I won't have time these next days.

@eric000888
Copy link

eric000888 commented Apr 18, 2023

RemiG3, Thank you for adding attention net to contrib. what's the shape of the input would be look like , for example if I want to use cartpole environment?
Thanks again.

@RemiG3
Copy link
Author

RemiG3 commented Apr 19, 2023

Thank you, @eric000888, for reporting this (feel free to provide the code you tested as you did in your first edits).

I have updated the branch to fix a bug on the dimension of minibatchs.
But, I still have an exception when batch_size = 1 or n_steps = 1 and I found the same exception for RecurrentPPO.

So, it should now work for batch_size > 1 and n_steps > 1 (as RecurrentPPO).

EDIT: I also add assertions about these cases, as in the original PPO.

@eric000888
Copy link

eric000888 commented Apr 25, 2023

RemiG3,
Sorry for late response, here is my first post code:

from sb3_contrib.ppo_attention.ppo_attention import AttentionPPO
from sb3_contrib.ppo_attention.policies import MlpAttnPolicy

VE = DummyVecEnv([lambda: gym.make("CartPole-v1")])

model = AttentionPPO(
"MlpAttnPolicy",
VE,
n_steps=240,
learning_rate=0.0003,
verbose=1,
batch_size=12,
ent_coef=0.03,
vf_coef=0.5,
seed=1,
n_epochs=10,
max_grad_norm=1,
gae_lambda=0.95,
gamma=0.99,
device='cpu',
policy_kwargs = dict(
net_arch=dict(pi=[64,32],vf=[64,32]),
)
)

First I create a vector environments and then setup the model like LSTM recurrent PPO, then run the model.learn().
I track the code and found the internal calculation return number is ok at the beginning but after a few loop it start return 'NA' and then stopped.
I saw some other implementation use stacked frame and use sliding window as input format so I'm a little bit confused about what's should be the correct input format. But from your code I think the input should just one records at the time, don't need to stack the records.

I follow the code and saw you concatenate the tensor of input and memory, but the input format from SB3 is one records and then after the first round of full loop it's become batch number of records and that throw the error as the memory is still just one
tensor instead of batch.

Thank you for the update, i will try it this weekend.

@eric000888
Copy link

another questions is if you just use GtrXL as feature extractor in PPO model, is this will get the same results? as the LSTM recurrent PPO has a flag to use the LSTM layer or not , similar like a feature extractor layer.

@eric000888
Copy link

eric000888 commented Apr 25, 2023

another thing is GtrXL demand more computation power , and PPO is like aiming a moving target, I found training a GtrXL PPO is a daunting task especially when using multiple layers. but if you can update the gradient on the whole trajectory then you may speed up the learning process. that means you collect all action/observation and then do one pass of back propagation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants