-
Notifications
You must be signed in to change notification settings - Fork 6
/
generate_dataset.py
62 lines (48 loc) · 1.48 KB
/
generate_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import argparse
import time
import os
import ray
from ray.exceptions import RayTimeoutError
from backward import generate_bwd
@ray.remote
def ray_generate_bwd(n=1):
out = generate_bwd(n)
for res in out:
append_sample(res)
return out
def append_sample(sample):
fldr = './data/dataset'
path = os.path.join(fldr, '{}.txt'.format(time.time()))
if os.path.exists(fldr):
arg = 'a'
else:
os.makedirs(fldr, exist_ok=True)
arg = 'w'
with open(path, arg) as fi:
fi.write(sample)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cpu', default=8, nargs='?')
parser.add_argument('--num', default=64, nargs='?')
parser.add_argument('--n', default=1, nargs='?')
args = parser.parse_args()
sequences_per_process = int(args.n)
cpu = int(args.cpu)
num = int(args.num)
process_runs = int(num / cpu)
ray.shutdown()
ray.init(num_cpus=cpu)
t0 = time.time()
print('{} samples {} cpu {} process_runs {} seq per process'.format(num, cpu, process_runs, sequences_per_process))
fails = 0
dataset = []
for _ in range(process_runs*cpu):
try:
out = ray_generate_bwd.remote(sequences_per_process)
out = ray.get(out, timeout=sequences_per_process)
dataset.extend(out)
except (TypeError, RayTimeoutError) as e:
print('fail', e)
fails += 1
print(time.time() - t0)
print(len(dataset))