-
-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add option to reweight loss to account for differences in relation frequency #1142
base: master
Are you sure you want to change the base?
Conversation
Hi @sbonner0, Thanks for working on this, having some form of loss weighting would definitely be useful from my point of view. I think that the main question in terms of design is: what are we weighting? In my opinion the samples that drive the loss are always the positive ones. The negatives are a sort of function of the positives, so maybe just using positive weights (like you propose for the margin loss) could simplify things. At a high level the API looks good to me. Maybe a potential improvement, to bring this feature in line with the rest of the codebase, would be to allow the I would be happy to hear other people's thoughts. Hope this helps, |
Hi @sbonner0 , first of all, thanks for starting with a proposal. 🚀 I like the idea of having (positive) triple weights, too, since it allows users to provide any (static, precalculated?) weights, including inverse entity or relation frequency weighting. In that case, it would make sense to extend the training instances, defined in I think the easiest option is to re-use these weights for the negative examples generated from a positive one. Some losses, e.g.. the adversarial one(s), calculate some dynamic negative loss weights, which could then be multiplied with the corresponding positive triple weight. |
Hey @francesco-tuveri - thanks so much for your interest and feedback. I take your point on the weighting of the positive versus negative samples however I still feel there might be some value in also allowing the negatives to be weighted separately also. Really agree with your suggestion to improve the API in relation to the pipeline, I'll look into doing something along these lines. Hey @mberr - thanks for the suggestion to add weights into the training instances. I guess that the appropriate weights could then be returned from the data loader via I was wondering if there is a way in which the weights could be applied in a more standard way for all the loss functions? As it stands, the weights would need to be applied in each loss function which feels like a lot of redundant code. It would be ideal if loss functions returned the unreduced values which could then, if required, be modified by the weights, before being reduced via the chosen method. |
For the start, it might be easier to start with support for relation-balanced triple weights
class InstanceWeighting:
def calculate_weights(self, mapped_triples: MappedTriples) -> torch.FloatTensor:
raise NotImplementedError
class RelationBalancedInstanceWeighting(InstanceWeighting):
def calculate_weights(self, mapped_triples: MappedTriples) -> torch.FloatTensor:
inverse, counts = mapped_triples[:, 1].unique(return_inverse=True, return_counts=True)[1:]
return counts.reciprocal()[inverse]
instance_weighting_resolver = ClassResolver.from_subclasses(InstanceWeighting)
To support custom weights, we could now add class CustomInstanceWeighting(InstanceWeighting):
def __init__(self, weights: torch.FloatTensor):
self.weights = weights
def calculate_weights(self, mapped_triples: MappedTriples) -> torch.FloatTensor:
# check for consistency
return self.weights although there are some things to consider, e.g., about inverse relations. |
Hey @mberr sorry for the delay in this but I have some free time and would love to try and make some progress on this from my end finally! Thanks for your fantastic reply above, having been some time since I have looked at the codebase however, should the |
The place seems to be fine for now.
This file contains the code to transform triples into The datasets are generally "map-style-datasets", i.e., the define a For sample weights, we thus need to extend the dataset (variants) to add an optional additional entry, the sample weight. In places where we need custom collators (i.e., functions which combine samples into a batch), this needs to be considered, too. The In the training loop, we iterate over the batches from these data loaders and use them to calculate batch losses. Here, we may need some adaptions to make sure that the "sample weight" part of the batch is not given to the model but forwarded to the loss function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
I'm also interested in having weighted loss (for different purposes than what @sbonner0 describes) so I looked at the current state of the PR. Here are a few comments.
@@ -83,6 +84,7 @@ def _process_batch_static( | |||
stop: Optional[int], | |||
label_smoothing: float = 0.0, | |||
slice_size: Optional[int] = None, | |||
relation_weights: dict = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the term relation
is a bit confusing here. Does it refer to the fact weights are based on relation frequencies? If so, I assume the argument should eventually be renamed instance_weights
, which is more generic.
Do I understand it correctly?
@@ -57,6 +57,7 @@ def _create_training_data_loader( | |||
dataset=triples_factory.create_slcwa_instances( | |||
batch_size=batch_size, | |||
shuffle=kwargs.pop("shuffle", True), | |||
instance_weighting="test", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as I see it, the TraningLoop
gets a reference to some InstanceWeighting
passed via train()
. Then, it must inject it in the data loader, which injects it in the batched instance (the dataset) via create_(s)lcwa_instances
. But the one object to use the InstanceWeighting
actually is the training loop itself, when processing a batch.
Is the multi-step dependency injection necessary here? Instead, the TrainingLoop
could simply keep a reference to the InstanceWeighting
and pass it as argument to _process_batch_static()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by the way, I assume the weighting procedure should be passed via train
(i.e. via training_kwargs
) and not via training_loop_kwargs
because the same TrainingLoop
may be used multiple times with different weights. Is that right?
@@ -181,6 +181,12 @@ def __init__( | |||
|
|||
logger.debug("we don't really need the triples factory: %s", triples_factory) | |||
|
|||
# Give all unique relation types equal weight by weighting each triple inversely related to its relations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eventually, this piece of code should be removed (and put in RelationBalancedInstanceWeighting.calculate_weights()
instead), shouldn't it?
Hey @vcharpenay thanks for all your input and sorry for the slow reply. I will try and address your comments later on but I just wanted to say that I would be very happy to receive some help with this PR so please do feel free to contribute and change whatever I have done thus far! |
Sharing some initial work I had done around adding support for reweighting the loss to account for relation frequency. The current code is very much work in progress but I wanted to share at this early stage for initial feedback.
TODO and ideas list:
Link to the relevant feature request
This was initially discussed here: #973
Description of the Change
In order to help address the issues associated with relation imbalance, this PR proposes the ability to enable the loss value to be reweighted at the triple level. This reweighting is computed as follows for relation type r:
weight_r = n_triples / (n_unique_relations * count(r))
The way the current code for the loss functions is, where the reduction happens in the loss class, means that support has to be added for reweighting to each loss class. Currently I only have added support to MRL and BCE.
Reweighted training can be enabled for these two losses as follows:
I am certain that I have not done this in the most efficient manner possible and would welcome any and all feedback! For example, I wonder if this process makes sense for more advanced loss functions, such as the adversarial ones, where the triple loss is already being modified?
Possible Drawbacks
Adding more complexity to the pipeline for a niche use-case perhaps.
Verification Process
Ran a series of datasets with quite heavy imbalance in relation frequency and noticed improved performance on rarer relation types. For example Hetionet and the Compound-treats-Disease relation.
Release Notes