-
Notifications
You must be signed in to change notification settings - Fork 7
/
train.py
121 lines (100 loc) · 3.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 12 20:44:14 2017
@author: dmarr
"""
#from pylab import*
caffe_root='/home/dhubel/deeplab_v2/deeplab-public-ver2/' # Replace the caffe_root with your caffe path in your system
import sys
sys.path.insert(0,caffe_root + 'python')
import caffe
import numpy as np
import os
import time
os.chdir('/home/dhubel/deeplab_v2/deeplab-public-ver2/') # Replace the caffe_root with your caffe path in your system
os.getcwd
caffe.set_device(0)
caffe.set_mode_gpu()
MAC_model='MACblock_99'
k=1
record_val_loss=True
f=open('../models/'+MAC_model+'/train.prototxt','r+')
s=f.read()
f.seek(0,0)
index=s.rfind('split')
c=s[index:index+23]
index1=c.rfind(',')
f.write(s.replace(s[index+11:index+index1-2],'train'+str(k)))
f.close()
if record_val_loss:
f=open('../models/'+MAC_model+'/val.prototxt','r+')
s=f.read()
f.seek(0,0)
index=s.rfind('split')
c=s[index:index+23]
index1=c.rfind(',')
f.write(s.replace(s[index+11:index+index1-2],'val'+str(k)))
f.close()
f=open('../models/'+MAC_model+'/solver_train_LF.prototxt','r+')
w=open('../models/'+MAC_model+'/solver_train_LF_aug.prototxt','w+')
s=f.read()
f.seek(0,0)
if 'test_' not in s:
print 'Please add test network in solver_train_LF.prototxt....'
if '#test_'in s and 'test_' in s:
s=s.replace('#test_','test_')
ind=s.rfind('model/')
ind1=s.rfind('train"')
ind3=s.rfind('GPU')
w.write(s[:ind]+'model/'+str(k)+s[ind1:])
f.close()
w.close()
else:
f=open('../models/'+MAC_model+'/solver_train_LF.prototxt','r+')
w=open('../models/'+MAC_model+'/solver_train_LF_aug.prototxt','w+')
s=f.read()
f.seek(0,0)
if '#test_'not in s and 'test_' in s:
s=s.replace('test_','#test_')
ind=s.rfind('model/')
ind1=s.rfind('train"')
w.write(s[:ind]+'model/'+str(k)+s[ind1:])
f.close()
w.close()
caffemodel='../pretrain/pretrain.caffemodel'
solver=caffe.SGDSolver('../models/'+MAC_model+'/solver_train_LF_aug.prototxt')
solver.net.copy_from(caffemodel)
max_iter = 160000
if not record_val_loss:
solver.step(max_iter)
else:
loss_PATH='../models/'+MAC_model+'/loss/'
train_loss = []
val_loss = []
test_iter = 128
test_interval = 500
start=time.time()
for it in range(max_iter//20):
traloss=0
for i in range(20):
solver.step(1)
traloss1=solver.net.blobs['loss'].data*1
traloss=traloss1+traloss
tloss=traloss/20
train_loss.append(str((it+1)*20)+' '+str(tloss))
end1=time.time()
print( end1-start )
print ('Iteration',(it+1)*20,'train_loss=',tloss)
if (it+1)*20 %test_interval==0:
valloss=0
for test_it in range(test_iter):
solver.test_nets[0].forward()
valloss1=solver.test_nets[0].blobs['loss'].data*1
valloss=valloss1+valloss
vloss=valloss/test_iter
val_loss.append(str((it+1)*20)+' '+str(vloss))
end1=time.time()
print (end1-start )
print ('Iteration',(it+1)*20,'val_loss=',vloss )
np.savetxt(loss_PATH+'train_loss.txt', train_loss, fmt='%ls')
np.savetxt(loss_PATH+'val_loss.txt', val_loss, fmt='%ls')