-
Notifications
You must be signed in to change notification settings - Fork 512
/
base_multi_modal_img_text.py
67 lines (57 loc) · 2.2 KB
/
base_multi_modal_img_text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import argparse
from corenet.modeling.models import MODEL_REGISTRY, BaseAnyNNModel
@MODEL_REGISTRY.register(name="__base__", type="multi_modal_image_text")
class BaseMultiModalImageText(BaseAnyNNModel):
"""Base class for multi-modal image-text data
Args:
opts: Command-line arguments
"""
def __init__(self, opts, *args, **kwargs) -> None:
super().__init__(opts, *args, **kwargs)
self.lr_multiplier_img_encoder = getattr(
opts, "model.multi_modal_image_text.lr_multiplier_img_encoder"
)
self.lr_multiplier_text_encoder = getattr(
opts, "model.multi_modal_image_text.lr_multiplier_text_encoder"
)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add model specific arguments"""
if cls != BaseMultiModalImageText:
# Don't re-register arguments in subclasses that don't override `add_arguments()`.
return parser
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--model.multi-modal-image-text.name",
type=str,
default=None,
help="Name of the multi-modal image-text model",
)
group.add_argument(
"--model.multi-modal-image-text.lr-multiplier-img-encoder",
type=float,
default=1.0,
help="LR multiplier for the image encoder in {}".format(cls.__name__),
)
group.add_argument(
"--model.multi-modal-image-text.lr-multiplier-text-encoder",
type=float,
default=1.0,
help="LR multiplier for the text encoder in {}".format(cls.__name__),
)
group.add_argument(
"--model.multi-modal-image-text.pretrained",
type=str,
default=None,
help="Path of the pretrained backbone",
)
group.add_argument(
"--model.multi-modal-image-text.freeze-batch-norm",
action="store_true",
help="Freeze batch norm layers",
)
return parser