-
Notifications
You must be signed in to change notification settings - Fork 512
/
base_optim.py
80 lines (70 loc) · 2.5 KB
/
base_optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import argparse
from corenet.utils import logger
class BaseOptim(object):
"""Base class for optimizer
Args:
opts: Command-line arguments
"""
def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None:
self.eps = 1e-8
self.lr = getattr(opts, "scheduler.lr")
self.weight_decay = getattr(opts, "optim.weight_decay")
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add optimizer arguments"""
if cls != BaseOptim:
# Don't re-register arguments in subclasses that don't override `add_arguments()`.
return parser
group = parser.add_argument_group(cls.__name__)
group.add_argument(
"--optim.name",
type=str,
default="sgd",
help="Name of the optimizer. Defaults to SGD.",
)
group.add_argument(
"--optim.eps",
type=float,
default=1e-8,
help="Optimizer epsilon value. Defaults to 1.e-8.",
)
group.add_argument(
"--optim.weight-decay",
default=4e-5,
type=float,
help="Weight decay (or L2 penalty). Defaults to 4.e-5.",
)
group.add_argument(
"--optim.no-decay-bn-filter-bias",
action="store_true",
default=False,
help="When enabled, the weight in normalization layers and biases in the model are not decayed."
"Defaults to False.",
)
group.add_argument(
"--optim.bypass-parameters-check",
action="store_true",
default=False,
help="Bypass parameter check when creating optimizer. Defaults to False",
)
return parser
def __repr__(self) -> str:
group_dict = dict()
for i, group in enumerate(self.param_groups):
for key in sorted(group.keys()):
if key == "params":
continue
if key not in group_dict:
group_dict[key] = [group[key]]
else:
group_dict[key].append(group[key])
format_string = self.__class__.__name__ + " ("
format_string += "\n"
for k, v in group_dict.items():
format_string += "\t {0}: {1}\n".format(k, v)
format_string += ")"
return format_string