Skip to content

Commit

Permalink
Adopt torchmetrics (#4290)
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster committed Nov 3, 2021
1 parent 8fc555a commit dc58203
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 19 deletions.
3 changes: 2 additions & 1 deletion dependencies/recommended.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.4.2
pytorch-lightning >= 1.5
torchmetrics
onnx
peewee
graphviz
Expand Down
1 change: 1 addition & 0 deletions dependencies/recommended_legacy.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ torchvision == 0.7.0+cpu
# It will install pytorch-lightning 0.8.x and unit tests won't work.
# Latest version has conflict with tensorboard and tensorflow 1.x.
pytorch-lightning
torchmetrics

keras == 2.1.6
onnx
Expand Down
23 changes: 12 additions & 11 deletions nni/retiarii/evaluator/pytorch/cgo/accelerator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Any, Union, Optional, List
import torch
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Any, List, Optional, Union

import torch
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer import Trainer

from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector

from ....serializer import serialize_cls

Expand Down Expand Up @@ -69,9 +70,8 @@ def model_to_device(self) -> None:
# bypass device placement from pytorch lightning
pass

def setup(self, model: torch.nn.Module) -> torch.nn.Module:
self.model_to_device()
return self.model
def setup(self) -> None:
pass

@property
def is_global_zero(self) -> bool:
Expand Down Expand Up @@ -100,8 +100,9 @@ def get_accelerator_connector(
deterministic: bool = False,
precision: int = 32,
amp_backend: str = 'native',
amp_level: str = 'O2',
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
amp_level: Optional[str] = None,
plugins: Optional[Union[List[Union[TrainingTypePlugin, ClusterEnvironment, str]],
TrainingTypePlugin, ClusterEnvironment, str]] = None,
**other_trainier_kwargs) -> AcceleratorConnector:
gpu_ids = Trainer()._parse_devices(gpus, auto_select_gpus, tpu_cores)
return AcceleratorConnector(
Expand Down
8 changes: 4 additions & 4 deletions nni/retiarii/evaluator/pytorch/cgo/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import torchmetrics
from torch.utils.data import DataLoader

import nni
Expand All @@ -19,7 +19,7 @@

@serialize_cls
class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0,
learning_rate: float = 0.001,
weight_decay: float = 0.,
Expand Down Expand Up @@ -119,7 +119,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
Class for optimizer (not an instance). default: ``Adam``
"""

def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
Expand Down Expand Up @@ -180,7 +180,7 @@ def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)


Expand Down
7 changes: 4 additions & 3 deletions nni/retiarii/evaluator/pytorch/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader

import nni
Expand Down Expand Up @@ -140,7 +141,7 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###

class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
Expand Down Expand Up @@ -213,7 +214,7 @@ def _get_validation_metrics(self):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}


class _AccuracyWithLogits(pl.metrics.Accuracy):
class _AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)

Expand Down Expand Up @@ -278,7 +279,7 @@ def __init__(self, criterion: nn.Module = nn.MSELoss,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)

Expand Down

0 comments on commit dc58203

Please sign in to comment.