-
Notifications
You must be signed in to change notification settings - Fork 11
/
viz_cifar.py
executable file
·221 lines (194 loc) · 7.32 KB
/
viz_cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
Tools for plotting / visualization
"""
import matplotlib
matplotlib.use('Agg') # no displayed figures -- need to call before loading pylab
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import warnings
def is_square(shp, n_colors=1):
"""
Test whether entries in shp are square numbers, or are square numbers after divigind out the
number of color channels.
"""
is_sqr = (shp == np.round(np.sqrt(shp))**2)
is_sqr_colors = (shp == n_colors*np.round(np.sqrt(np.array(shp)/float(n_colors)))**2)
return is_sqr | is_sqr_colors
def show_receptive_fields(theta, P=None, n_colors=None, max_display=100, grid_wa=None):
"""
Display receptive fields in a grid. Tries to intelligently guess whether to treat the rows,
the columns, or the last two axes together as containing the receptive fields. It does this
by checking which axes are square numbers -- so you can get some unexpected plots if the wrong
axis is a square number, or if multiple axes are. It also tries to handle the last axis
containing color channels correctly.
"""
shp = np.array(theta.shape)
if n_colors is None:
n_colors = 1
if shp[-1] == 3:
n_colors = 3
# multiply colors in as appropriate
if shp[-1] == n_colors:
shp[-2] *= n_colors
theta = theta.reshape(shp[:-1])
shp = np.array(theta.shape)
if len(shp) > 2:
# merge last two axes
shp[-2] *= shp[-1]
theta = theta.reshape(shp[:-1])
shp = np.array(theta.shape)
if len(shp) > 2:
# merge leading axes
theta = theta.reshape((-1,shp[-1]))
shp = np.array(theta.shape)
if len(shp) == 1:
theta = theta.reshape((-1,1))
shp = np.array(theta.shape)
# figure out the right orientation, by looking for the axis with a square
# number of entries, up to number of colors. transpose if required
is_sqr = is_square(shp, n_colors=n_colors)
if is_sqr[0] and is_sqr[1]:
warnings.warn("Unsure of correct matrix orientation. "
"Assuming receptive fields along first dimension.")
elif is_sqr[1]:
theta = theta.T
elif not is_sqr[0] and not is_sqr[1]:
# neither direction corresponds well to an image
# NOTE if you delete this next line, the code will work. The rfs just won't look very
# image like
return False
theta = theta[:,:max_display].copy()
if P is None:
img_w = int(np.ceil(np.sqrt(theta.shape[0]/float(n_colors))))
else:
img_w = int(np.ceil(np.sqrt(P.shape[0]/float(n_colors))))
nf = theta.shape[1]
if grid_wa is None:
grid_wa = int(np.ceil(np.sqrt(float(nf))))
grid_wb = int(np.ceil(nf / float(grid_wa)))
if P is not None:
theta = np.dot(P, theta)
vmin = np.min(theta)
vmax = np.max(theta)
for jj in range(nf):
plt.subplot(grid_wa, grid_wb, jj+1)
ptch = np.zeros((n_colors*img_w**2,))
ptch[:theta.shape[0]] = theta[:,jj]
if n_colors==3:
ptch = ptch.reshape((n_colors, img_w, img_w))
ptch = ptch.transpose((1,2,0)) # move color channels to end
else:
ptch = ptch.reshape((img_w, img_w))
#ptch -= vmin
#ptch /= vmax-vmin
if n_colors==3:
#print "minmax", ptch.min(), ptch.max()
plt.imshow(ptch.clip(0.0,1.0), interpolation='nearest')
else:
plt.imshow(ptch, interpolation='nearest', cmap=cm.Greys_r)
plt.axis('off')
return True
def plot_parameter(theta_in, base_fname_part1, base_fname_part2="", title = '', n_colors=None):
"""
Save both a raw and receptive field style plot of the contents of theta_in.
base_fname_part1 provides the mandatory root of the filename.
"""
theta = np.array(theta_in.copy()) # in case it was a scalar
print "%s min %g median %g mean %g max %g shape"%(
title, np.min(theta), np.median(theta), np.mean(theta), np.max(theta)), theta.shape
theta = np.squeeze(theta)
if len(theta.shape) == 0:
# it's a scalar -- make it a 1d array
theta = np.array([theta])
shp = theta.shape
if len(shp) > 2:
theta = theta.reshape((theta.shape[0], -1))
shp = theta.shape
## display basic figure
plt.figure(figsize=[8,8])
if len(shp) == 1:
plt.plot(theta, '.', alpha=0.5)
elif len(shp) == 2:
plt.imshow(theta, interpolation='nearest', aspect='auto', cmap=cm.Greys_r)
plt.colorbar()
plt.title(title)
plt.savefig(base_fname_part1 + '_raw_' + base_fname_part2 + '.pdf')
plt.close()
## also display it in basis function view if it's a matrix, or
## if it's a bias with a square number of entries
if len(shp) >= 2 or is_square(shp[0]):
if len(shp) == 1:
theta = theta.reshape((-1,1))
plt.figure(figsize=[8,8])
if show_receptive_fields(theta, n_colors=n_colors):
plt.suptitle(title + "receptive fields")
plt.savefig(base_fname_part1 + '_rf_' + base_fname_part2 + '.pdf')
plt.close()
def max_value(inputlist, index):
return max([sublist[index] for sublist in inputlist])
def plot_2D(x, num_steps, filename):
"""
plot 2D images
"""
if num_steps == 1:
x_0 = [sublist[0] for sublist in x]
x_1 = [sublist[1] for sublist in x]
plt.scatter(x_0, x_1)
plt.axis([-1.5, 1.5, -1.5, 1.5])
plt.savefig(filename + '.pdf')
plt.close()
else:
for time in range(num_steps):
print num_steps
plt.close()
x_0 = [sublist[0] for sublist in x[time]]
x_1 = [sublist[1] for sublist in x[time]]
plt.scatter(x_0, x_1)
plt.axis([-1.5, 1.5, -1.5, 1.5])
plt.savefig(filename + '_step_'+ str(time)+'.pdf')
plt.close()
def plot_grad(grad, filename):
plt.close()
rng = [(-1.5,1.5),(-1.5,1.5)]
(x_beg, x_end), (y_beg, y_end) = rng
for step in range(len(grad)):
start_0 = np.asarray([sublist[0] for sublist in grad[step]])
start_1 = np.asarray([sublist[1] for sublist in grad[step]])
#X, Y = np.mgrid[x_beg:x_end:30j, y_beg:y_end:30j].astype('float32')
end_0 = np.asarray([sublist[2] for sublist in grad[step]])
end_1 = np.asarray([sublist[3] for sublist in grad[step]])
speed = np.sqrt(end_0 ** 2 + end_1 ** 2)
UN = end_0/ speed
VN = end_1/ speed
#import ipdb
#ipdb.set_trace()
plt.figure()
plt.quiver(start_0, start_1, UN, VN,
end_0, cmap = cm.winter,
headlength = 3,
clim = [0.,1.])
plt.axis([-1.5, 1.5, -1.5, 1.5])
plt.colorbar()
plt.savefig(filename + '_step_'+ str(step)+'.pdf')
plt.close()
def plot_images(X, fname):
"""
Plot images in a grid.
X is expected to be a 4d tensor of dimensions [# images]x[# colors]x[height]x[width]
"""
form = ".png"
savenp = True
## plot
# move color to end
Xcol = X.reshape((X.shape[0],-1,)).T
plt.figure(figsize=[8,8])
if show_receptive_fields(Xcol, n_colors=X.shape[1]):
plt.savefig(fname + form)
else:
warnings.warn('Images unexpected shape.')
#print "saving fig", fname
plt.close()
## save as a .npz file
if savenp:
np.savez(fname + '.npz', X=X)