Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training issue about dsrnn #3

Open
hoonm1n opened this issue Apr 25, 2023 · 4 comments
Open

training issue about dsrnn #3

hoonm1n opened this issue Apr 25, 2023 · 4 comments

Comments

@hoonm1n
Copy link

hoonm1n commented Apr 25, 2023

Hi @Shuijing725 !.
I have a issue about your code in training with srnn.
I changed robot.policy in config.py from "selfAttn_merge_srnn" to "srnn".
Then when i run the train.py, the error occurred like under line.

`<Monitor<CrowdSimPredRealGST>>
No ghost version.
new gst
new st model
LOADED MODEL
device: cuda:0

Traceback (most recent call last):
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/train.py", line 247, in
main()
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/train.py", line 90, in main
actor_critic = Policy(
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/rl/networks/model.py", line 28, in init
self.base = base(obs_shape, base_kwargs)
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/rl/networks/srnn_model.py", line 378, in init
robot_size = 7 if args.env_type == 'crowd_sim' else 2
AttributeError: 'Namespace' object has no attribute 'env_type'. Did you mean: 'env_name'`?

Could you solve this issue? Thank you for reading!

@hoonm1n
Copy link
Author

hoonm1n commented Apr 25, 2023

When i modified that line in srnn_model.py, another error occurred like under line.
Traceback (most recent call last):
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/train.py", line 247, in
main()
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/train.py", line 162, in main
value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/rl/networks/model.py", line 60, in act
value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks, infer=True)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/rl/networks/srnn_model.py", line 441, in forward
nodes_current_selected = self.robot_linear(robot_node)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x7 and 2x3)

@Shuijing725
Copy link
Owner

Do you want to add trajectory prediction? If not, change env_name in arguments.py to CrowdSimVarNum-v0.
Otherwise, S-RNN is not compatible with prediction right now, so I'll have to fix it.

@hoonm1n
Copy link
Author

hoonm1n commented Apr 26, 2023

Does that mean that the old network's weights can't be trained in the current repository(CrowdNav_Prediction_AttnGraph) yet?

@Shuijing725
Copy link
Owner

Shuijing725 commented Apr 26, 2023

This repo has several different gym environments (see readme -> training -> ii). S-RNN can be trained with CrowdSimVarNum-v0 environment, which does not have human trajectory prediction. But it cannot be trained with other environments with prediction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants