/
fused_rnn_cell.py
183 lines (154 loc) · 5.94 KB
/
fused_rnn_cell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module for constructing fused RNN cells."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import rnn
class FusedRNNCell(object):
"""Abstract object representing a fused RNN cell.
A fused RNN cell represents the entire RNN expanded over the time
dimension. In effect, this represents an entire recurrent network.
Unlike RNN cells which are subclasses of `rnn_cell.RNNCell`, a `FusedRNNCell`
operates on the entire time sequence at once, by putting the loop over time
inside the cell. This usually leads to much more efficient, but more complex
and less flexible implementations.
Every `FusedRNNCell` must implement `__call__` with the following signature.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def __call__(self,
inputs,
initial_state=None,
dtype=None,
sequence_length=None,
scope=None):
"""Run this fused RNN on inputs, starting from the given state.
Args:
inputs: `3-D` tensor with shape `[time_len x batch_size x input_size]`
or a list of `time_len` tensors of shape `[batch_size x input_size]`.
initial_state: either a tensor with shape `[batch_size x state_size]`
or a tuple with shapes `[batch_size x s] for s in state_size`, if the
cell takes tuples. If this is not provided, the cell is expected to
create a zero initial state of type `dtype`.
dtype: The data type for the initial state and expected output. Required
if `initial_state` is not provided or RNN state has a heterogeneous
dtype.
sequence_length: Specifies the length of each sequence in inputs. An
`int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
time_len)`.
Defaults to `time_len` for each element.
scope: `VariableScope` or `string` for the created subgraph; defaults to
class name.
Returns:
A pair containing:
- Output: A `3-D` tensor of shape `[time_len x batch_size x output_size]`
or a list of `time_len` tensors of shape `[batch_size x output_size]`,
to match the type of the `inputs`.
- Final state: Either a single `2-D` tensor, or a tuple of tensors
matching the arity and shapes of `initial_state`.
"""
pass
class FusedRNNCellAdaptor(FusedRNNCell):
"""This is an adaptor for RNNCell classes to be used with `FusedRNNCell`."""
def __init__(self, cell, use_dynamic_rnn=False):
"""Initialize the adaptor.
Args:
cell: an instance of a subclass of a `rnn_cell.RNNCell`.
use_dynamic_rnn: whether to use dynamic (or static) RNN.
"""
self._cell = cell
self._use_dynamic_rnn = use_dynamic_rnn
def __call__(self,
inputs,
initial_state=None,
dtype=None,
sequence_length=None,
scope=None):
is_list = isinstance(inputs, list)
if self._use_dynamic_rnn:
if is_list:
inputs = array_ops.stack(inputs)
outputs, state = rnn.dynamic_rnn(
self._cell,
inputs,
sequence_length=sequence_length,
initial_state=initial_state,
dtype=dtype,
time_major=True,
scope=scope)
if is_list:
# Convert outputs back to list
outputs = array_ops.unstack(outputs)
else: # non-dynamic rnn
if not is_list:
inputs = array_ops.unstack(inputs)
outputs, state = rnn.static_rnn(
self._cell,
inputs,
initial_state=initial_state,
dtype=dtype,
sequence_length=sequence_length,
scope=scope)
if not is_list:
# Convert outputs back to tensor
outputs = array_ops.stack(outputs)
return outputs, state
class TimeReversedFusedRNN(FusedRNNCell):
"""This is an adaptor to time-reverse a FusedRNNCell.
For example,
```python
cell = tf.contrib.rnn.BasicRNNCell(10)
fw_lstm = tf.contrib.rnn.FusedRNNCellAdaptor(cell, use_dynamic_rnn=True)
bw_lstm = tf.contrib.rnn.TimeReversedFusedRNN(fw_lstm)
fw_out, fw_state = fw_lstm(inputs)
bw_out, bw_state = bw_lstm(inputs)
```
"""
def __init__(self, cell):
self._cell = cell
def _reverse(self, t, lengths):
"""Time reverse the provided tensor or list of tensors.
Assumes the top dimension is the time dimension.
Args:
t: 3D tensor or list of 2D tensors to be reversed
lengths: 1D tensor of lengths, or `None`
Returns:
A reversed tensor or list of tensors
"""
if isinstance(t, list):
return list(reversed(t))
else:
if lengths is None:
return array_ops.reverse_v2(t, [0])
else:
return array_ops.reverse_sequence(t, lengths, 0, 1)
def __call__(self,
inputs,
initial_state=None,
dtype=None,
sequence_length=None,
scope=None):
inputs = self._reverse(inputs, sequence_length)
outputs, state = self._cell(
inputs,
initial_state=initial_state,
dtype=dtype,
sequence_length=sequence_length,
scope=scope)
outputs = self._reverse(outputs, sequence_length)
return outputs, state