Skip to content

Commit

Permalink
Mark all result classes as protected (#11130)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Dec 17, 2021
1 parent 860959f commit 8508cce
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 93 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -129,6 +129,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved ownership of the `Accelerator` instance to the `TrainingTypePlugin`; all training-type plugins now take an optional parameter `accelerator` ([#11022](https://github.com/PyTorchLightning/pytorch-lightning/pull/11022))


- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))


- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Expand Up @@ -351,7 +351,7 @@ def log(
results = self.trainer._results
if results is None:
raise MisconfigurationException(
"You are trying to `self.log()` but the loop `ResultCollection` is not registered"
"You are trying to `self.log()` but the loop's result collection is not registered"
" yet. This is most likely because you are trying to log in a `predict` hook,"
" but it doesn't support logging"
)
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/loops/base.py
Expand Up @@ -19,7 +19,7 @@
from torchmetrics import Metric

import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -282,7 +282,7 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di
destination[key] = v.state_dict()
elif isinstance(v, Loop):
v.state_dict(destination, key + ".")
elif isinstance(v, ResultCollection):
elif isinstance(v, _ResultCollection):
# sync / unsync metrics
v.sync()
destination[key] = v.state_dict()
Expand Down Expand Up @@ -312,7 +312,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[key])
elif (
isinstance(v, ResultCollection)
isinstance(v, _ResultCollection)
and self.trainer is not None
and self.trainer.lightning_module is not None
):
Expand All @@ -324,10 +324,10 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional
if metrics:
metric_attributes.update(metrics)

# The `ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`.
# The `_ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`.
# When creating a checkpoint, the `Metric`s are dropped from the loop `state_dict` to serialize only
# Python primitives. However, their states are saved with the model's `state_dict`.
# On reload, we need to re-attach the `Metric`s back to the `ResultCollection`.
# On reload, we need to re-attach the `Metric`s back to the `_ResultCollection`.
# The references are provided through the `metric_attributes` dictionary.
v.load_state_dict(
state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Expand Up @@ -19,8 +19,8 @@

from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.types import EPOCH_OUTPUT


Expand All @@ -32,7 +32,7 @@ def __init__(self, verbose: bool = True) -> None:
self.epoch_loop = EvaluationEpochLoop()
self.verbose = verbose

self._results = ResultCollection(training=False)
self._results = _ResultCollection(training=False)
self._outputs: List[EPOCH_OUTPUT] = []
self._logged_outputs: List[_OUT_DICT] = []
self._max_batches: List[int] = []
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Expand Up @@ -21,7 +21,7 @@
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _BATCH_OUTPUTS_TYPE
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached, _update_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
self.batch_loop = TrainingBatchLoop()
self.val_loop = loops.EvaluationLoop(verbose=False)

self._results = ResultCollection(training=True)
self._results = _ResultCollection(training=True)
self._outputs: _OUTPUTS_TYPE = []
self._warning_cache = WarningCache()
self._dataloader_iter: Optional[Iterator] = None
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Expand Up @@ -17,7 +17,7 @@
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_deprecation
Expand Down Expand Up @@ -136,7 +136,7 @@ def _skip_backward(self, value: bool) -> None:
self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value

@property
def _results(self) -> ResultCollection:
def _results(self) -> _ResultCollection:
if self.trainer.training:
return self.epoch_loop._results
if self.trainer.validating:
Expand Down
80 changes: 40 additions & 40 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Expand Up @@ -199,7 +199,7 @@ def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "_Meta
return meta


class ResultMetric(Metric, DeviceDtypeModuleMixin):
class _ResultMetric(Metric, DeviceDtypeModuleMixin):
"""Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""

def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
Expand Down Expand Up @@ -316,25 +316,25 @@ def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None:
super().__setstate__(d)

@classmethod
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "ResultMetric":
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "_ResultMetric":
# need to reconstruct twice because `meta` is used in `__init__`
meta = _Metadata._reconstruct(state["meta"])
result_metric = cls(meta, state["is_tensor"])
result_metric.__setstate__(state, sync_fn=sync_fn)
return result_metric

def to(self, *args: Any, **kwargs: Any) -> "ResultMetric":
def to(self, *args: Any, **kwargs: Any) -> "_ResultMetric":
self.__dict__.update(
apply_to_collection(self.__dict__, (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)
)
return self


class ResultMetricCollection(dict):
class _ResultMetricCollection(dict):
"""Dict wrapper for easy access to metadata.
All of the leaf items should be instances of
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric`
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric`
with the same metadata.
"""

Expand All @@ -347,36 +347,36 @@ def has_tensor(self) -> bool:
return any(v.is_tensor for v in self.values())

def __getstate__(self, drop_value: bool = False) -> dict:
def getstate(item: ResultMetric) -> dict:
def getstate(item: _ResultMetric) -> dict:
return item.__getstate__(drop_value=drop_value)

items = apply_to_collection(dict(self), ResultMetric, getstate)
items = apply_to_collection(dict(self), _ResultMetric, getstate)
return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__}

def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None:
# can't use `apply_to_collection` as it does not recurse items of the same type
items = {k: ResultMetric._reconstruct(v, sync_fn=sync_fn) for k, v in state["items"].items()}
items = {k: _ResultMetric._reconstruct(v, sync_fn=sync_fn) for k, v in state["items"].items()}
self.update(items)

@classmethod
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "ResultMetricCollection":
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "_ResultMetricCollection":
rmc = cls()
rmc.__setstate__(state, sync_fn=sync_fn)
return rmc


_METRIC_COLLECTION = Union[_IN_METRIC, ResultMetricCollection]
_METRIC_COLLECTION = Union[_IN_METRIC, _ResultMetricCollection]


class ResultCollection(dict):
class _ResultCollection(dict):
"""
Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetricCollection`
Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric` or
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetricCollection`
Example:
# `device` needs to be provided before logging
result = ResultCollection(training=True, torch.device("cpu"))
result = _ResultCollection(training=True, torch.device("cpu"))
# you can log to a specific collection.
# arguments: fx, key, value, metadata
Expand All @@ -395,14 +395,14 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] =
self.dataloader_idx: Optional[int] = None

@property
def result_metrics(self) -> List[ResultMetric]:
def result_metrics(self) -> List[_ResultMetric]:
o = []

def append_fn(v: ResultMetric) -> None:
def append_fn(v: _ResultMetric) -> None:
nonlocal o
o.append(v)

apply_to_collection(list(self.values()), ResultMetric, append_fn)
apply_to_collection(list(self.values()), _ResultMetric, append_fn)
return o

def _extract_batch_size(self, value: _METRIC_COLLECTION, batch_size: Optional[int], meta: _Metadata) -> int:
Expand All @@ -414,7 +414,7 @@ def _extract_batch_size(self, value: _METRIC_COLLECTION, batch_size: Optional[in
return batch_size

batch_size = 1
is_tensor = value.is_tensor if isinstance(value, ResultMetric) else value.has_tensor
is_tensor = value.is_tensor if isinstance(value, _ResultMetric) else value.has_tensor
if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction:
batch_size = extract_batch_size(self.batch)
self.batch_size = batch_size
Expand Down Expand Up @@ -485,30 +485,30 @@ def log(
self.update_metrics(key, value, batch_size)

def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
"""Create one ResultMetric object per value.
"""Create one _ResultMetric object per value.
Value can be provided as a nested collection
"""

def fn(v: _IN_METRIC) -> ResultMetric:
metric = ResultMetric(meta, isinstance(v, torch.Tensor))
def fn(v: _IN_METRIC) -> _ResultMetric:
metric = _ResultMetric(meta, isinstance(v, torch.Tensor))
return metric.to(self.device)

value = apply_to_collection(value, (torch.Tensor, Metric), fn)
if isinstance(value, dict):
value = ResultMetricCollection(value)
value = _ResultMetricCollection(value)
self[key] = value

def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None:
def fn(result_metric: ResultMetric, v: torch.Tensor) -> None:
def fn(result_metric: _ResultMetric, v: torch.Tensor) -> None:
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
result_metric.forward(v.to(self.device), batch_size)
result_metric.has_reset = False

apply_to_collections(self[key], value, ResultMetric, fn)
apply_to_collections(self[key], value, _ResultMetric, fn)

@staticmethod
def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Tensor]:
def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[torch.Tensor]:
cache = None
if on_step and result_metric.meta.on_step:
cache = result_metric._forward_cache
Expand All @@ -529,11 +529,11 @@ def valid_items(self) -> Generator:
return (
(k, v)
for k, v in self.items()
if not (isinstance(v, ResultMetric) and v.has_reset)
if not (isinstance(v, _ResultMetric) and v.has_reset)
and self.dataloader_idx in (None, v.meta.dataloader_idx)
)

def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]:
def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> Tuple[str, str]:
name = result_metric.meta.name
forked_name = result_metric.meta.forked_name(on_step)
add_dataloader_idx = result_metric.meta.add_dataloader_idx
Expand All @@ -549,11 +549,11 @@ def metrics(self, on_step: bool) -> _METRICS:

for _, result_metric in self.valid_items():

# extract forward_cache or computed from the ResultMetric. ignore when the output is None
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)
# extract forward_cache or computed from the _ResultMetric. ignore when the output is None
value = apply_to_collection(result_metric, _ResultMetric, self._get_cache, on_step, include_none=False)

# convert metric collection to dict container.
if isinstance(value, ResultMetricCollection):
if isinstance(value, _ResultMetricCollection):
value = dict(value.items())

# check if the collection is empty
Expand Down Expand Up @@ -594,23 +594,23 @@ def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> Non
fx: Function to reset
"""

def fn(item: ResultMetric) -> None:
def fn(item: _ResultMetric) -> None:
requested_type = metrics is None or metrics ^ item.is_tensor
same_fx = fx is None or fx == item.meta.fx
if requested_type and same_fx:
item.reset()

apply_to_collection(self, ResultMetric, fn)
apply_to_collection(self, _ResultMetric, fn)

def to(self, *args: Any, **kwargs: Any) -> "ResultCollection":
def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection":
"""Move all data to the given device."""
self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs))

if "device" in kwargs:
self.device = kwargs["device"]
return self

def cpu(self) -> "ResultCollection":
def cpu(self) -> "_ResultCollection":
"""Move all data to CPU."""
return self.to(device="cpu")

Expand All @@ -634,7 +634,7 @@ def __repr__(self) -> str:

def __getstate__(self, drop_value: bool = True) -> dict:
d = self.__dict__.copy()
# all the items should be either `ResultMetric`s or `ResultMetricCollection`s
# all the items should be either `_ResultMetric`s or `_ResultMetricCollection`s
items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items()}
return {**d, "items": items}

Expand All @@ -643,14 +643,14 @@ def __setstate__(
) -> None:
self.__dict__.update({k: v for k, v in state.items() if k != "items"})

def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]:
def setstate(k: str, item: dict) -> Union[_ResultMetric, _ResultMetricCollection]:
if not isinstance(item, dict):
raise ValueError(f"Unexpected value: {item}")
cls = item["_class"]
if cls == ResultMetric.__name__:
cls = ResultMetric
elif cls == ResultMetricCollection.__name__:
cls = ResultMetricCollection
if cls == _ResultMetric.__name__:
cls = _ResultMetric
elif cls == _ResultMetricCollection.__name__:
cls = _ResultMetricCollection
else:
raise ValueError(f"Unexpected class name: {cls}")
_sync_fn = sync_fn or (self[k].meta.sync.fn if k in self else None)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -65,7 +65,7 @@
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
Expand Down Expand Up @@ -2242,7 +2242,7 @@ def progress_bar_metrics(self) -> dict:
return self.logger_connector.progress_bar_metrics

@property
def _results(self) -> Optional[ResultCollection]:
def _results(self) -> Optional[_ResultCollection]:
active_loop = self._active_loop
if active_loop is not None:
return active_loop._results
Expand Down

0 comments on commit 8508cce

Please sign in to comment.