Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ACER raises an error when GaussianHeadWithFixedCovariance is used #144

Open
muupan opened this issue May 26, 2021 · 0 comments · May be fixed by #145
Open

ACER raises an error when GaussianHeadWithFixedCovariance is used #144

muupan opened this issue May 26, 2021 · 0 comments · May be fixed by #145
Assignees
Labels
bug Something isn't working

Comments

@muupan
Copy link
Member

muupan commented May 26, 2021

Reported in #143

ACER assumes that all the parameters of a distribution (defined by get_params_of_distribution) require grad so that the algorithm can compute the gradient wrt the parameters.

pfrl/pfrl/agents/acer.py

Lines 172 to 180 in 44bf2e4

def get_params_of_distribution(distrib):
if isinstance(distrib, torch.distributions.Independent):
return get_params_of_distribution(distrib.base_dist)
elif isinstance(distrib, torch.distributions.Categorical):
return (distrib._param,)
elif isinstance(distrib, torch.distributions.Normal):
return distrib.loc, distrib.scale
else:
raise NotImplementedError("{} is not supported by ACER".format(type(distrib)))

pfrl/pfrl/agents/acer.py

Lines 218 to 221 in 44bf2e4

distrib_params = get_params_of_distribution(distrib)
for param in distrib_params:
assert param.shape[0] == 1
assert param.requires_grad

However, GaussianHeadWithFixedCovariance (

class GaussianHeadWithFixedCovariance(nn.Module):
) is used, the scale parameter of the torch.distributions.Normal distribution does not require grad, resulting in an assertion error.

@muupan muupan self-assigned this May 26, 2021
@muupan muupan linked a pull request May 26, 2021 that will close this issue
@muupan muupan added the bug Something isn't working label May 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant