/
write_job.py
157 lines (139 loc) · 5.26 KB
/
write_job.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from itertools import product
import numpy as np
import argparse
save_in_folder = "scripts/"
NUM_STEP = 10000
def write_run_all_jobs(count, init_count=0):
"""
generate a .sh file to run all experiments
"""
with open('{}run_all.sh'.format(save_in_folder), 'w') as f:
f.write('#!/usr/bin/env bash\nchmod +x ./{}tasks_*.sh\n'.format(save_in_folder))
for i in range(count-init_count):
f.write('./{}tasks_{}.sh &> {}log{}.txt &\n'.format(save_in_folder, init_count+i, save_in_folder, i))
if (i+1) % 10 == 0:
f.write('wait\n')
def write_jobs(all_comb, count, verbose):
cmd = "python main_torch.py --dataset {} --data_size {} --model {} --state_update {} --overlap {} " \
"--T {} --num_update {} --batch_size {} --lr {} --buffer_size 100 --num_run 10\n"
# init_count = count
for domain, num_step, model, state_update, overlap, T, M, B, lr in all_comb:
new_cmd = cmd.format(domain, num_step, model, state_update, overlap, T, M, B, lr)
with open("{}tasks_{}.sh".format(save_in_folder, count), 'w') as f:
f.write(new_cmd)
if verbose == "True":
print(count, new_cmd)
count += 1
return count
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--verbose', default="True", type=str)
parser.add_argument('--start', default=0, type=int)
args = parser.parse_args()
count = args.start
# cw
domain_lst = ['cw']
model_lst = ['fpp']
state_update_lst = ['True', 'False']
overlap_lst = ['True']
T_lst = [10]
M_lst = [1]
B_lst = [1]
lr_lst = [0.001]
num_step_lst = [10000]
all_comb = list(
product(domain_lst, num_step_lst, model_lst, state_update_lst, overlap_lst, T_lst, M_lst, B_lst, lr_lst))
count = write_jobs(all_comb, count, args.verbose)
#
# model_lst = ['t-bptt']
# state_update_lst = ['True']
# overlap_lst = ['True', 'False']
# M_lst = [1]
# B_lst = [1]
# all_comb = list(product(domain_lst, num_step_lst, model_lst, state_update_lst, overlap_lst, T_lst, M_lst, B_lst, lr_lst))
# count = write_jobs(all_comb, count, args.verbose)
# model_lst = ['uoro']
# state_update_lst = ['True']
# overlap_lst = ['True']
# T_lst = [1]
# M_lst = [1]
# B_lst = [1]
# all_comb = list(
# product(domain_lst, num_step_lst, model_lst, state_update_lst, overlap_lst, T_lst, M_lst, B_lst, lr_lst))
# count = write_jobs(all_comb, count, args.verbose)
# lsd
# domain_lst = ['lsd']
# model_lst = ['fpp']
# state_update_lst = ['True', 'False']
# overlap_lst = ['True']
# T_lst = [8, 16, 32]
# M_lst = [1]
# B_lst = [1, 2, 4, 8, 16]
# num_step_lst = [10000]
# all_comb = list(product(domain_lst, num_step_lst, model_lst, state_update_lst, overlap_lst, T_lst, M_lst, B_lst, lr_lst))
# count = write_jobs(all_comb, count, args.verbose)
# model_lst = ['t-bptt']
# state_update_lst = ['True']
# overlap_lst = ['True', 'False']
# M_lst = [1]
# B_lst = [1]
# all_comb = list(product(domain_lst, num_step_lst, model_lst, state_update_lst, overlap_lst, T_lst, M_lst, B_lst, lr_lst))
# count = write_jobs(all_comb, count, args.verbose)
# model_lst = ['uoro']
# state_update_lst = ['True']
# overlap_lst = ['True']
# T_lst = [1]
# M_lst = [1]
# B_lst = [1]
# all_comb = list(
# product(domain_lst, num_step_lst, model_lst, state_update_lst, overlap_lst, T_lst, M_lst, B_lst, lr_lst))
# count = write_jobs(all_comb, count, args.verbose)
# mnist
# domain_lst = ['mnist']
# model_lst = ['fpp']
# state_update_lst = ['True']
# overlap_lst = ['True']
# T_lst = [7, 14, 21, 28]
# M_lst = [1]
# B_lst = [1]
# lr_lst = [0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001]
# num_step_lst = [28000]
#
# all_comb = list(
# product(domain_lst, num_step_lst, model_lst, state_update_lst, overlap_lst, T_lst, M_lst, B_lst, lr_lst))
# count = write_jobs(all_comb, count, args.verbose)
#
# model_lst = ['t-bptt']
# state_update_lst = ['True']
# overlap_lst = ['True', 'False']
# M_lst = [1]
# B_lst = [1]
# all_comb = list(
# product(domain_lst, num_step_lst, model_lst, state_update_lst, overlap_lst, T_lst, M_lst, B_lst, lr_lst))
# count = write_jobs(all_comb, count, args.verbose)
#
# model_lst = ['uoro']
# state_update_lst = ['True']
# overlap_lst = ['True']
# T_lst = [1]
# M_lst = [1]
# B_lst = [1]
# all_comb = list(
# product(domain_lst, num_step_lst, model_lst, state_update_lst, overlap_lst, T_lst, M_lst, B_lst, lr_lst))
# count = write_jobs(all_comb, count, args.verbose)
#
# domain_lst = ['mnist']
# model_lst = ['fpp']
# state_update_lst = ['True']
# overlap_lst = ['True']
# T_lst = [28]
# M_lst = [1]
# B_lst = [8, 16]
# lr_lst = [0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001]
# num_step_lst = [28000]
#
# all_comb = list(
# product(domain_lst, num_step_lst, model_lst, state_update_lst, overlap_lst, T_lst, M_lst, B_lst, lr_lst))
# count = write_jobs(all_comb, count, args.verbose)
# print('sbatch --array={}-{} ./run.sh'.format(0, count-1))
write_run_all_jobs(count)