Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support calculating loss during validation #1503

Merged
merged 16 commits into from May 17, 2024

Conversation

fanqiNO1
Copy link
Collaborator

@fanqiNO1 fanqiNO1 commented Feb 22, 2024

Background

Since early stopping requires validation loss as a possible metric, mmengine currently does not support calculating and parsing validation loss as a metric.

However, due to the inconsistency of model implementations and the fact that calculating validation loss is not a common requirement, the process of calculating validation loss should not be initiated by mmengine, but rather, initiated by the model and returned by mmengine with a convention to be parsed and returned as a metric.

Thus this PR aims to implement this return-and-resolve convention without introducing breaking change.

Design

In order not to introduce breaking change, we chose to return the loss computed by the model at val_step (model.forward with mode='predict' or predict) wrapped by BaseDataElement and append after the val step result.

Therefore, mmengine needs to get the last item of the result of val_step in ValLoop and determine whether it is validation loss or not. If it is validation loss, it will perform the relevant computation and return it at the end of the ValLoop, and then compute other metrics based on the items other than the validation loss, e.g., the accuracy, etc. If it is not a val loss, it will not be processed.

Adaptation

Custom Model

Take https://github.com/open-mmlab/mmengine/blob/02f80e8bdd38f6713e04a872304861b02157905a/examples/distributed_training.py#L14-#L25 as an example.

class MMResNet50(BaseModel):

    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
-          return x, labels
+          val_loss = {'loss': F.cross_entropy(x, labels)}
+          return x, labels, BaseDataElement(loss=val_loss)

MMPreTrain

Take https://github.com/open-mmlab/mmpretrain/blob/17a886cb5825cd8c26df4e65f7112d404b99fe12/mmpretrain/models/classifiers/image.py#L248-L249 as an example.

    def predict(self,
                inputs: torch.Tensor,
                data_samples: Optional[List[DataSample]] = None,
                **kwargs) -> List[DataSample]:
        """Predict results from a batch of inputs.

        Args:
            inputs (torch.Tensor): The input tensor with shape
                (N, C, ...) in general.
            data_samples (List[DataSample], optional): The annotation
                data of every samples. Defaults to None.
            **kwargs: Other keyword arguments accepted by the ``predict``
                method of :attr:`head`.
        """
        feats = self.extract_feat(inputs)
-       return self.head.predict(feats, data_samples, **kwargs)
+       preds = self.head.predict(feats, data_samples, **kwargs)
+       loss = self.head.loss(feats, data_samples)
+       loss_sample = DataSample(loss=loss)
+       preds.append(loss_sample)
+       return preds

MMPose

Calculating loss in this way maybe not correct.

Take https://github.com/open-mmlab/mmpose/blob/5a3be9451bdfdad2053a90dc1199e3ff1ea1a409/mmpose/models/pose_estimators/topdown.py#L99-#L120 as an example.

    def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.

        Args:
            inputs (Tensor): Inputs with shape (N, C, H, W)
            data_samples (List[:obj:`PoseDataSample`]): The batch
                data samples

        Returns:
            list[:obj:`PoseDataSample`]: The pose estimation results of the
            input images. The return value is `PoseDataSample` instances with
            ``pred_instances`` and ``pred_fields``(optional) field , and
            ``pred_instances`` usually contains the following keys:

                - keypoints (Tensor): predicted keypoint coordinates in shape
                    (num_instances, K, D) where K is the keypoint number and D
                    is the keypoint dimension
                - keypoint_scores (Tensor): predicted keypoint scores in shape
                    (num_instances, K)
        """
        assert self.with_head, (
            'The model must have head to perform prediction.')

        if self.test_cfg.get('flip_test', False):
            _feats = self.extract_feat(inputs)
            _feats_flip = self.extract_feat(inputs.flip(-1))
            feats = [_feats, _feats_flip]
+           loss = self.head.loss(_feats, data_samples, train_cfg=self.train_cfg)
        else:
            feats = self.extract_feat(inputs)
+           loss = self.head.loss(feats, data_samples, train_cfg=self.train_cfg)

        preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg)

        if isinstance(preds, tuple):
            batch_pred_instances, batch_pred_fields = preds
        else:
            batch_pred_instances = preds
            batch_pred_fields = None

        results = self.add_pred_to_datasample(batch_pred_instances,
                                              batch_pred_fields, data_samples)
+       results.append(loss_sample)
        return results

In addition, you should add dict(type='GenerateTarget', encoder=codec) to val_pipeline similar to train_pipeline.

@MikasaLee
Copy link

Hope to merge val loss into mmengine as soon as possible, which is a very useful feature

@zhouzaida zhouzaida linked an issue Feb 26, 2024 that may be closed by this pull request
HAOCHENYE
HAOCHENYE previously approved these changes Mar 5, 2024
HAOCHENYE
HAOCHENYE previously approved these changes May 6, 2024
mmengine/runner/loops.py Outdated Show resolved Hide resolved
@zhouzaida zhouzaida merged commit d1f1aab into open-mmlab:main May 17, 2024
9 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature] Support calculating loss in the validation step
4 participants