-
Notifications
You must be signed in to change notification settings - Fork 706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Towards SoftAdapt loss balancing for tf.compat.v1 #1586
base: master
Are you sure you want to change the base?
Conversation
deepxde/callbacks.py
Outdated
loss_weights = dde.Variable(loss_weights, trainable=False, dtype=loss_weights.dtype) | ||
loss_weights *= 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am trying to allow loss_weights
to be Variable
, such that the loss function updates automatically every time that the weights change. Any clue @lululxvi ?
Here, I was trying to set the loss_weights
to 0. Therefore, the loss
shall give 0 for next epochs (which is not the case so far).
Shall we define loss_weights
differently in model.compile
?
Maybe we need to work here in:
Lines 169 to 183 in 3b08fe3
def losses(losses_fn): | |
# Data losses | |
losses = losses_fn( | |
self.net.targets, self.net.outputs, loss_fn, self.net.inputs, self | |
) | |
if not isinstance(losses, list): | |
losses = [losses] | |
# Regularization loss | |
if self.net.regularizer is not None: | |
losses.append(tf.losses.get_regularization_loss()) | |
losses = tf.convert_to_tensor(losses) | |
# Weighted losses | |
if loss_weights is not None: | |
losses *= loss_weights | |
return losses |
Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you plan to update loss_weights
?
Implementing in TensorFlow is tricky, as it is static graph. It should be much easier to implement in pytorch, where you can directly change the |
Thank you for your feeback. I would really prefer to implement this adaptive loss callback in I think I'll start with a simple two-terms loss (and one weighing parameter). |
|
Hi, if we define Next, we have to define appopriately the total_loss. |
|
Work in progress!