Skip to content

Commit

Permalink
Merge branch 'master' into cocomap
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed May 28, 2023
2 parents e2ac8ee + 9b4bef8 commit e4683de
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 52 deletions.
64 changes: 59 additions & 5 deletions ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, cast, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau
from torch.optim.optimizer import Optimizer

# https://github.com/pytorch/ignite/issues/2773
Expand Down Expand Up @@ -792,6 +792,57 @@ def simulate_values( # type: ignore[override]
return output


class _CosineAnnealingWarmRestarts:
def __init__(self, lr_scheduler: CosineAnnealingWarmRestarts):
self._lr_scheduler = lr_scheduler

@property
def last_epoch(self) -> int:
return self._lr_scheduler.last_epoch

@last_epoch.setter
def last_epoch(self, value: int) -> None:
self._lr_scheduler.last_epoch = value

@property
def optimizer(self) -> torch.optim.Optimizer:
return self._lr_scheduler.optimizer

def get_lr(self, epoch: Optional[int] = None) -> List[float]:
T_mult = self._lr_scheduler.T_mult
eta_min = self._lr_scheduler.eta_min

if epoch is None and self.last_epoch < 0:
epoch = 0
if epoch is None:
epoch = self.last_epoch + 1
self._lr_scheduler.T_cur = self._lr_scheduler.T_cur + 1
if self._lr_scheduler.T_cur >= self._lr_scheduler.T_i:
self._lr_scheduler.T_cur = self._lr_scheduler.T_cur - self._lr_scheduler.T_i
self._lr_scheduler.T_i = self._lr_scheduler.T_i * T_mult
else:
if epoch < 0:
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
if epoch >= self._lr_scheduler.T_0:
if T_mult == 1:
self._lr_scheduler.T_cur = epoch % self._lr_scheduler.T_0
else:
n = int(math.log((epoch / self._lr_scheduler.T_0 * (T_mult - 1) + 1), T_mult))
self._lr_scheduler.T_cur = epoch - self._lr_scheduler.T_0 * (T_mult**n - 1) / (T_mult - 1)
self._lr_scheduler.T_i = self._lr_scheduler.T_0 * T_mult**n
else:
self._lr_scheduler.T_i = self._lr_scheduler.T_0
self._lr_scheduler.T_cur = epoch

self.last_epoch = math.floor(epoch)

return [
eta_min
+ (base_lr - eta_min) * (1 + math.cos(math.pi * self._lr_scheduler.T_cur / self._lr_scheduler.T_i)) / 2
for base_lr in self._lr_scheduler.base_lrs
]


class LRScheduler(ParamScheduler):
"""A wrapper class to call `torch.optim.lr_scheduler` objects as `ignite` handlers.
Expand Down Expand Up @@ -853,7 +904,10 @@ def __init__(
f"but given {type(lr_scheduler)}"
)

self.lr_scheduler = lr_scheduler
self.lr_scheduler: Union[PyTorchLRScheduler, _CosineAnnealingWarmRestarts] = lr_scheduler
if isinstance(lr_scheduler, CosineAnnealingWarmRestarts):
self.lr_scheduler = _CosineAnnealingWarmRestarts(lr_scheduler)

super(LRScheduler, self).__init__(
optimizer=self.lr_scheduler.optimizer,
param_name="lr",
Expand All @@ -863,7 +917,7 @@ def __init__(
warnings.warn(
"Please make sure to attach scheduler to Events.ITERATION_COMPLETED "
"instead of Events.ITERATION_STARTED to make sure to use "
"the first lr value from the optimizer, otherwise it is will be skipped"
"the first lr value from the optimizer, otherwise it will be skipped"
)
self.lr_scheduler.last_epoch += 1

Expand All @@ -876,9 +930,9 @@ def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None
def get_param(self) -> Union[float, List[float]]:
"""Method to get current optimizer's parameter value"""
# Emulate context manager for pytorch>=1.4
self.lr_scheduler._get_lr_called_within_step = True # type: ignore[attr-defined]
self.lr_scheduler._get_lr_called_within_step = True # type: ignore[union-attr]
lr_list = cast(List[float], self.lr_scheduler.get_lr())
self.lr_scheduler._get_lr_called_within_step = False # type: ignore[attr-defined]
self.lr_scheduler._get_lr_called_within_step = False # type: ignore[union-attr]
if len(lr_list) == 1:
return lr_list[0]
else:
Expand Down
6 changes: 3 additions & 3 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(
@abstractmethod
def reset(self) -> None:
"""
Resets the metric to it's initial state.
Resets the metric to its initial state.
By default, this is called at the start of each epoch.
"""
Expand All @@ -240,7 +240,7 @@ def update(self, output: Any) -> None:
@abstractmethod
def compute(self) -> Any:
"""
Computes the metric based on it's accumulated state.
Computes the metric based on its accumulated state.
By default, this is called at the end of each epoch.
Expand Down Expand Up @@ -273,7 +273,7 @@ def iteration_completed(self, engine: Engine) -> None:
Note:
``engine.state.output`` is used to compute metric values.
The majority of implemented metrics accepts the following formats for ``engine.state.output``:
The majority of implemented metrics accept the following formats for ``engine.state.output``:
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. ``y_pred`` and ``y`` can be torch tensors or
list of tensors/numbers if applicable.
Expand Down
105 changes: 69 additions & 36 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
num_classes = 2 if self._type == "binary" else y_pred.size(1)
if self._type == "multiclass" and y.max() + 1 > num_classes:
raise ValueError(
f"y_pred contains less classes than y. Number of predicted classes is {num_classes}"
f" and element in y has invalid class = {y.max().item() + 1}."
f"y_pred contains fewer classes than y. Number of classes in the prediction is {num_classes}"
f" and an element in y has invalid class = {y.max().item() + 1}."
)
y = y.view(-1)
if self._type == "binary" and self._average is False:
Expand All @@ -86,30 +86,32 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens

@reinit__is_reduced
def reset(self) -> None:
# `numerator`, `denominator` and `weight` are three variables chosen to be abstract
# representatives of the ones that are measured for cases with different `average` parameters.
# `weight` is only used when `average='weighted'`. Actual value of these three variables is
# as follows.
#
# average='samples':
# numerator (torch.Tensor): sum of metric value for samples
# denominator (int): number of samples
#
# average='weighted':
# numerator (torch.Tensor): number of true positives per class/label
# denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
# positives per class/label
# weight (torch.Tensor): number of actual positives per class
#
# average='micro':
# numerator (torch.Tensor): sum of number of true positives for classes/labels
# denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives
# for classes/labels
#
# average='macro' or boolean or None:
# numerator (torch.Tensor): number of true positives per class/label
# denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
# positives per class/label
"""
`numerator`, `denominator` and `weight` are three variables chosen to be abstract
representatives of the ones that are measured for cases with different `average` parameters.
`weight` is only used when `average='weighted'`. Actual value of these three variables is
as follows.
average='samples':
numerator (torch.Tensor): sum of metric value for samples
denominator (int): number of samples
average='weighted':
numerator (torch.Tensor): number of true positives per class/label
denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
class/label.
weight (torch.Tensor): number of actual positives per class
average='micro':
numerator (torch.Tensor): sum of number of true positives for classes/labels
denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives for
classes/labels.
average='macro' or boolean or None:
numerator (torch.Tensor): number of true positives per class/label
denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
class/label.
"""

self._numerator: Union[int, torch.Tensor] = 0
self._denominator: Union[int, torch.Tensor] = 0
Expand All @@ -120,16 +122,20 @@ def reset(self) -> None:

@sync_all_reduce("_numerator", "_denominator")
def compute(self) -> Union[torch.Tensor, float]:
# Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
#
# .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight
#
# wherein `weight` is the internal variable `weight` for `'weighted'` option and :math:`1/C`
# for the `macro` one. :math:`C` is the number of classes/labels.
#
# Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.
#
# .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator }
r"""
Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
.. math::
\text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight
wherein `weight` is the internal variable `_weight` for `'weighted'` option and :math:`1/C`
for the `macro` one. :math:`C` is the number of classes/labels.
Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.
.. math::
\text{Precision/Recall} = \frac{ numerator }{ denominator }
"""

if not self._updated:
raise NotComputableError(
Expand Down Expand Up @@ -367,6 +373,33 @@ def thresholded_output_transform(output):

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
r"""
Update the metric state using prediction and target.
Args:
output: a binary tuple of tensors (y_pred, y) whose shapes follow the table below. N stands for the batch
dimension, `...` for possible additional dimensions and C for class dimension.
.. list-table::
:widths: 20 10 10 10
:header-rows: 1
* - Output member\\Data type
- Binary
- Multiclass
- Multilabel
* - y_pred
- (N, ...)
- (N, C, ...)
- (N, C, ...)
* - y
- (N, ...)
- (N, ...)
- (N, C, ...)
For binary and multilabel data, both y and y_pred should consist of 0's and 1's, but for multiclass
data, y_pred and y should consist of probabilities and integers respectively.
"""
self._check_shape(output)
self._check_type(output)
y_pred, y, correct = self._prepare_output(output)
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def attach(self, engine: Engine, name: str, _usage: Union[str, MetricUsage] = Ep
if self.epoch_bound:
# restart average every epoch
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
else:
engine.add_event_handler(Events.STARTED, self.started)
# compute metric
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
# apply running average
Expand Down
46 changes: 44 additions & 2 deletions tests/ignite/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest
import torch
from torch.optim.lr_scheduler import ExponentialLR, StepLR
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ExponentialLR, StepLR

from ignite.engine import Engine, Events
from ignite.handlers.param_scheduler import (
Expand Down Expand Up @@ -650,7 +650,7 @@ def test_lr_scheduler(torch_lr_scheduler_cls, kwargs):
state_dict1 = scheduler1.state_dict()

torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2, **kwargs)
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it is will be skipped"):
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it will be skipped"):
scheduler2 = LRScheduler(torch_lr_scheduler2, use_legacy=True)
state_dict2 = scheduler2.state_dict()

Expand Down Expand Up @@ -1362,3 +1362,45 @@ def test_reduce_lr_on_plateau_scheduler_asserts():
with pytest.raises(ValueError, match=r"Length of argument metric_values should be equal to num_events."):
metric_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
ReduceLROnPlateauScheduler.simulate_values(5, metric_values, 0.01)


@pytest.mark.parametrize("warmup_end_value", [0.23, None])
@pytest.mark.parametrize("T_0", [1, 12])
@pytest.mark.parametrize("T_mult", [1, 3])
def test_create_lr_scheduler_with_warmup_cosine(warmup_end_value, T_0, T_mult):
lr = 0.2
steps = 200
warm_steps = 50
warm_start = 0.023

def get_optim():
t1 = torch.zeros([1], requires_grad=True)
return torch.optim.SGD([t1], lr=lr)

def get_cos_shed():
return CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult)

optimizer = get_optim()
scheduler = get_cos_shed()
cosine_lrs = []
for i in range(steps):
cosine_lrs.append(optimizer.param_groups[0]["lr"])
scheduler.step()

optimizer = get_optim()
scheduler = create_lr_scheduler_with_warmup(
get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
)

warm_lrs = []
real_warm_steps = warm_steps if warmup_end_value is not None else (warm_steps - 1)
for epoch in range(real_warm_steps + steps):
scheduler(None)
warm_lrs.append(optimizer.param_groups[0]["lr"])

if warmup_end_value is not None:
np.testing.assert_allclose(np.linspace(warm_start, warmup_end_value, warm_steps), warm_lrs[:warm_steps])
assert warm_lrs[real_warm_steps:] == cosine_lrs
else:
np.testing.assert_allclose(np.linspace(warm_start, lr, warm_steps), warm_lrs[:warm_steps])
assert warm_lrs[real_warm_steps:] == cosine_lrs
14 changes: 8 additions & 6 deletions tests/ignite/metrics/test_running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def test_epoch_unbound():
batch_size = 10
n_classes = 10
data = list(range(n_iters))
loss_values = iter(range(n_epochs * n_iters))
y_true_batch_values = iter(np.random.randint(0, n_classes, size=(n_epochs * n_iters, batch_size)))
y_pred_batch_values = iter(np.random.rand(n_epochs * n_iters, batch_size, n_classes))
loss_values = iter(range(2 * n_epochs * n_iters))
y_true_batch_values = iter(np.random.randint(0, n_classes, size=(2 * n_epochs * n_iters, batch_size)))
y_pred_batch_values = iter(np.random.rand(2 * n_epochs * n_iters, batch_size, n_classes))

def update_fn(engine, batch):
loss_value = next(loss_values)
Expand All @@ -146,9 +146,7 @@ def update_fn(engine, batch):

running_avg_acc = [None]

@trainer.on(Events.STARTED)
def running_avg_output_init(engine):
engine.state.running_avg_output = None
trainer.state.running_avg_output = None

@trainer.on(Events.ITERATION_COMPLETED, running_avg_acc)
def manual_running_avg_acc(engine, running_avg_acc):
Expand Down Expand Up @@ -187,6 +185,10 @@ def assert_equal_running_avg_output_values(engine):

trainer.run(data, max_epochs=3)

running_avg_acc[0] = None
trainer.state.running_avg_output = None
trainer.run(data, max_epochs=3)


def test_multiple_attach():
n_iters = 100
Expand Down

0 comments on commit e4683de

Please sign in to comment.