Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_q in BuildingEnvelope Disconnects Gradients #118

Open
HarryLTS opened this issue Nov 5, 2023 · 0 comments
Open

get_q in BuildingEnvelope Disconnects Gradients #118

HarryLTS opened this issue Nov 5, 2023 · 0 comments
Assignees

Comments

@HarryLTS
Copy link

HarryLTS commented Nov 5, 2023

When I pass tensors with gradients into the forward function of sys (an object of class BuildingEnvelope), the gradients get removed at the point where get_q is called. This happens even if sys is instantiated with the kwargs backend='torch', requires_grad=True. This affects any system which uses BuildingEnvelope, as in the following setup, the gradients from loss computed with 'yn' cannot propagate back to policy (a blocks.MLP_bounds).

    policy_node = Node(policy, ['yn', 'D', 'UB', 'LB'], ['U'], name='policy')
    system_node = Node(sys, ['xn', 'U', 'Dhidden'], ['xn', 'yn'], name='system')

    cl_system = System([policy_node, system_node], nsteps=args.nsteps, name='cl_system')

The cause of this issue seems to be that BuildingEnvelope.get_q is wrapped by @cast_backend, which calls torch.tensor(return_tensor, dtype=torch.float32) on the tensor returned by get_q, which removes its gradient. If this line is removed, the gradients are able to propagate and the policy can be trained normally.

@drgona drgona assigned drgona and RBirmiwal and unassigned drgona and RBirmiwal Nov 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants