Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] (suggested fix) mmpose.models.pose_estimators.topdown.TopdownPoseEstimator is unable to be symbolically traced because of untraceable add_pred_to_datasample() and loss() #3012

Open
2 tasks done
elisa-aleman opened this issue Apr 5, 2024 · 1 comment
Assignees

Comments

@elisa-aleman
Copy link

elisa-aleman commented Apr 5, 2024

Prerequisite

Environment

computer not available at the time

Using:
torch 2.0.0+cu118
torchvision: 0.15.0+cu118
mmengine: 0.10.3
mmrazor: 1.0.0
mmpose: 1.3.1

Reproduces the problem - code sample

Using mmrazor to quantize this model, I stumbled upon an error when the symbolic_trace for the fx graph was being made.

Applied fixes for torch 2.0.0 incompatibility suggested in mmrazor #632 and a fix for nn.Parameters inside TopdownPoseEstimator not being traced in mmrazor #633

from mmrazor.models.task_modules.tracer.fx.custom_tracer import CustomTracer
from mmpose.models.pose_estimators.topdown import TopdownPoseEstimator
from mmengine.config import Config


cfg = Config.fromfile('/mmpose/configs/body_2d_keypoint/rtmpose/coco/rtmpose-t_8xb256-420e_coco-256x192.py')

rtmpose = TopdownPoseEstimator(
    backbone=cfg.model.backbone,
    neck=cfg.model.neck,
    head=cfg.model.head,
    train_cfg=cfg.train_cfg,
    data_preprocessor=cfg.model.data_preprocessor,
)

tracer = CustomTracer(
    skipped_methods=[
        'mmpose.models.heads.RTMCCHead.loss',
        'mmpose.models.heads.RTMCCHead.predict',
    ]
)
traced_graph = tracer.trace(rtmpose)

Reproduces the problem - error message

Traceback (most recent call last):
  File "..../site-packages/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 421, in trace
    'output', (self.create_arg(fn(*args)), ), {},

  File "..../site-packages/mmpose/models/pose_estimators/base.py", line 161, in forward
    return self.predict(inputs, data_samples)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "..../site-packages/mmpose/models/pose_estimators/topdown.py", line 117, in predict
    results = self.add_pred_to_datasample(batch_pred_instances,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "..../site-packages/mmpose/models/pose_estimators/topdown.py", line 138, in predict
    assert len(batch_pred_instances) == len(batch_data_samples)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "..../site-packages/torch/fx/proxy.py", line 420, in _len_
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported by default. If you want this call to be recorded, please call 'torch.fx.wrap('len') at module scope

Additional information

for loss() I suggest the following patch:

@@ -68,8 +68,8 @@ def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
         feats = self.extract_feat(inputs)

-        losses = dict()
-
         if self.with_head:
-            losses.update(
                self.head.loss(feats, data_samples, train_cfg=self.train_cfg))
+            losses = {self.head.loss(feats, data_samples, train_cfg=self.train_cfg)}
+        else:
+            losses = {}

        return losses

for add_pred_to_datasample() I suggest the following:

@@ -138,48 +138,62 @@ def add_pred_to_datasample(self, batch_pred_instances: InstanceList,
-        assert len(batch_pred_instances) == len(batch_data_samples)
-        if batch_pred_fields is None:
-            batch_pred_fields = []
         output_keypoint_indices = self.test_cfg.get('output_keypoint_indices',
                                                    None)
-
-        for pred_instances, pred_fields, data_sample in zip_longest(
-                batch_pred_instances, batch_pred_fields, batch_data_samples):
-
-            gt_instances = data_sample.gt_instances
-
-            # convert keypoint coordinates from input space to image space
-            input_center = data_sample.metainfo['input_center']
-            input_scale = data_sample.metainfo['input_scale']
-            input_size = data_sample.metainfo['input_size']
-
-            pred_instances.keypoints[..., :2] = \
-                pred_instances.keypoints[..., :2] / input_size * input_scale \
-                + input_center - 0.5 * input_scale
-            if 'keypoints_visible' not in pred_instances:
-                pred_instances.keypoints_visible = \
-                    pred_instances.keypoaint_scores
-
-            if output_keypoint_indices is not None:
-                # select output keypoints with given indices
-                num_keypoints = pred_instances.keypoints.shape[1]
-                for key, value in pred_instances.all_items():
-                    if key.startswith('keypoint'):
-                        pred_instances.set_field(
-                            value[:, output_keypoint_indices], key)
-
-            # add bbox information into pred_instances
-            pred_instances.bboxes = gt_instances.bboxes
-            pred_instances.bbox_scores = gt_instances.bbox_scores
-
-            data_sample.pred_instances = pred_instances
-
-            if pred_fields is not None:
-                if output_keypoint_indices is not None:
-                    # select output heatmap channels with keypoint indices
-                    # when the number of heatmap channel matches num_keypoints
-                    for key, value in pred_fields.all_items():
-                        if value.shape[0] != num_keypoints:
-                            continue
-                        pred_fields.set_field(value[output_keypoint_indices],
-                                              key)
-                data_sample.pred_fields = pred_fields
+        batch_data_samples = _add_pred_to_datasample(
+            output_keypoint_indices,
+            batch_pred_instances,
+            batch_pred_fields,
+            batch_data_samples
+            )
         return batch_data_samples
+
+
+ @torch.fx.wrap
+ def _add_pred_to_datasample(
+     output_keypoint_indices,
+     batch_pred_instances: InstanceList,
+     batch_pred_fields: Optional[PixelDataList],
+     batch_data_samples: SampleList) -> SampleList:
+     assert len(batch_pred_instances) == len(batch_data_samples)
+     if batch_pred_fields is None:
+         batch_pred_fields = []
+ 
+     for pred_instances, pred_fields, data_sample in zip_longest(
+             batch_pred_instances, batch_pred_fields, batch_data_samples):
+ 
+         gt_instances = data_sample.gt_instances
+ 
+         # convert keypoint coordinates from input space to image space
+         input_center = data_sample.metainfo['input_center']
+         input_scale = data_sample.metainfo['input_scale']
+         input_size = data_sample.metainfo['input_size']
+ 
+         pred_instances.keypoints[..., :2] = \
+             pred_instances.keypoints[..., :2] / input_size * input_scale \
+             + input_center - 0.5 * input_scale
+         if 'keypoints_visible' not in pred_instances:
+             pred_instances.keypoints_visible = \
+                 pred_instances.keypoaint_scores
+ 
+         if output_keypoint_indices is not None:
+             # select output keypoints with given indices
+             num_keypoints = pred_instances.keypoints.shape[1]
+             for key, value in pred_instances.all_items():
+                 if key.startswith('keypoint'):
+                     pred_instances.set_field(
+                         value[:, output_keypoint_indices], key)
+ 
+         # add bbox information into pred_instances
+         pred_instances.bboxes = gt_instances.bboxes
+         pred_instances.bbox_scores = gt_instances.bbox_scores
+ 
+         data_sample.pred_instances = pred_instances
+ 
+         if pred_fields is not None:
+             if output_keypoint_indices is not None:
+                 # select output heatmap channels with keypoint indices
+                 # when the number of heatmap channel matches num_keypoints
+                 for key, value in pred_fields.all_items():
+                     if value.shape[0] != num_keypoints:
+                         continue
+                     pred_fields.set_field(value[output_keypoint_indices],
+                                           key)
+             data_sample.pred_fields = pred_fields
+         return batch_data_samples

This solves the issue with fx tracing, although there's still other issues I have yet to solve.

@elisa-aleman
Copy link
Author

Added reproducing code and full fix suggestion

@elisa-aleman elisa-aleman changed the title [Bug] mmpose.models.pose_estimators.topdown.TopdownPoseEstimator is unable to be symbolically traced because of untraceable add_pred_to_datasample() and loss() [Bug] [suggested fix] mmpose.models.pose_estimators.topdown.TopdownPoseEstimator is unable to be symbolically traced because of untraceable add_pred_to_datasample() and loss() Apr 9, 2024
@elisa-aleman elisa-aleman changed the title [Bug] [suggested fix] mmpose.models.pose_estimators.topdown.TopdownPoseEstimator is unable to be symbolically traced because of untraceable add_pred_to_datasample() and loss() [Bug] (suggested fix) mmpose.models.pose_estimators.topdown.TopdownPoseEstimator is unable to be symbolically traced because of untraceable add_pred_to_datasample() and loss() Apr 9, 2024
@elisa-aleman elisa-aleman changed the title [Bug] (suggested fix) mmpose.models.pose_estimators.topdown.TopdownPoseEstimator is unable to be symbolically traced because of untraceable add_pred_to_datasample() and loss() [kind/bug] (suggested fix) mmpose.models.pose_estimators.topdown.TopdownPoseEstimator is unable to be symbolically traced because of untraceable add_pred_to_datasample() and loss() Apr 9, 2024
@elisa-aleman elisa-aleman changed the title [kind/bug] (suggested fix) mmpose.models.pose_estimators.topdown.TopdownPoseEstimator is unable to be symbolically traced because of untraceable add_pred_to_datasample() and loss() [Bug] (suggested fix) mmpose.models.pose_estimators.topdown.TopdownPoseEstimator is unable to be symbolically traced because of untraceable add_pred_to_datasample() and loss() Apr 9, 2024
@elisa-aleman elisa-aleman changed the title [Bug] (suggested fix) mmpose.models.pose_estimators.topdown.TopdownPoseEstimator is unable to be symbolically traced because of untraceable add_pred_to_datasample() and loss() [Bug] (suggested fix) mmpose.models.pose_estimators.topdown.TopdownPoseEstimator is unable to be symbolically traced because of untraceable add_pred_to_datasample() and loss() Apr 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants