forked from domluna/memn2n
/
tf_config.py
58 lines (56 loc) · 1.45 KB
/
tf_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import tensorflow as tf
def configure(in_config):
tf.flags.DEFINE_float(
'learning_rate',
in_config['learning_rate'],
'Learning rate for Adam Optimizer'
)
tf.flags.DEFINE_float(
'epsilon',
in_config['epsilon'],
'Epsilon value for Adam Optimizer'
)
tf.flags.DEFINE_float(
'max_grad_norm',
in_config['max_grad_norm'],
'Clip gradients to this norm')
tf.flags.DEFINE_integer(
'evaluation_interval',
in_config['evaluation_interval'],
"Evaluate and print results every x epochs"
)
tf.flags.DEFINE_integer(
'batch_size',
in_config['batch_size'],
'Batch size for training'
)
tf.flags.DEFINE_integer(
'hops',
in_config['hops'],
'Number of hops in the Memory Network'
)
tf.flags.DEFINE_integer(
'epochs',
in_config['epochs'],
'Number of epochs to train for'
)
tf.flags.DEFINE_integer(
'embedding_size',
in_config['embedding_size'],
'Embedding size for embedding matrices'
)
tf.flags.DEFINE_integer(
'memory_size',
in_config['memory_size'],
'Maximum size of memory'
)
tf.flags.DEFINE_integer(
'task_id',
in_config['task_id'],
"bAbI task id, 1 <= id <= 6"
)
tf.flags.DEFINE_integer(
'random_state',
in_config['random_state'],
'Random state'
)