/
defaulf_cfg.py
164 lines (142 loc) · 9.14 KB
/
defaulf_cfg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import argparse
class default_parser:
def __init__(self) -> None:
pass
def wandb_parser(self):
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--wandb', action='store_true')
parser.add_argument('--wandb_project', type=str, default='NeurIPs2022-Sparse SAM', help="Project name in wandb.")
parser.add_argument('--wandb_name', type=str, default='Default', help="Experiment name in wandb.")
return parser
def base_parser(self):
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--output_dir', type=str, default='logs', help='Name of dir where save all experiments.')
parser.add_argument('--output_name', type=str, default=None, help="Name of dir where save the log.txt&ckpt.pth of this experiment. (None means auto-set)")
parser.add_argument('--resume', action='store_true', help="resume model,opt,etc.")
parser.add_argument('--resume_path', type=str, default='.')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--log_freq', type=int, default=10, help="Frequency of recording information.")
parser.add_argument('--start_epoch', type=int, default=0)
parser.add_argument('--epochs', type=int, default=200, help="Epochs of training.")
return parser
def dist_parser(self):
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
return parser
def data_parser(self):
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--dataset', type=str, default='CIFAR10_base', help="Dataset name in `DATASETS` registry.")
parser.add_argument('--datadir', type=str, default='/public/data0/DATA-1/users/mipeng7/datasets', help="Path to your dataset.")
parser.add_argument('--batch_size', type=int, default=128, help="Batch size used in training and validation.")
parser.add_argument('--num_workers', type=int, default=8, help="Number of CPU threads for dataloaders.")
parser.add_argument('--pin_memory', action='store_true', default=True)
parser.add_argument('--drop_last', action='store_true', default=True)
parser.add_argument('--distributed_val', action='store_true', help="Enabling distributed evaluation (Only works when use multi gpus).")
return parser
def base_opt_parser(self):
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--opt', type=str, default='sgd')
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--weight_decay', type=float, default=5e-4)
# sgd
parser.add_argument('--momentum', type=float, default=0.9, help="Momentum for SGD.(None means the default in optm)")
parser.add_argument('--nesterov', action="store_true")
# adam
parser.add_argument('--betas', type=float, default=None, nargs='+', help="Betas for AdamW Optimizer.(None means the default in optm)")
parser.add_argument('--eps', type=float, default=None, help="Epsilon for AdamW Optimizer.(None means the default in optm)")
return parser
def sam_opt_parser(self):
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--rho', type=float, default=0.05, help="Perturbation intensity of SAM type optims.")
parser.add_argument('--sparsity', type=float, default=0.2, help="The proportion of parameters that do not calculate perturbation.")
parser.add_argument('--update_freq', type=int, default=5, help="Update frequency (epoch) of sparse SAM.")
parser.add_argument('--pattern', choices=["unstructured", "structured", "nm"])
parser.add_argument('--implicit', action='store_true')
parser.add_argument('--m_structured', type=int)
parser.add_argument('--n_structured', type=int)
parser.add_argument('--num_samples', type=int, default=1024, help="Number of samples to compute fisher information. Only for `ssam-f`.")
parser.add_argument('--drop_rate', type=float, default=0.5, help="Death Rate in `ssam-d`. Only for `ssam-d`.")
parser.add_argument('--drop_strategy', type=str, default='gradient', help="Strategy of Death. Only for `ssam-d`.")
parser.add_argument('--growth_strategy', type=str, default='random', help="Only for `ssam-d`.")
return parser
def lr_scheduler_parser(self):
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--warmup_epoch', type=int, default=0)
parser.add_argument('--warmup_init_lr', type=float, default=0.0)
parser.add_argument('--lr_scheduler', type=str, default='CosineLRscheduler')
# CosineLRscheduler
parser.add_argument('--eta_min', type=float, default=0)
# MultiStepLRscheduler
parser.add_argument('--milestone', type=int, nargs='+', default=[60, 120, 160], help="Milestone for MultiStepLRscheduler.")
parser.add_argument('--gamma', type=float, default=0.2, help="Gamma for MultiStepLRscheduler.")
return parser
def model_parser(self):
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--model', type=str, default='resnet18', help="Model in registry to use.")
parser.add_argument('--patch_size', type=int, default=4, help="Patch size used in ViT for CIFAR.")
parser.add_argument('--samconv', action='store_true', help='used for sparse bp(in conv)')
parser.add_argument('--culinear', action='store_true', help='used for sparse bp(in vit)')
return parser
def get_args(self):
all_parser_funcs = []
for func_or_attr in dir(self):
if callable(getattr(self, func_or_attr)) and not func_or_attr.startswith('_') and func_or_attr[-len('parser'):] == 'parser':
all_parser_funcs.append(getattr(self, func_or_attr))
all_parsers = [parser_func() for parser_func in all_parser_funcs]
final_parser = argparse.ArgumentParser(parents=all_parsers)
args = final_parser.parse_args()
self.auto_set_name(args)
return args
def auto_set_name(self, args):
def opt_hyper_str(args):
args_opt = args.opt.split('-')
if len(args_opt) == 1:
return [str(args.opt)]
elif len(args_opt) == 2:
sam_opt, base_opt = args_opt[0], args_opt[1]
if sam_opt[:3].upper() == 'SAM' and args.implicit is False and (args.samconv is False and args.culinear is False):
# using optimizer SAM
return [str(args.opt), 'rho{}'.format(str(args.rho))]
elif sam_opt[:4].upper() == 'SSAM' and args.implicit is False and args.samconv is False:
# using optimizer SSAMF or SSAMD explicitly
outlist = [str(args.opt), 'rho{}'.format(str(args.rho)), 'pattern-{}'.format(str(args.pattern))]
if args.pattern == 'unstructured' or args.pattern == 'structured':
outlist += ['sparsity{}'.format(str(args.sparsity))]
elif args.pattern == 'nm':
outlist += ['n{}m{}'.format(str(args.n_structured), str(args.m_structured))]
else:
raise ValueError("Wrong args.Pattern")
outlist += ['explicit']
return outlist
elif sam_opt[:3].upper() == 'SAM' and args.implicit is True and (args.samconv is True or args.culinear is True):
# using optimizer SSAMF implicitly
outlist = ['SSAMF', 'rho{}'.format(str(args.rho)), 'pattern-{}'.format(str(args.pattern))]
if args.pattern == 'structured':
outlist += ['sparsity{}'.format(str(args.sparsity))]
elif args.pattern == 'nm':
outlist += ['n{}m{}'.format(str(args.n_structured), str(args.m_structured))]
else:
raise ValueError("Wrong args.Pattern")
outlist += ['implicit']
return outlist
def sam_hyper_param(args):
args_opt = args.opt.split('-')
if len(args_opt) == 1:
return []
elif len(args_opt) == 2:
sam_opt, base_opt = args_opt[0], args_opt[1]
# SAM, SSAMF, SSAMD
output_name = ['rho{}'.format(args.rho)]
if sam_opt[:4].upper() == 'SSAM':
output_name.extend(['s{}u{}'.format(args.sparsity, args.update_freq), 'D{}{}'.format(args.drop_rate, args.drop_strategy), 'R{}'.format(args.growth_strategy), 'fisher-n{}'.format(args.num_samples)])
return output_name
if args.output_name is None:
args.output_name = '_'.join([
args.dataset,
'bsz' + str(args.batch_size),
'epoch' + str(args.epochs),
args.model,
'lr' + str(args.lr),
] + opt_hyper_str(args) + ['seed{}'.format(args.seed)])
if args.wandb_name == 'Default':
args.wandb_name = args.output_name