Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use num_batch_dims=0 to deal with observations containing scalars #231

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

wookayin
Copy link
Contributor

@wookayin wookayin commented May 23, 2022

When a (nested) observation contain a scalar value on some axes
with dimension 1, batch_concat() will crash due to a shape mismatch
error (e.g., cannot concatenate tensors of shape [100, 1] and []).

This happens because batch_concat by default assumes the input
tensors are batched, but when computing the statistics
the batch dimensions (axis=0) already have been eliminated.

Note: This applies to SAC agent. PPO agent already has a correct implementation (comparison here).

When a (nested) observation contain a scalar value on some axes
with dimension 1, batch_concat() will crash due to a shape mismatch
error (e.g., cannot concatenate tensors of shape [100, 1] and []).

This happens because batch_concat by default assumes the input
tensors are batched, but when computing the statistics
the batch dimensions (axis=0) already have been eliminated.
@wookayin
Copy link
Contributor Author

One additional favor: please rebase(squash) when merging, rather than creating a merge commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants