-
Notifications
You must be signed in to change notification settings - Fork 516
/
misc.py
99 lines (82 loc) · 3.06 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
from numbers import Number
from typing import Any, Dict, Union
import torch
from torch import Tensor
from corenet.metrics import METRICS_REGISTRY
from corenet.metrics.metric_base import AverageMetric
from corenet.utils import logger
@METRICS_REGISTRY.register(name="loss")
class LossMetric(AverageMetric):
def gather_metrics(
self,
prediction: Union[Tensor, Dict],
target: Union[Tensor, Dict],
extras: Dict[str, Any],
) -> Union[Tensor, Dict[str, Tensor]]:
"""
This function gather losses from different processes and converts to float.
"""
if extras is None:
extras = {}
loss = extras.get("loss", None)
if loss is None:
loss = 0.0
if isinstance(loss, Tensor):
return loss
elif isinstance(loss, Number):
return torch.tensor(loss, device=self.device)
elif isinstance(loss, Dict):
loss.pop(None, None)
for k, v in loss.items():
if isinstance(v, Number):
loss[k] = torch.tensor(loss, device=self.device)
elif not isinstance(v, Tensor):
logger.error(
"Loss metric supports Number, Tensor, or Dict of Tensors."
f" Got {v} with {type(v)} type under key {k}."
)
return loss
else:
logger.error(
"Loss metric supports Number, Tensor, or Dict of Tensors."
f" Got {loss} with {type(loss)} type."
)
@METRICS_REGISTRY.register(name="grad_norm")
class GradNormMetric(AverageMetric):
def gather_metrics(
self,
prediction: Union[Tensor, Dict],
target: Union[Tensor, Dict],
extras: Dict[str, Any],
) -> Union[Tensor, Dict[str, Tensor]]:
if extras is None:
extras = {}
grad_norm = extras.get("grad_norm", None)
if grad_norm is None:
grad_norm = 0.0
if isinstance(grad_norm, Tensor):
return grad_norm
elif isinstance(grad_norm, Number):
return torch.tensor(grad_norm, device=self.device)
elif isinstance(grad_norm, Dict):
grad_norm.pop(None, None)
for k, v in grad_norm.items():
if isinstance(v, Number):
grad_norm[k] = torch.tensor(grad_norm, device=self.device)
elif isinstance(v, str):
del grad_norm[k]
elif not isinstance(v, Tensor):
logger.error(
"Grad-norm metric supports Number, Tensor, or Dict of Tensors."
f" Got {v} with {type(v)} type under key {k}."
)
return grad_norm
else:
logger.error(
"Grad-norm metric supports Number, Tensor, or Dict of Tensors."
f" Got {grad_norm} with {type(grad_norm)} type."
)