Skip to content

Commit

Permalink
Add Metric.from_mask helper method (facebookresearch#3411)
Browse files Browse the repository at this point in the history
  • Loading branch information
poojasethi committed Nov 23, 2022
1 parent 94f1b9c commit aadd32e
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 15 deletions.
24 changes: 23 additions & 1 deletion parlai/core/metrics.py
Expand Up @@ -24,6 +24,7 @@
Optional,
Set,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -272,7 +273,7 @@ def many(cls, *objs: List[TVector]) -> List[Metric]:
"""
Construct many of a Metric from the base parts.
Useful if you separately compute numerators and denomenators, etc.
Useful if you separately compute numerators and denominators, etc.
"""
lengths = [len(o) for o in objs]
objs = list(objs) # convert from tuple for inplace modification
Expand All @@ -286,6 +287,27 @@ def many(cls, *objs: List[TVector]) -> List[Metric]:
raise IndexError(f'Uneven {cls.__name__} constructions: {lengths}')
return [cls(*items) for items in zip(*objs)]

@classmethod
def from_mask(
cls, metric_per_token: torch.Tensor, mask: torch.Tensor, MyMetric: Type[Metric]
) -> List[Metric]:
"""
From token-level metrics, returns an aggregate MyMetric per example in the batch.
:param metric_per_token:
a (batchsize x num_tokens) Tensor
:param mask:
a (batchsize x num_tokens) Tensor to mask out tokens that should *not* be considered in the aggregate metric calculation.
:param MyMetric:
a subclass of Metric
:return:
a (batchsize) Tensor
"""
tokens_per_ex = mask.long().sum(dim=-1)
metric_per_ex = (metric_per_token * mask).sum(dim=-1)
metrics = MyMetric.many(metric_per_ex, tokens_per_ex)
return metrics


class FixedMetric(Metric):
"""
Expand Down
31 changes: 19 additions & 12 deletions parlai/core/torch_generator_agent.py
Expand Up @@ -34,7 +34,7 @@
from parlai.utils.misc import warn_once
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric
from parlai.core.metrics import Metric, SumMetric, AverageMetric, FairseqBleuMetric
from parlai.utils.fp16 import FP16SafeCrossEntropy
import parlai.utils.fsdp as fsdp_utils
from parlai.utils.torch import (
Expand Down Expand Up @@ -710,28 +710,35 @@ def compute_loss(self, batch, return_output=False):
model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
scores, preds, *_ = model_output
score_view = scores.reshape(-1, scores.size(-1))
loss = self.criterion(score_view, batch.label_vec.view(-1))
loss = loss.view(scores.shape[:-1]).sum(dim=1)
# save loss to metrics
loss_flattened = self.criterion(score_view, batch.label_vec.view(-1))
loss_per_token = loss_flattened.view(scores.shape[:-1])
notnull = batch.label_vec.ne(self.NULL_IDX)
target_tokens = notnull.long().sum(dim=-1)
correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)

# save loss to metrics
# cross entropy loss
self.record_local_metric('loss', AverageMetric.many(loss, target_tokens))
self.record_local_metric(
'loss', Metric.from_mask(loss_per_token, notnull, AverageMetric)
)
# perplexity
self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
self.record_local_metric(
'ppl', Metric.from_mask(loss_per_token, notnull, PPLMetric)
)
# token-wise accuracy
self.record_local_metric(
'token_acc', AverageMetric.many(correct, target_tokens)
'token_acc',
Metric.from_mask(batch.label_vec == preds, notnull, AverageMetric),
)
# utterance-wise exact match
num_target_tokens = notnull.long().sum(dim=-1)
num_tokens_correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)
self.record_local_metric(
'token_em', AverageMetric.many(correct == target_tokens)
'token_em', AverageMetric.many(num_tokens_correct == num_target_tokens)
)

# actually do backwards loss
loss = loss_per_token.sum(dim=1)
loss = loss.sum()
loss /= target_tokens.sum() # average loss per token
loss /= num_target_tokens.sum() # average loss per token
if return_output:
return (loss, model_output)
else:
Expand Down Expand Up @@ -1440,7 +1447,7 @@ def set_block_list(self: TSType, block_list: Optional[SearchBlocklist]) -> TSTyp

def get_output_from_current_step(self):
"""
Get the outputput at the current step.
Get the output at the current step.
"""
return self.outputs[-1]

Expand Down
94 changes: 92 additions & 2 deletions tests/test_metrics.py
Expand Up @@ -13,6 +13,7 @@
AverageMetric,
SumMetric,
FixedMetric,
Metric,
Metrics,
GlobalAverageMetric,
MacroAverageMetric,
Expand All @@ -28,6 +29,7 @@
WeightedF1Metric,
AUCMetrics,
)
from parlai.core.torch_generator_agent import PPLMetric
import parlai.utils.testing as testing_utils


Expand Down Expand Up @@ -70,7 +72,6 @@ def test_sum_metric_additions(self):
self.assertAlmostEqual(actual_output, output, places=6)

def test_average_metric_inputs(self):

passing_inputs_and_outputs = [
((2, 4), 0.5),
((17.0, 10.0), 1.7),
Expand All @@ -91,7 +92,6 @@ def test_average_metric_inputs(self):
AverageMetric(input_[0], input_[1])

def test_average_metric_additions(self):

input_pairs_and_outputs = [
((2, 4), (1.5, 1), 0.7),
(
Expand Down Expand Up @@ -120,6 +120,96 @@ def test_macroaverage_additions(self):
assert (m1 + m2) == AverageMetric(4, 7)
assert MacroAverageMetric({'a': m1, 'b': m2}) == 0.5 * (1.0 / 3 + 3.0 / 4)

def test_average_metric_from_mask(self) -> None:
# first test case. batchsize=3, num_tokens=10
token_values_1 = torch.FloatTensor(
[
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[-10, -8, -6, -4, -2, 0, 2, 4, 6, 8],
[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0],
]
)
token_mask_1 = torch.LongTensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
]
)
output_1 = [
AverageMetric(55, 10),
AverageMetric(-30, 6),
AverageMetric(12.5, 5),
]

# second test case. batchsize=4, num_tokens=5
token_values_2 = torch.FloatTensor(
[
[1, 2, 3, 4, 5],
[1.5, 0, -1, 3, -4],
[-3, -2, -1, 0, 1],
[4, 5, 6, 7, 8],
]
)
token_mask_2 = torch.LongTensor(
[
[1, 1, 1, 1, 1],
[1, 1, 1, 0, 0],
[1, 0, 1, 0, 1],
[0, 0, 0, 0, 0],
]
)
output_2 = [
AverageMetric(15, 5),
AverageMetric(0.5, 3),
AverageMetric(-3, 3),
AverageMetric(0, 0),
]

input_and_outputs = [
(token_values_1, token_mask_1, output_1),
(token_values_2, token_mask_2, output_2),
]

for token_values, token_mask, output in input_and_outputs:
actual_output = Metric.from_mask(token_values, token_mask, AverageMetric)
self.assertEqual(len(actual_output), len(output))
# Because Metric.from_mask() calls Metric.many(), which in turn converts tensors to lists,
# it possible for the actual and expected outputs to be close to each other but not exactly equal.
for a, o in zip(actual_output, output):
self.assertIsInstance(a, type(o))
self.assertAlmostEqual(a.value(), o.value(), places=6)

def test_ppl_metric_from_mask(self) -> None:
# batchsize=3, num_tokens=10
token_values = torch.FloatTensor(
[
[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
)
token_mask = torch.LongTensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
)
output = [
PPLMetric(4.5, 10),
PPLMetric(0.6, 6),
PPLMetric(0, 0),
]
actual_output = Metric.from_mask(token_values, token_mask, PPLMetric)

self.assertEqual(len(actual_output), len(output))
# Because Metric.from_mask() calls Metric.many(), which in turn converts tensors to lists,
# it possible for the actual and expected outputs to be close to each other but not exactly equal.
for a, o in zip(actual_output, output):
self.assertIsInstance(a, type(o))
self.assertAlmostEqual(a.value(), o.value(), places=6)


class TestMetrics(unittest.TestCase):
"""
Expand Down

0 comments on commit aadd32e

Please sign in to comment.