Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Towards SoftAdapt loss balancing for tf.compat.v1 #1586

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

pescap
Copy link
Contributor

@pescap pescap commented Dec 6, 2023

Work in progress!

Comment on lines 596 to 598
loss_weights = dde.Variable(loss_weights, trainable=False, dtype=loss_weights.dtype)
loss_weights *= 0

Copy link
Contributor Author

@pescap pescap Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to allow loss_weights to be Variable, such that the loss function updates automatically every time that the weights change. Any clue @lululxvi ?

Here, I was trying to set the loss_weights to 0. Therefore, the loss shall give 0 for next epochs (which is not the case so far).

Shall we define loss_weights differently in model.compile?

Maybe we need to work here in:

deepxde/deepxde/model.py

Lines 169 to 183 in 3b08fe3

def losses(losses_fn):
# Data losses
losses = losses_fn(
self.net.targets, self.net.outputs, loss_fn, self.net.inputs, self
)
if not isinstance(losses, list):
losses = [losses]
# Regularization loss
if self.net.regularizer is not None:
losses.append(tf.losses.get_regularization_loss())
losses = tf.convert_to_tensor(losses)
# Weighted losses
if loss_weights is not None:
losses *= loss_weights
return losses

Thank you!

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you plan to update loss_weights?

@lululxvi
Copy link
Owner

lululxvi commented Dec 6, 2023

Implementing in TensorFlow is tricky, as it is static graph. It should be much easier to implement in pytorch, where you can directly change the loss_weights value.

@pescap
Copy link
Contributor Author

pescap commented Dec 6, 2023

Implementing in TensorFlow is tricky, as it is static graph. It should be much easier to implement in pytorch, where you can directly change the loss_weights value.

Thank you for your feeback. I would really prefer to implement this adaptive loss callback in tensorflow.compat.v1.

I think I'll start with a simple two-terms loss (and one weighing parameter).

@haison19952013
Copy link

haison19952013 commented Feb 1, 2024

Implementing in TensorFlow is tricky, as it is static graph. It should be much easier to implement in pytorch, where you can directly change the loss_weights value.

  • It can be done if the loss_weights is the argument of the train_step()
  • If not, when we iteratively change the loss weights, we will need tensorflow to make the graph all over again. In other words, model will .compile() again and the training might be slow.

@pescap
Copy link
Contributor Author

pescap commented Feb 1, 2024

Hi, if we define loss_weights as Variable, no need to compile several times, right?

Next, we have to define appopriately the total_loss.

@haison19952013
Copy link

haison19952013 commented Feb 2, 2024

Hi, if we define loss_weights as Variable, no need to compile several times, right?

Next, we have to define appopriately the [total_loss]

total_loss = tf.math.reduce_sum(losses)
.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants