Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]Remove generator vars from gradient of the Supervised loss in TimeGAN (as loss does not depend on it) #309

Open
itakatz opened this issue Oct 19, 2023 · 0 comments
Labels
enhancement New feature or request

Comments

@itakatz
Copy link

itakatz commented Oct 19, 2023

(This is not a problem in the algorithm per-se, but more a matter of code clarity. So I do not tag it as a bug)

In the definition of the train_supervisor method:

def train_supervisor(self, x, opt):
with GradientTape() as tape:
h = self.embedder(x)
h_hat_supervised = self.supervisor(h)
generator_loss_supervised = self._mse(h[:, 1:, :], h_hat_supervised[:, :-1, :])
var_list = self.supervisor.trainable_variables + self.generator.trainable_variables
gradients = tape.gradient(generator_loss_supervised, var_list)
apply_grads = [(grad, var) for (grad, var) in zip(gradients, var_list) if grad is not None]
opt.apply_gradients(apply_grads)
return generator_loss_supervised

the generator variables are added to var_list. This is not needed, as generator_loss_supervised does not depend on these variables (so gradient will be 0 and no update on these parameters will take place). This is probably inherited from the original TimeGAN implementation.

If I am correct, I suggest removing these variables from var_list, for better code clarity.

@itakatz itakatz added the enhancement New feature or request label Oct 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant