Skip to content

Commit

Permalink
Move ipu precision flag check to IPUPrecisionPlugin init (#12148)
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Mar 5, 2022
1 parent b5fe056 commit 91052dc
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
13 changes: 12 additions & 1 deletion pytorch_lightning/plugins/precision/ipu.py
Expand Up @@ -27,9 +27,20 @@


class IPUPrecisionPlugin(PrecisionPlugin):
"""Precision plugin for IPU integration."""
"""Precision plugin for IPU integration.
Raises:
ValueError:
If the precision is neither 16 nor 32.
"""

def __init__(self, precision: int) -> None:
supported_precision_values = (16, 32)
if precision not in supported_precision_values:
raise ValueError(
f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision_values}."
)
super().__init__()
self.precision = precision

Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Expand Up @@ -685,12 +685,6 @@ def _check_and_init_precision(self) -> PrecisionPlugin:

def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, AMP type, and accelerator."""
# TODO: change exception type to ImpactableConfigurationException
if isinstance(self.accelerator, IPUAccelerator):
if self._precision_flag not in (16, 32):
raise MisconfigurationException(
f"`Trainer(accelerator='ipu', precision={self._precision_flag!r})` is not supported."
)
if isinstance(self.accelerator, TPUAccelerator):
if self._precision_flag == 64:
raise MisconfigurationException(
Expand Down
4 changes: 2 additions & 2 deletions tests/accelerators/test_accelerator_connector.py
Expand Up @@ -929,9 +929,9 @@ def test_unsupported_ipu_choice(mock_ipu_acc_avail, monkeypatch):

monkeypatch.setattr(imports, "_IPU_AVAILABLE", True)
monkeypatch.setattr(ipu, "_IPU_AVAILABLE", True)
with pytest.raises(MisconfigurationException, match=r"accelerator='ipu', precision='bf16'\)` is not supported"):
with pytest.raises(ValueError, match=r"accelerator='ipu', precision='bf16'\)` is not supported"):
Trainer(accelerator="ipu", precision="bf16")
with pytest.raises(MisconfigurationException, match=r"accelerator='ipu', precision=64\)` is not supported"):
with pytest.raises(ValueError, match=r"accelerator='ipu', precision=64\)` is not supported"):
Trainer(accelerator="ipu", precision=64)


Expand Down

0 comments on commit 91052dc

Please sign in to comment.