Skip to content

Commit

Permalink
[Enhance] Enable full precision training on Ascend NPU. (#3085)
Browse files Browse the repository at this point in the history
## Motivation

We will support full precision training on the next generation Ascend
NPU, so there is no need to enable mixed precision by default.

## Modification

Determine whether the current chip supports full precision training, and
automatically enable mixed precision.

## BC-breaking (Optional)

Not affected.

## Use cases (Optional)

We have verified the correctness on the Ascend NPU.
  • Loading branch information
Ginray committed Jun 7, 2023
1 parent 41cfa70 commit 0beaf69
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
4 changes: 2 additions & 2 deletions mmseg/apis/train.py
Expand Up @@ -15,7 +15,7 @@
from mmseg.core import DistEvalHook, EvalHook, build_optimizer
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import (build_ddp, build_dp, find_latest_checkpoint,
get_root_logger)
get_root_logger, is_npu_support_full_precision)


def init_random_seed(seed=None, device='cuda'):
Expand Down Expand Up @@ -136,7 +136,7 @@ def train_segmentor(model,
logger=logger,
meta=meta))

if cfg.device == 'npu':
if cfg.device == 'npu' and not is_npu_support_full_precision():
optimiter_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
cfg.optimizer_config = optimiter_config if \
not cfg.optimizer_config else cfg.optimizer_config
Expand Down
6 changes: 4 additions & 2 deletions mmseg/utils/__init__.py
Expand Up @@ -3,9 +3,11 @@
from .logger import get_root_logger
from .misc import find_latest_checkpoint
from .set_env import setup_multi_processes
from .util_distribution import build_ddp, build_dp, get_device
from .util_distribution import (build_ddp, build_dp, get_device,
is_npu_support_full_precision)

__all__ = [
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device'
'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device',
'is_npu_support_full_precision'
]
9 changes: 9 additions & 0 deletions mmseg/utils/util_distribution.py
Expand Up @@ -94,6 +94,15 @@ def is_npu_available():
return hasattr(torch, 'npu') and torch.npu.is_available()


def is_npu_support_full_precision() -> bool:
"""Returns True if npu devices support full precision training."""
if not is_npu_available():
return False
import torch_npu.npu.utils as npu_utils
version_of_support_full_precision = 220
return npu_utils.get_soc_version() >= version_of_support_full_precision


def get_device():
"""Returns an available device, cpu, npu, cuda or mlu."""
is_device_available = {
Expand Down

0 comments on commit 0beaf69

Please sign in to comment.