Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gy77/add freeze hook #1387

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

gy77/add freeze hook #1387

wants to merge 6 commits into from

Conversation

gy-7
Copy link
Contributor

@gy-7 gy-7 commented Oct 10, 2023

Motivation

Motivation:

  1. Freeze some parameters of the model when training the model.

Goal:

  1. Specify the epoch to freeze the specified network layer.
  2. Available for all downstream repositories.


Modification

Add FreezeHook and FreezeHook unit tests.


Use cases

  1. Network layers matching freeze_layers are freeze before freeze_iter/freeze_epoch starts.
  2. Network layers matching unfreeze_layers are freeze before unfreeze_iter/unfreeze_epoch starts.
  3. freeze_layers/unfreeze_layers matches network layers via regular expression
  4. The index of iter/epoch starts at 0, with epoch=0 for the first epoch.
  5. unfreeze_iter, unfreeze_epoch and unfreeze_layers are optional. If freeze_epoch/freeze_iter is not None, unfreeze_layers must not be None.
  6. Only one of freeze_iter and freeze_epoch can be set, as well as unfreeze_iter and unfreeze_epoch.
ImageClassifier(
    (backbone):ResNet(
        ...
        (layer1):Sequential(...)
        (layer2):Sequential(...)
        (layer3):Sequential(...)
        (layer4):Sequential(...)
    )
    (neck):GlobalAveragePooling2d(...)
    (head):Linear(...)
)
  1. Freeze the parameters of backbone before the start of 1st training epoch.
custom_hooks = [
...
dict(
    type="FreezeHook",
    freeze_layers="backbone.*",
    freeze_epoch=0)
]
  1. Freeze the layer1 and layer2 parameters in the backbone before the start of 10th training epoch.
custom_hooks = [
...
dict(
    type="FreezeHook",
    freeze_layers="backbone.layer1.*|backbone.layer2.*",
    freeze_epoch=10)
]
  1. Freeze the parameters of backbone before the start of 1st training epoch. Unfreeze the parameters of the the backbone before the start of 10th training epoch.
custom_hooks = [
 ...
 dict(
     type="FreezeHook",
     freeze_layers="backbone.*",
     freeze_epoch=0,
     unfreeze_layers="backbone.*",
     unfreeze_epoch=9)
]
  1. The verbose parameter is used to determine whether to print the requires_grad variable for each model layer.
custom_hooks = [
 ...
 dict(
     type="FreezeHook",
     freeze_layers="backbone.*",
     freeze_epoch=1,
     verbose=True)
]
mmengine - INFO - backbone.conv1.weight requires_grad: True
mmengine - INFO - backbone.bn1.weight requires_grad: True
...
mmengine - INFO - head.light_head.weight requires_grad: True
mmengine - INFO - head.light_head.bias requires_grad: True

mmengine/hooks/freeze_hook.py Show resolved Hide resolved
unfreeze_epoch (int): The epoch number to start unfreezing layers.
unfreeze_layers (tuple[str]): Model layers containing the keyword in
unfreeze_layers will unfreeze the gradient.
log_grad (bool): Whether to log the requires_grad of each layer.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
log_grad (bool): Whether to log the requires_grad of each layer.
verbose (bool): Whether to log the requires_grad of each layer.


Args:
freeze_epoch (int): The epoch number to start freezing layers.
freeze_layers (tuple[str]): Model layers containing the keyword in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest making freeze_layers the first argument, and it should be a regex expression

mmengine/hooks/freeze_hook.py Show resolved Hide resolved
self.unfreeze_layers = unfreeze_layers
self.log_grad = log_grad

def modify_layers_grad(self, model, layers, requires_grad):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def modify_layers_grad(self, model, layers, requires_grad):
def _modify_layers_grad(self, model, layers, requires_grad):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the type hint

v.requires_grad = requires_grad
break

def log_model_grad(self, model, log_grad=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def log_model_grad(self, model, log_grad=False):
def _log_model_grad(self, model, log_grad=False):

mmengine/hooks/freeze_hook.py Show resolved Hide resolved

def __init__(
self,
freeze_layers: Union[Sequence[str], str],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it has been a regex expression, it is not necessary to make it a tuple of str ('exp1|exp2|exp3' is enough)

(tuple, list)) and not isinstance(freeze_layers[0], str):
raise TypeError(
'`freeze_layers` must be a tuple or list of string')
if not isinstance(freeze_iter, (int, type(None))):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not isinstance(freeze_iter, (int, type(None))):
if not isinstance(freeze_iter) and freeze_iter is not None:

if not isinstance(verbose, bool):
raise TypeError('`verbose` must be a boolean')
# check arguments value
if freeze_iter and freeze_iter < 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if freeze_iter and freeze_iter < 0:
if freeze_iter is not None and freeze_iter < 0:

if freeze_iter and freeze_iter < 0:
raise ValueError(
'`freeze_iter` must be greater than or equal to 0')
if freeze_epoch and freeze_epoch < 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if freeze_epoch and freeze_epoch < 0:
if freeze_epoch is not None and freeze_epoch < 0:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge this check into:

if (freeze_iter is None)  ^ (freeze_epoch is None):
   raise ValueError(...)

if freeze_iter is not None and freeze_iter < 0:
    raise ValueError(...)

if freeze_epoch is not None and freeze_epoch < 0:
    raise ValueError(...)

"""Modify the `requires_grad` of the specified layers.

Args:
model (BaseModel): a BaseModel of mmengine.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model (BaseModel): a BaseModel of mmengine.
model (BaseModel): A BaseModel of mmengine.


def _modify_layers_grad(self, model: BaseModel, layers: Sequence[str],
requires_grad: bool):
"""Modify the `requires_grad` of the specified layers.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Modify the `requires_grad` of the specified layers.
"""Modify the ``requires_grad`` of the specified layers.

print_log(
f'{k} requires_grad: {v.requires_grad}', logger='current')

def _main(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _main(self,
def _freeze(self,

Comment on lines 229 to 231
if self.freeze_iter is not None:
self._main(runner, runner.iter, self.freeze_iter,
self.unfreeze_iter)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.freeze_iter is not None:
self._main(runner, runner.iter, self.freeze_iter,
self.unfreeze_iter)
if self.freeze_iter is not None and runner.iter in (self.freeze_iter, self.unfreeze_iter):
self._freeze(runner.model)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants