Skip to content

Commit c134d5d

Browse files
committed
Potential final model v2
1 parent 93e0e86 commit c134d5d

File tree

12 files changed

+218
-59
lines changed

12 files changed

+218
-59
lines changed

CurriculumLib.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,6 @@ def __getitem__(self, idx):
147147
elNorm = np.stack([iris_norm, pupil_norm], axis=0) # Respect iris first policy
148148

149149
elNorm = torch.from_numpy(elNorm).to(self.prec)
150-
'''
151-
print('...')
152-
print(img.type())
153-
print(label.type())
154-
print(spatialWeights.type())
155-
print(pupil_center.type())
156-
print(iris_center.type())
157-
print(elPts.type())
158-
print(elNorm.type())
159-
print(cond.type())
160-
print(imInfo.type())
161-
print('---')
162-
'''
163150
return (img, label, spatialWeights, distMap, pupil_center, iris_center, elNorm, cond, imInfo)
164151

165152
def readImage(self, idx):

analysis/pixelStats.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Wed May 6 07:29:05 2020
5+
6+
@author: rakshit
7+
"""
8+
9+
import os
10+
import cv2
11+
import sys
12+
import h5py
13+
import pickle
14+
import numpy as np
15+
16+
sys.path.append('..')
17+
18+
path2curObjs = '/home/rakshit/Documents/Python_Scripts/GIW_e2e/curObjects/baseline'
19+
path2ds = '/media/rakshit/tank/Dataset/All'
20+
21+
ds_list = ['Fuhl', 'PupilNet', 'LPW', 'NVGaze', 'OpenEDS', 'riteyes_general']
22+
cond = ['Natural', 'Constrained', 'Natural', 'Constrained', 'Natural', 'Constrained']
23+
24+
curObjs_list = os.listdir(path2curObjs)
25+
26+
def readImages(obj):
27+
archNums = np.unique(obj.imList[:, 1])
28+
I_list = []
29+
for archNum in np.nditer(archNums):
30+
f = h5py.File(os.path.join(path2ds, obj.arch[archNum]+'.h5'), 'r')
31+
im_ids = obj.imList[obj.imList[:, 1] == archNum, 0]
32+
loc = np.in1d(np.arange(f['Images'].shape[0]), im_ids)
33+
I = np.array(f['Images'][loc, ...]).astype(np.float32)
34+
I_list.append(normalize(I))
35+
f.close()
36+
return np.concatenate(I_list, axis=0)
37+
38+
def normalize(imgs):
39+
# Given a large amount of images, normalize and return
40+
L, H, W = imgs.shape
41+
mu = np.mean(imgs.reshape(L, -1), axis=1) # mu [L, ]
42+
std = np.std(imgs.reshape(L, -1), axis=1) # std [L, ]
43+
norm_data = imgs.reshape(L, -1) - np.stack([mu for i in range(H*W)], axis=1)
44+
norm_data = norm_data/np.stack([std for i in range(H*W)], axis=1)
45+
return norm_data
46+
47+
pxStats = {'name': [],
48+
'train': [],
49+
'valid': [],
50+
'test': []}
51+
52+
for ds_name in ds_list:
53+
print('Starting: {}'.format(ds_name))
54+
pxStats['name'].append(ds_name)
55+
path2curObj = os.path.join(path2curObjs, 'cond_'+ds_name+'.pkl')
56+
trainObj, validObj, testObj = pickle.load(open(path2curObj, 'rb'))
57+
58+
# Extract train stats
59+
norm_data = readImages(trainObj)
60+
vals = np.apply_along_axis(lambda x: np.histogram(x,
61+
range=(-4, 4),
62+
bins=40),
63+
axis=1,
64+
arr=norm_data)
65+
pxStats['train'].append(vals)
66+
print('Train done: {}'.format(ds_name))
67+
68+
# Extract valid stats
69+
norm_data = readImages(validObj)
70+
vals = np.apply_along_axis(lambda x: np.histogram(x,
71+
range=(-4, 4),
72+
bins=40),
73+
axis=1,
74+
arr=norm_data)
75+
pxStats['valid'].append(vals)
76+
print('Valid done: {}'.format(ds_name))
77+
78+
# Extract test stats
79+
norm_data = readImages(testObj)
80+
vals = np.apply_along_axis(lambda x: np.histogram(x,
81+
range=(-4, 4),
82+
bins=40),
83+
axis=1,
84+
arr=norm_data)
85+
pxStats['test'].append(vals)
86+
print('Test done: {}'.format(ds_name))
87+
88+
# Save out data
89+
f = open('statData.pkl','wb')
90+
pickle.dump(pxStats, f)
91+
f.close()
92+
print('Done: {}'.format(ds_name))

args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def parse_args():
3535
parser.add_argument('--model', type=str, default='ritnet_v2', help='select model')
3636
parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
3737
parser.add_argument('--batchsize', type=int, default=16, help='select a batchsize')
38-
parser.add_argument('--resume', type=int, default=0, help='resume?')
38+
parser.add_argument('--resume', type=int, default=1, help='resume?')
3939
parser.add_argument('--loadfile', type=str, default='', help='load experiment')
4040
parser.add_argument('--expname', type=str, default='dev', help='experiment number')
4141
parser.add_argument('--prec', type=int, default=32, help='precision. 16, 32, 64')
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
'''
4+
5+
@author: rakshit
6+
'''
7+
import os
8+
import sys
9+
import pickle
10+
11+
sys.path.append('..')
12+
import CurriculumLib as CurLib
13+
from CurriculumLib import DataLoader_riteyes
14+
15+
path2data = '/media/rakshit/tank/Dataset'
16+
path2h5 = os.path.join(path2data, 'All')
17+
keepOld = False
18+
19+
DS_sel = pickle.load(open('dataset_selections.pkl', 'rb'))
20+
AllDS = CurLib.readArchives(os.path.join(path2data, 'MasterKey'))
21+
list_ds = ['NVGaze', 'OpenEDS', 'riteyes_general']
22+
23+
# Generate objects per dataset
24+
subsets_train = []
25+
subsets_test = []
26+
27+
for setSel in list_ds:
28+
subsets_train += DS_sel['train'][setSel]
29+
subsets_test += DS_sel['test'][setSel]
30+
31+
# Train object
32+
AllDS_cond = CurLib.selSubset(AllDS, subsets_train)
33+
dataDiv_obj = CurLib.generate_fileList(AllDS_cond, mode='vanilla', notest=False)
34+
trainObj = DataLoader_riteyes(dataDiv_obj, path2h5, 0, 'train', True, (480, 640), scale=0.5)
35+
validObj = DataLoader_riteyes(dataDiv_obj, path2h5, 0, 'valid', False, (480, 640), scale=0.5)
36+
37+
# Test object
38+
AllDS_cond = CurLib.selSubset(AllDS, subsets_test)
39+
dataDiv_obj = CurLib.generate_fileList(AllDS_cond, mode='none', notest=True)
40+
testObj = DataLoader_riteyes(dataDiv_obj, path2h5, 0, 'test', False, (480, 640), scale=0.5)
41+
42+
path2save = os.path.join(os.getcwd(), 'baseline', 'cond_'+'pretrained'+'.pkl')
43+
if os.path.exists(path2save) and keepOld:
44+
print('Preserving old selections ...')
45+
# This ensure that the original selection remains the same
46+
trainObj_orig, validObj_orig, testObj_orig = pickle.load(open(path2save, 'rb'))
47+
trainObj.imList = trainObj_orig.imList
48+
validObj.imList = validObj_orig.imList
49+
testObj.imList = testObj_orig.imList
50+
pickle.dump((trainObj, validObj, testObj), open(path2save, 'wb'))
51+
else:
52+
pickle.dump((trainObj, validObj, testObj), open(path2save, 'wb'))

models/RITnet_v2.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.nn as nn
99
import torch.nn.functional as F
1010

11-
from utils import normPts, regressionModule, linStack, unnormPts, convBlock
11+
from utils import normPts, regressionModule, linStack, convBlock
1212
from loss import conf_Loss, get_ptLoss, get_seg2ptLoss, get_segLoss
1313
from loss import WeightedHausdorffDistance
1414

@@ -155,7 +155,7 @@ def __init__(self,
155155
chz=32,
156156
growth=1.2,
157157
actfunc=F.leaky_relu,
158-
norm=nn.BatchNorm2d,
158+
norm=nn.InstanceNorm2d,
159159
selfCorr=False,
160160
disentangle=False):
161161
super(DenseNet2D, self).__init__()
@@ -205,20 +205,15 @@ def forward(self,
205205
op = self.dec(x4, x3, x2, x1, x)
206206

207207
#%% Weighted Hauss Loss
208+
'''
209+
# This loss does not conflict with segmentation losses
208210
dsizes = torch.from_numpy(np.stack([[H/1.]*B, [W/1.]*B], axis=1)).to(x.device)
209-
210-
# wHauss expects GT as rows and cols.
211-
flag_segSamples = (1 -cond[:,1]).to(torch.float32)
212-
num_segSamples = torch.sum(flag_segSamples)
213-
loss_wHauss_iri = self.wHauss(torch.sigmoid(op[:,1,...]), # Iris heatmap
214-
unnormPts(elNorm[:, 0, :2], [H, W])[:, [1, 0]], # Pixel locs
215-
dsizes)
216-
loss_wHauss_iri = torch.sum(loss_wHauss_iri*flag_segSamples)/num_segSamples if num_segSamples else 0.0
217-
loss_wHauss_pup = self.wHauss(torch.sigmoid(op[:,-1,...]),
218-
pupil_center[:, [1, 0]], dsizes)
219-
loss_wHauss_pup = loss_wHauss_pup.mean()
220-
loss_wHauss = .5*loss_wHauss_pup + .5*loss_wHauss_iri
221211
212+
pupMap = torch.softmax(op[:, -1, ...].view(B, -1), dim=1)
213+
loss_wHauss_pup = self.wHauss(pupMap.view(B, H, W),
214+
pupil_center[:, [1, 0]], dsizes)
215+
loss_wHauss = loss_wHauss_pup.mean()
216+
'''
222217
#%%
223218
op_tup = get_allLoss(op, # Output segmentation map
224219
elOut, # Predicted Ellipse parameters
@@ -230,8 +225,9 @@ def forward(self,
230225
cond, # Condition
231226
ID, # Image and dataset ID
232227
alpha)
228+
233229
loss, pred_c_seg = op_tup
234-
loss += 5e-3*loss_wHauss
230+
#loss += 5e-4*loss_wHauss
235231

236232
if self.disentangle:
237233
pred_ds = self.dsIdentify_lin(latent)
@@ -283,22 +279,26 @@ def get_allLoss(op, # Network output
283279
# Segmentation to pupil center loss using center of mass
284280
l_seg2pt_pup, pred_c_seg_pup = get_seg2ptLoss(op[:, 2, ...],
285281
normPts(pupil_center,
286-
target.shape[1:]), temperature=1)
287-
l_seg2pt_pup = torch.mean(l_seg2pt_pup)
282+
target.shape[1:]), temperature=16)
288283

289284
# Segmentation to iris center loss using center of mass
290285
if torch.sum(loc_onlyMask):
291286
# Iris center is only present when GT masks are present. Note that
292287
# elNorm will hold garbage values. Those samples should not be backprop
293-
l_seg2pt_iri, pred_c_seg_iri = get_seg2ptLoss(op[:, 1, ...],
294-
elNorm[:, 0, :2], temperature=1)
288+
w = np.clip((0.5/0.1)*alpha, 0, 0.5) # w: [0->0.5]
289+
iriMap = op[:, 1, ...]*(0.5+w) + op[:, 2, ...]*(w-0.5) # gradual handoff
290+
l_seg2pt_iri, pred_c_seg_iri = get_seg2ptLoss(iriMap,
291+
elNorm[:, 0, :2],
292+
temperature=16)
295293
temp = torch.stack([loc_onlyMask, loc_onlyMask], dim=1)
296294
l_seg2pt_iri = torch.sum(l_seg2pt_iri*temp)/torch.sum(temp.to(torch.float32))
295+
l_seg2pt_pup = torch.mean(l_seg2pt_pup)
297296

298297
else:
299298
# If GT map is absent, loss is set to 0.0
300299
# Set Iris and Pupil center to be same
301300
l_seg2pt_iri = 0.0
301+
l_seg2pt_pup = torch.mean(l_seg2pt_pup)
302302
pred_c_seg_iri = torch.clone(elOut[:, 5:7])
303303

304304
pred_c_seg = torch.stack([pred_c_seg_iri,
@@ -310,7 +310,8 @@ def get_allLoss(op, # Network output
310310

311311
# Bottleneck ellipse losses
312312
# NOTE: This loss is only activated when normalized ellipses do not exist
313-
l_pt = get_ptLoss(elOut[:, 5:7], normPts(pupil_center, target.shape[1:]), 1-loc_onlyMask)
313+
l_pt = get_ptLoss(elOut[:, 5:7], normPts(pupil_center,
314+
target.shape[1:]), 1-loc_onlyMask)
314315

315316
# Compute ellipse losses - F1 loss for valid samples
316317
l_ellipse = get_ptLoss(elOut, elNorm.view(-1, 10), loc_onlyMask)
-397 Bytes
Binary file not shown.

pretrained.git_ok

10.5 MB
Binary file not shown.

pytorchtools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self,
3737
self.best_score = None
3838
self.early_stop = False
3939
self.val_loss_min = np.Inf if mode == 'min' else -np.Inf
40-
self.delta = delta if mode == 'max' else -delta
40+
self.delta = delta
4141
self.path2save = path2save
4242
self.fName = fName
4343
self.mode = mode

runRC_baseline.sh

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,9 @@ workers=12
66
lr=0.0005
77

88
spack env activate riteyes4
9-
# Load necessary modules
10-
# spack load /7qmaaiw # Load OpenCV
11-
# spack load /jthz32l # Load pytorch by hash
12-
# spack load /dtlfq7l # Load torchvision by hash
13-
# spack load /fvki7dt # Load scipy
14-
# spack load /rso7arf # Load matplotlib
15-
# spack load /me57ozl # Load image manipulation library
16-
# spack load /bblye5g # Load sklearn for metrics
17-
# spack load /zzdgeg6 # Load tensorboardx (latest)
18-
# spack load /me75cc2 # Load tqdm
19-
# spack load /hlxw2mt # Load h5py with MPI
209

2110
declare -a curObj_list=("NVGaze" "PupilNet" "OpenEDS" "Fuhl" "riteyes-general" "LPW")
22-
declare -a batchsize_list=("36" "48" "36" "48" "36" "48")
11+
declare -a batchsize_list=("48" "60" "48" "60" "48" "60")
2312
declare -a selfCorr_list=("0")
2413
declare -a disentangle_list=("0")
2514

@@ -36,7 +25,7 @@ do
3625
str+="--disp=0 --overfit=0 --lr=${lr} --selfCorr=${selfCorr} --disentangle=${disentangle}"
3726
echo $str
3827
echo -e $str > command.lock
39-
sbatch -J ${baseJobName} -o "rc_log/baseline/${baseJobName}.o" -e "rc_log/baseline/${baseJobName}.e" --mem=16G --cpus-per-task=9 -p debug -A riteyes --gres=gpu:p4:2 -t 0-1:0:0 command.lock
28+
sbatch -J ${baseJobName} -o "rc_log/baseline/${baseJobName}.o" -e "rc_log/baseline/${baseJobName}.e" --mem=16G --cpus-per-task=9 -p tier3 -A riteyes --gres=gpu:v100:1 -t 2-0:0:0 command.lock
4029
done
4130
done
4231
done

runRC_testCode.sh

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/bin/bash -l
2+
3+
path2ds="/home/rsk3900/Datasets/"
4+
epochs=2
5+
workers=12
6+
lr=0.0005
7+
8+
spack env activate riteyes4
9+
10+
declare -a curObj_list=("NVGaze" "PupilNet" "OpenEDS" "Fuhl" "riteyes-general" "LPW")
11+
declare -a batchsize_list=("16" "24" "16" "24" "16" "24")
12+
declare -a selfCorr_list=("0")
13+
declare -a disentangle_list=("0")
14+
15+
for i in "${!curObj_list[@]}"
16+
do
17+
for selfCorr in "${selfCorr_list[@]}"
18+
do
19+
for disentangle in "${disentangle_list[@]}"
20+
do
21+
batchsize=${batchsize_list[i]}
22+
baseJobName="RC_e2e_${curObj_list[i]}_${selfCorr}_${disentangle}"
23+
str="#!/bin/bash\npython3 train.py --path2data=${path2ds} --expname=${baseJobName} "
24+
str+="--curObj=${curObj_list[i]} --batchsize=${batchsize} --workers=${workers} --prec=32 --epochs=${epochs} "
25+
str+="--disp=0 --overfit=0 --lr=${lr} --selfCorr=${selfCorr} --disentangle=${disentangle}"
26+
echo $str
27+
echo -e $str > command.lock
28+
sbatch -J ${baseJobName} -o "rc_log/baseline/${baseJobName}.o" -e "rc_log/system_test/${baseJobName}.e" --mem=16G --cpus-per-task=9 -p debug -A riteyes --gres=gpu:p4:2 -t 0-5:0:0 command.lock
29+
done
30+
done
31+
done

0 commit comments

Comments
 (0)