/
parallel_module.py
78 lines (61 loc) · 2.71 KB
/
parallel_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import logging
import numpy as np
import dill
import tensorflow as tf
import gzip
from utils import get_jobdir
class ParallelModule:
def __init__(self, scope):
self._scope = scope
self._logger = logging.getLogger(self._scope)
self._jobdir = get_jobdir(self._logger)
self._seed = None
def get_params(self, scope=''):
raise NotImplementedError
def get_global_params(self, scope=''):
raise NotImplementedError
def set_params(self, params, scope=''):
raise NotImplementedError
def set_global_params(self, params, scope=''):
raise NotImplementedError
def store_transitions(self, batch, info):
raise NotImplementedError
def train(self):
raise NotImplementedError
def save_params(self, path, prefix=''):
prefix = '' if not prefix else f'_{prefix}'
# np.save(os.path.join(path, f'{self._scope}{prefix}'), self.get_params())
return (os.path.join(path, f'{self._scope}{prefix}'), self.get_params())
def save_global_params(self, path, prefix=''):
prefix = '' if not prefix else f'_{prefix}'
# np.save(os.path.join(path, f'{self._scope}{prefix}'), self.get_global_params())
return (os.path.join(path, f'{self._scope}{prefix}'), self.get_global_params())
def load_params(self, path, prefix=''):
prefix = '' if not prefix else f'_{prefix}'
params = np.load(os.path.join(path, f'{self._scope}{prefix}.npy'))[1]
self.set_params(params)
def load_global_params(self, path, prefix=''):
prefix = '' if not prefix else f'_{prefix}'
params = np.load(os.path.join(path, f'{self._scope}{prefix}.npy'))[1]
self.set_global_params(params)
def save_buffer(self, path, prefix=''):
if hasattr(self, '_buffer'):
prefix = '' if not prefix else f'_{prefix}'
if isinstance(self._buffer, dict):
for key, buffer in self._buffer.items():
buffer.save(os.path.join(path, f'{self._scope}_buffer_{key}{prefix}'))
else:
self._buffer.save(os.path.join(path, f'{self._scope}_buffer{prefix}'))
def restore_buffer(self, path, prefix=''):
if hasattr(self, '_buffer'):
prefix = '' if not prefix else f'_{prefix}'
if isinstance(self._buffer, dict):
for key, buffer in self._buffer.items():
buffer.restore(os.path.join(path, f'{self._scope}_buffer_{key}{prefix}'))
else:
self._buffer.restore(os.path.join(path, f'{self._scope}_buffer{prefix}'))
def set_seed(self, seed):
self._seed = seed
np.random.seed(self._seed)
tf.set_random_seed(self._seed)