Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

New Metric.from_mask helper method #3411

Open
stephenroller opened this issue Jan 26, 2021 · 2 comments
Open

New Metric.from_mask helper method #3411

stephenroller opened this issue Jan 26, 2021 · 2 comments

Comments

@stephenroller
Copy link
Contributor

We have quite a few instances where we have some per-token losses/metrics along with a corresponding mask

metric_per_token # torch.Tensor of shape (batchsize, num_tokens)
mask # torch.BoolTensor of shape (batchsize, num_tokens)

And we want a per-batch-example average:

tokens_per_ex = mask.long().sum(dim=-1)
metric_per_ex = (metric_per_token * mask).sum(dim=-1)
metrics: List[MyMetric] = MyMetric.many(metric_per_ex, tokens_per_ex)
self.record_local_metric('metric_name', metrics)

I'd like us to have a helper classmethod in Metric called from_mask:

class Metric:
    @classmethod
    def from_mask(cls, metric_per_token, token_mask):
        # returns the equivalent of the "metrics" object above

Once this is done, add unit tests for this (test AverageMetric and PPLMetric directly). Checkpoint there.

After you've implemented this, upgrade TorchGeneratorAgent to use your new helper, upgrading the code for loss, ppl, and token_acc.

Example:

if batch.label_vec is None:
raise ValueError('Cannot compute loss without a label.')
model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
scores, preds, *_ = model_output
score_view = scores.view(-1, scores.size(-1))
loss = self.criterion(score_view, batch.label_vec.view(-1))
loss = loss.view(scores.shape[:-1]).sum(dim=1)
# save loss to metrics
notnull = batch.label_vec.ne(self.NULL_IDX)
target_tokens = notnull.long().sum(dim=-1)
correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)
self.record_local_metric('loss', AverageMetric.many(loss, target_tokens))
self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
self.record_local_metric(
'token_acc', AverageMetric.many(correct, target_tokens)
)

See if you can find at least one other place who can benefit from upgrading this pattern.

@github-actions
Copy link

This issue has not had activity in 30 days. Please feel free to reopen if you have more issues. You may apply the "never-stale" tag to prevent this from happening.

@poojasethi
Copy link
Contributor

Hi @klshuster and @stephenroller! I've just submitted a PR for this issue: #4894

poojasethi added a commit that referenced this issue Nov 29, 2022
* Add Metric.from_mask helper method (#3411)

* Use cls directly instead of passing in MyMetric
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

5 participants