-
Notifications
You must be signed in to change notification settings - Fork 0
/
io_daemon.py
96 lines (82 loc) · 4.25 KB
/
io_daemon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
'''
Copyright (C) 2020, Northwestern University and Lawrence Berkeley National Laboratory
See COPYRIGHT notice in top-level directory.
'''
import time
import h5py
import numpy as np
import multiprocessing
import horovod.tensorflow as hvd
class IOdaemon:
def __init__ (self, dataset, do_shuffle = 0, buffer_size = 128):
self.size = hvd.size()
self.rank = hvd.rank()
self.rng = np.random.default_rng()
self.dataset = dataset
self.shuffled_index = np.arange(self.dataset.num_train_files)
self.buffer_size = buffer_size
self.do_shuffle = do_shuffle
self.in_file_off = 0
self.num_train_files = len(dataset.train_files)
self.num_valid_files = len(dataset.valid_files)
self.file_index = 0
self.offset = int(self.num_train_files / self.size) * self.rank
self.data_shape = (self.buffer_size, 128, 128, 128, 12)
self.label_shape = (self.buffer_size, 4)
self.num_local_files = int(self.num_train_files / self.size)
self.write_index = 0
print ("R" + str(self.rank) + " will work on " + str(self.num_local_files) + " files.")
def run (self, lock, cv, finish,
data, label, num_samples):
num_buffers = len(data)
self.shuffled_index[:] = self.dataset.shared_shuffled_index[:]
print ("R" + str(self.rank) + " updated shuffled_index, [0] is : " + str(self.shuffled_index[0]))
print ("Number of buffers: " + str(num_buffers))
while 1:
# Get the current write buffer index.
write_index = self.write_index
if finish.value == 1:
break
buf_off = 0
if num_samples[write_index].value == 0:
while (buf_off < self.buffer_size):
# Choose a file to read.
file_index = self.shuffled_index[self.file_index + self.offset]
# Open the target file.
start = time.time()
f = h5py.File(self.dataset.train_files[file_index], 'r')
file_len = f['3Dmap'].shape[0]
read_off = self.in_file_off
read_len = min (self.buffer_size - buf_off, file_len - read_off)
# Read
data_np = np.frombuffer(data[write_index], dtype = np.uint16).reshape(self.data_shape)
np.copyto(data_np[buf_off:buf_off + read_len], f['3Dmap'][read_off:read_off + read_len])
label_np = np.frombuffer(label[write_index], dtype = np.float32).reshape(self.label_shape)
np.copyto(label_np[buf_off:buf_off + read_len], f['unitPar'][read_off:read_off + read_len])
f.close()
end = time.time()
print ("R" + str(self.rank) + " read " + str(read_len) +\
" samples from " + str(self.dataset.train_files[file_index]) +\
" and it took " + str(end - start) + " secs")
# Update the offsets.
buf_off += read_len
self.in_file_off += read_len
# If one file has been all consumed, go for the next local file.
if self.in_file_off == file_len:
self.in_file_off = 0
self.file_index += 1
# If all the local files have been traversed over,
# wrap around the index and get the shuffled index from the main thread.
if self.file_index == self.num_local_files:
self.file_index = 0
if self.do_shuffle == 1:
self.shuffled_index[:] = self.dataset.shared_shuffled_index[:]
print ("R" + str(self.rank) + " updated shuffled_index, [0] is : " + str(self.shuffled_index[0]))
lock.acquire()
if buf_off > 0:
num_samples[write_index].value = self.buffer_size
self.write_index = (write_index + 1) % num_buffers
cv.notify()
while finish.value == 0 and num_samples[self.write_index].value > 0:
cv.wait()
lock.release()