Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add build_models() / build_evaluator() / build_log_processor() #1310

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

GuoPingPan
Copy link

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Some methods in Runner like Runner.build_model / .build_evaluator / .build_log_processor can not be use as staticmethod which can be used without creating an instance of Runner.
So in order to build the specific module like model / evaluator / log_processor, I add build functions the same as runner's method, including:

def build_model(model: Union[nn.Module, Dict]) -> nn.Module
def build_log_processor(log_processor: Union[LogProcessor, Dict]) -> LogProcessor
def build_evaluator(evaluator: Union[Dict, List, Evaluator]) -> Evaluator:

Modification

  1. mmengine/runner/log_processor.py
# line 580
def build_log_processor(
        log_processor: Union[LogProcessor, Dict]) -> LogProcessor:
    """Build test log_processor.

    Examples of ``log_processor``:

        # `LogProcessor` will be used
        log_processor = dict()

        # custom log_processor
        log_processor = dict(type='CustomLogProcessor')

    Args:
        log_processor (LogProcessor or dict): A log processor or a dict
        to build log processor. If ``log_processor`` is a log processor
        object, just returns itself.

    Returns:
        :obj:`LogProcessor`: Log processor object build from
        ``log_processor_cfg``.
    """
    if isinstance(log_processor, LogProcessor):
        return log_processor
    elif not isinstance(log_processor, dict):
        raise TypeError(
            'log processor should be a LogProcessor object or dict, but'
            f'got {log_processor}')

    log_processor_cfg = copy.deepcopy(log_processor)  # type: ignore

    if 'type' in log_processor_cfg:
        log_processor = LOG_PROCESSORS.build(log_processor_cfg)
    else:
        log_processor = LogProcessor(**log_processor_cfg)  # type: ignore

    return log_processor  # type: ignore
  1. add mmengine/evaluator/builder.py
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Union

from mmengine.registry import EVALUATOR
from .evaluator import Evaluator


def build_evaluator(evaluator: Union[Dict, List, Evaluator]) -> Evaluator:
    """Build evaluator.

    Examples of ``evaluator``::

        # evaluator could be a built Evaluator instance
        evaluator = Evaluator(metrics=[ToyMetric()])

        # evaluator can also be a list of dict
        evaluator = [
            dict(type='ToyMetric1'),
            dict(type='ToyEvaluator2')
        ]

        # evaluator can also be a list of built metric
        evaluator = [ToyMetric1(), ToyMetric2()]

        # evaluator can also be a dict with key metrics
        evaluator = dict(metrics=ToyMetric())
        # metric is a list
        evaluator = dict(metrics=[ToyMetric()])

    Args:
        evaluator (Evaluator or dict or list): An Evaluator object or a
            config dict or list of config dict used to build an Evaluator.

    Returns:
        Evaluator: Evaluator build from ``evaluator``.
    """
    if isinstance(evaluator, Evaluator):
        return evaluator
    elif isinstance(evaluator, dict):
        # if `metrics` in dict keys, it means to build customized evalutor
        if 'metrics' in evaluator:
            evaluator.setdefault('type', 'Evaluator')
            return EVALUATOR.build(evaluator)
        # otherwise, default evalutor will be built
        else:
            return Evaluator(evaluator)  # type: ignore
    elif isinstance(evaluator, list):
        # use the default `Evaluator`
        return Evaluator(evaluator)  # type: ignore
    else:
        raise TypeError(
            'evaluator should be one of dict, list of dict, and Evaluator'
            f', but got {evaluator}')
from .builder import build_evaluator # NOTE
from .evaluator import Evaluator
from .metric import BaseMetric, DumpResults
from .utils import get_metric_value

__all__ = [
    'BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults',
    'build_evaluator' # NOTE
]
  1. add mmengine/model/builder.py
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Union

import torch.nn as nn

from mmengine.registry import MODELS


def build_model(model: Union[nn.Module, Dict]) -> nn.Module:
    """Build function of Model.

    If ``model`` is a dict, it will be used to build a nn.Module object.
    Else, if ``model`` is a nn.Module object it will be returned directly.

    An example of ``model``::

        model = dict(type='ResNet')

    Args:
        model (nn.Module or dict): A ``nn.Module`` object or a dict to
            build nn.Module object. If ``model`` is a nn.Module object,
            just returns itself.

    Note:
        The returned model must implement ``train_step``, ``test_step``
        if ``runner.train`` or ``runner.test`` will be called. If
        ``runner.val`` will be called or ``val_cfg`` is configured,
        model must implement `val_step`.

    Returns:
        nn.Module: Model build from ``model``.
    """
    if isinstance(model, nn.Module):
        return model
    elif isinstance(model, dict):
        model = MODELS.build(model)
        return model  # type: ignore
    else:
        raise TypeError('model should be a nn.Module object or dict, '
                        f'but got {model}')
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version
from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage,
                             MomentumAnnealingEMA, StochasticWeightAverage)
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .builder import build_model # NOTE
from .test_time_aug import BaseTTAModel
from .utils import (convert_sync_batchnorm, detect_anomalous_params,
                    merge_dict, revert_sync_batchnorm, stack_batch)
from .weight_init import (BaseInit, Caffe2XavierInit, ConstantInit,
                          KaimingInit, NormalInit, PretrainedInit,
                          TruncNormalInit, UniformInit, XavierInit,
                          bias_init_with_prob, caffe2_xavier_init,
                          constant_init, initialize, kaiming_init, normal_init,
                          trunc_normal_init, uniform_init, update_init_info,
                          xavier_init)
from .wrappers import (MMDistributedDataParallel,
                       MMSeparateDistributedDataParallel, is_model_wrapper)

__all__ = [
    'MMDistributedDataParallel', 'is_model_wrapper', 'BaseAveragedModel',
    'StochasticWeightAverage', 'ExponentialMovingAverage',
    'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor',
    'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule',
    'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList',
    'ModuleDict', 'Sequential', 'revert_sync_batchnorm', 'update_init_info',
    'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
    'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
    'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit',
    'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit',
    'Caffe2XavierInit', 'PretrainedInit', 'initialize',
    'convert_sync_batchnorm', 'BaseTTAModel', 'build_model' # NOTE
]

if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
    from .wrappers import MMFullyShardedDataParallel  # noqa:F401
    __all__.append('MMFullyShardedDataParallel')

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMCls.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@CLAassistant
Copy link

CLAassistant commented Aug 16, 2023

CLA assistant check
All committers have signed the CLA.

Copy link
Author

@GuoPingPan GuoPingPan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether the build_log processor is needed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants