Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add RecorderHook #1300

Open
wants to merge 39 commits into
base: main
Choose a base branch
from

Conversation

Xinyu302
Copy link
Contributor

@Xinyu302 Xinyu302 commented Aug 10, 2023

A glcc summer camp project.

Motivation

The user expects to visualize the output of any layer of the network in the visualization hook, not necessarily the output of nn.module.forward, but also intermediate variables, and will inevitably need to invasively modify the code to update the output into the message hub. Then retrieve the value of the message hub in the visualization hook for visualization. Rather than intrusively modifying the code, we want to have a scheme that can non-intrusively obtain the output of any layer of the network. In addition, it also needs to provide the ability to get properties of a specified instance.

Modification

In short, use ast module to modify the forward function of runner.model.
Here is a simple example.

class ToyModel(BaseModel):

    def __init__(self, data_preprocessor=None):
        super().__init__(data_preprocessor=data_preprocessor)
        self.linear1 = nn.Linear(2, 2)
        self.linear2 = nn.Linear(2, 1)

    def forward(self, inputs, data_samples, mode='tensor'):
        if isinstance(inputs, list):
            inputs = torch.stack(inputs)
        if isinstance(data_samples, list):
            data_sample = torch.stack(data_samples)
        outputs = self.linear1(inputs)
        outputs = self.linear2(outputs)

        if mode == 'tensor':
            return outputs
        elif mode == 'loss':
            loss = (data_sample - outputs).sum()
            outputs = dict(loss=loss)
            return outputs
        elif mode == 'predict':
            return outputs

The API of RecorderHook is like this:

cfg.custom_hooks = [
    dict(
        type='RecorderHook',
        recorders=[
            dict(type='FunctionRecorder', ...),  # Function recorder1
            dict(type='FunctionRecorder', ...),  # Function recorder2
            dict(type='AttributeRecorder', ...)  # AttributeRecorder1
            ... 
        ],
    )
]

RecorderHook uses FunctionRecorder and AttributeRecorder to record different things in forward method.

FunctionRecorder

function

Gets the output and intermediate variables of the specified function or method. If the function has several intermediate variables of the same name

case

    custom_hooks=[
        dict(
            type='RecorderHook',
            recorders=[
                dict(type='FunctionRecorder', target='outputs', index=[1])
            ],
            save_dir='./work_dir',
            print_modification=True)
    ]

Forward method after modification

def forward(self, inputs, data_samples, mode='tensor'):
    from mmengine.logging import MessageHub
    import copy
    message_hub = MessageHub.get_current_instance()
    if isinstance(inputs, list):
        inputs = torch.stack(inputs)
    if isinstance(data_samples, list):
        data_sample = torch.stack(data_samples)
    outputs = self.linear1(inputs)
    outputs = self.linear2(outputs)
    message_hub.update_info('runner_model:forward:outputs@1', outputs)
    if mode == 'tensor':
        return outputs
    elif mode == 'loss':
        loss = (data_sample - outputs).sum()
        outputs = dict(loss=loss)
        return outputs
    elif mode == 'predict':
        return outputs

AttributeRecorder

function

Gets the value of the specified property. Insert the recorder code just at the front of the function.

case

    custom_hooks=[
        dict(
            type='RecorderHook',
            recorders=[
                dict(type='AttributeRecorder', target='self.linear1.weight')
            ],
            save_dir='./work_dir',
            print_modification=True)
    ]

Forward method after modification

def forward(self, inputs, data_samples, mode='tensor'):
    from mmengine.logging import MessageHub
    import copy
    message_hub = MessageHub.get_current_instance()
    if isinstance(self.linear1.weight, torch.Tensor):
        _deep_copy_self_linear1_weight = self.linear1.weight.detach().clone()
    else:
        _deep_copy_self_linear1_weight = copy.deepcopy(self.linear1.weight)
    message_hub.update_info('runner_model:forward:self.linear1.weight', _deep_copy_self_linear1_weight)
    if isinstance(inputs, list):
        inputs = torch.stack(inputs)
    if isinstance(data_samples, list):
        data_sample = torch.stack(data_samples)
    outputs = self.linear1(inputs)
    outputs = self.linear2(outputs)
    if mode == 'tensor':
        return outputs
    elif mode == 'loss':
        loss = (data_sample - outputs).sum()
        outputs = dict(loss=loss)
        return outputs
    elif mode == 'predict':
        return outputs

A more complicated case

Users can specify the model and function that they want to record.

class MMResNet50(BaseModel):

    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels
    custom_hooks=[
        dict(
            type='RecorderHook',
            recorders=[
                dict(
                    model='resnet',
                    method='_forward_impl',
                    type='FunctionRecorder',
                    target='x', index=[0,1,2])
            ],
            save_dir='./work_dir',
            print_modification=True)
    ]

after modification

def _forward_impl(self, x: Tensor) -> Tensor:
    from mmengine.logging import MessageHub
    import copy
    message_hub = MessageHub.get_current_instance()
    x = self.conv1(x)
    message_hub.update_info('resnet:_forward_impl:x@0', x)
    x = self.bn1(x)
    message_hub.update_info('resnet:_forward_impl:x@1', x)
    x = self.relu(x)
    message_hub.update_info('resnet:_forward_impl:x@2', x)
    x = self.maxpool(x)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.fc(x)
    return x

TODO

  • Add unit test
  • Store data recorded to message hub, provide visualizer with support
  • Add docstring and type hint.

@Xinyu302 Xinyu302 changed the title [Feature] Add RecorderHook [WIP][Feature] Add RecorderHook Aug 11, 2023
@Xinyu302 Xinyu302 changed the title [WIP][Feature] Add RecorderHook [Feature] Add RecorderHook Sep 17, 2023


class FunctionRecorderTransformer(ast.NodeTransformer):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docstring and type hint.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants