/
split_dependency.py
136 lines (119 loc) · 5.64 KB
/
split_dependency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Utility for creating multiple dependencies with synchronized save/restore."""
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.checkpointable import base as checkpointable
class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
"""Wraps save and restore callbacks as a `SaveableObject`."""
def __init__(self, name, dtype, save_callback, restore_callback):
self._restore_callback = restore_callback
spec = saver_lib.BaseSaverBuilder.SaveSpec(
tensor=save_callback,
slice_spec="",
name=name,
dtype=dtype)
super(_CallbackSaveable, self).__init__(
save_callback, [spec], name)
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into both variables."""
tensor, = restored_tensors
return self._restore_callback(tensor)
class _SplitDependency(checkpointable.CheckpointableBase):
"""Looks like a regular variable while synchronizing save/restores."""
def __init__(self, save_buffer, restore_buffer, name, dtype, num_components,
fill_save_buffer_fn, consume_restore_buffer_fn):
self._save_buffer = save_buffer
self._restore_buffer = restore_buffer
self._name = name
self._dtype = dtype
self._num_components = num_components
self._fill_save_buffer_fn = fill_save_buffer_fn
self._consume_restore_buffer_fn = consume_restore_buffer_fn
def _save(self):
"""Pull from the shared buffer, populating it if necessary."""
if self._name not in self._save_buffer:
if self._save_buffer:
raise AssertionError(
("Split dependency %s (%s) unsynchronized. Split dependencies must "
"be saved together.") % (self._name, self))
self._fill_save_buffer_fn(self._save_buffer)
return self._save_buffer.pop(self._name)
def _restore(self, tensor):
"""Push into the shared buffer, flushing it if necessary."""
if self._name in self._restore_buffer:
raise AssertionError(
("Split dependency %s (%s) unsynchronized. Split dependencies must "
"be restored together.") % (self._name, self))
self._restore_buffer[self._name] = tensor
if len(self._restore_buffer) == self._num_components:
op = self._consume_restore_buffer_fn(self._restore_buffer)
self._restore_buffer.clear()
return op
else:
return control_flow_ops.no_op()
def _gather_saveables_for_checkpoint(self):
"""Looks to Checkpointable like a regular variable."""
return {
checkpointable.VARIABLE_VALUE_KEY:
functools.partial(_CallbackSaveable,
dtype=self._dtype,
save_callback=self._save,
restore_callback=self._restore)
}
def split_dependency(component_names, component_dtypes,
fill_save_buffer_fn, consume_restore_buffer_fn):
"""Creates multiple dependencies with a synchronized save/restore.
Useful when a single op produces `Tensor`s which should each be saved under
different objects, or when `Tensor`s saved with many different objects need to
be restored together as inputs to a single op (i.e. an object which uses a
single fused op may be swapped out for a subgraph of objects, and these two
programs are checkpoint compatible).
Args:
component_names: A sequence of names for the split
dependencies. `fill_save_buffer_fn` must add these keys to the dictionary
it is passed, and `consume_restore_buffer_fn` will receive a dictionary
with these keys.
component_dtypes: Data types for the `Tensor`s being saved and restored, a
sequence corresponding to `component_names`.
fill_save_buffer_fn: A function which takes an empty dictionary as an
argument and adds `Tensor`s with `component_names` as keys. These
`Tensor`s will be saved as if they were individual variables.
consume_restore_buffer_fn: A function which takes a dictionary with
`component_names` as keys mapping to restored individual `Tensor`s and
returns a restore op (or if executing eagerly, runs the restoration and
may return `None`).
Returns:
A dictionary mapping from names to Checkpointable objects. If one is
reachable from an object as a dependency, the others should be too; adding
dependencies on some but not all of the objects will result in errors.
"""
save_buffer = {}
restore_buffer = {}
split_dependencies = {}
for name, dtype in zip(component_names, component_dtypes):
split_dependencies[name] = _SplitDependency(
save_buffer=save_buffer,
restore_buffer=restore_buffer,
name=name,
dtype=dtype,
num_components=len(component_names),
fill_save_buffer_fn=fill_save_buffer_fn,
consume_restore_buffer_fn=consume_restore_buffer_fn)
return split_dependencies