/
utils.py
158 lines (135 loc) · 4.51 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import copy
import numpy as np
from collections import OrderedDict as OD
from collections import defaultdict as DD
from collections import Iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
import collections
import numpy as np
import pandas as pd
import torch
def load_best_args(
args,
target='acc',
avg_over='run',
keep=['method', 'use_augs', 'task_free', 'dataset', 'mem_size', 'mir_head_only', 'distill_coef', 'n_iters'],
):
# load the dataframe with the hparam runs
df = pd.read_csv('sweeps/hp_result_final.csv')
# subselect the appropriate runs
for key in keep:
new_df = df[df[key] == getattr(args, key)]
if new_df.shape[0] == 0:
print(f'skipping over {key}')
else:
df = new_df
# which arg to overwrite ?
unique = df.nunique()
arg_list = list(unique[unique > 1].index)
arg_list.remove(avg_over)
arg_list.remove(target)
# find the best run
acc_per_cfg = df.groupby(arg_list)[target].agg(['mean', 'std'])
acc_per_cfg = acc_per_cfg.rename(columns={'mean': f'{target}_mean', 'std': f'{target}_std'})
arg_values = acc_per_cfg[f'{target}_mean'].idxmax()
if not isinstance(arg_values, Iterable):
arg_values = [arg_values]
print('overwriting args')
for (k,v) in zip(arg_list, arg_values):
if k in keep:
continue
print(f'{k} from {getattr(args, k)} to {v}')
setattr(args, k, v)
def sho_(x, nrow=8):
x = x * .5 + .5
from torchvision.utils import save_image
from PIL import Image
if x.ndim == 5:
nrow=x.size(1)
x = x.reshape(-1, *x.shape[2:])
save_image(x, 'tmp.png', nrow=nrow)
Image.open('tmp.png').show()
def save_(x, name='tmp.png'):
x = x * .5 + .5
from torchvision.utils import save_image
from PIL import Image
if x.ndim == 5:
nrow=x.size(1)
x = x.reshape(-1, *x.shape[2:])
save_image(x, name)
# --- MIR utils
''' For MIR '''
def overwrite_grad(pp, new_grad, grad_dims):
"""
This is used to overwrite the gradients with a new gradient
vector, whenever violations occur.
pp: parameters
newgrad: corrected gradient
grad_dims: list storing number of parameters at each layer
"""
cnt = 0
for param in pp():
param.grad=torch.zeros_like(param.data)
beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
en = sum(grad_dims[:cnt + 1])
this_grad = new_grad[beg: en].contiguous().view(
param.data.size())
param.grad.data.copy_(this_grad)
cnt += 1
def get_grad_vector(pp, grad_dims):
"""
gather the gradients in one vector
"""
grads = torch.zeros(size=(sum(grad_dims),), device=pp[0].device)
cnt = 0
for param in pp:
if param.grad is not None:
beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
en = sum(grad_dims[:cnt + 1])
grads[beg: en].copy_(param.grad.data.view(-1))
cnt += 1
return grads
def get_future_step_parameters(this_net, grad_vector, grad_dims, lr=1):
"""
computes \theta-\delta\theta
:param this_net:
:param grad_vector:
:return:
"""
new_net=copy.deepcopy(this_net)
overwrite_grad(new_net.parameters,grad_vector,grad_dims)
with torch.no_grad():
for param in new_net.parameters():
if param.grad is not None:
param.data=param.data - lr*param.grad.data
return new_net
def get_grad_dims(self):
self.grad_dims = []
for param in self.net.parameters():
self.grad_dims.append(param.data.numel())
# Taken from
# https://github.com/aimagelab/mammoth/blob/cb9a36d788d6ad051c9eee0da358b25421d909f5/models/gem.py#L34
def store_grad(params, grads, grad_dims):
"""
This stores parameter gradients of past tasks.
pp: parameters
grads: gradients
grad_dims: list with number of parameters per layers
"""
# store the gradients
grads.fill_(0.0)
count = 0
for param in params():
if param.grad is not None:
begin = 0 if count == 0 else sum(grad_dims[:count])
end = np.sum(grad_dims[:count + 1])
grads[begin: end].copy_(param.grad.data.view(-1))
count += 1
# Taken from
# https://github.com/aimagelab/mammoth/blob/cb9a36d788d6ad051c9eee0da358b25421d909f5/models/agem.py#L21
def project(gxy: torch.Tensor, ger: torch.Tensor) -> torch.Tensor:
corr = torch.dot(gxy, ger) / torch.dot(ger, ger)
return gxy - corr * ger