Skip to content

Commit

Permalink
Fixed saver relative paths for latest_checkpoint
Browse files Browse the repository at this point in the history
This would be cleaner if we made all paths listed in the "latest"
file relative to the its directory, allowing the removal of the
added `os.path.isabs` checks.

That would make the `os.join` in `saver.latest_checkpoint` much less
surprising.

But at least this way, there is no effect on currently working code.

Fixes #571
Change-Id: I47d8536b9b2ed3dcc193d6e6b7f4573a4e22c9b3
  • Loading branch information
MarkDaoust authored and Vijay Vasudevan committed Jan 8, 2016
1 parent 96fac8a commit eef2aaa
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tensorflow/python/training/saver.py
Expand Up @@ -489,6 +489,16 @@ def update_checkpoint_state(save_dir,
all_model_checkpoint_paths.append(model_checkpoint_path)
# Writes the "checkpoint" file for the coordinator for later restoration.
coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)

# Relative paths need to be rewritten to be relative to the "save_dir".
if not os.path.isabs(model_checkpoint_path):
model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)

all_model_checkpoint_paths = [
os.path.relpath(p, save_dir) for p in all_model_checkpoint_paths
if not os.path.isabs(p)
]

if coord_checkpoint_filename == model_checkpoint_path:
raise RuntimeError("Save path '%s' conflicts with path used for "
"checkpoint state. Please use a different save path." %
Expand Down Expand Up @@ -854,6 +864,10 @@ def save(self, sess, save_path, global_step=None, latest_filename=None):
"""
if latest_filename is None:
latest_filename = "checkpoint"

if os.path.split(latest_filename)[0]:
raise ValueError("'latest_filename' must not contain path components")

if global_step is not None:
if not isinstance(global_step, compat.integral_types):
global_step = training_util.global_step(sess, global_step)
Expand Down Expand Up @@ -905,8 +919,11 @@ def latest_checkpoint(checkpoint_dir, latest_filename=None):
# Pick the latest checkpoint based on checkpoint state.
ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
if ckpt and ckpt.model_checkpoint_path:

# If you pass "os.path.join" two absolute paths it returns the second one.
checkpoint_pattern = os.path.join(
checkpoint_dir, ckpt.model_checkpoint_path)

if gfile.Glob(checkpoint_pattern):
return checkpoint_pattern

Expand Down
71 changes: 71 additions & 0 deletions tensorflow/python/training/saver_test.py
Expand Up @@ -20,6 +20,9 @@

import os.path
import time
import contextlib
import shutil
import tempfile

import tensorflow.python.platform

Expand Down Expand Up @@ -583,5 +586,73 @@ def testNonReshape(self):
self.assertEqual(20.0, v1.eval())


class LatestCheckpointWithRelativePaths(tf.test.TestCase):

@staticmethod
@contextlib.contextmanager
def tempWorkingDir(temppath):
cwd = os.getcwd()
os.chdir(temppath)
try:
yield
finally:
os.chdir(cwd)

@staticmethod
@contextlib.contextmanager
def tempDir():
tempdir = tempfile.mkdtemp()
try:
yield tempdir
finally:
shutil.rmtree(tempdir)

def testRelativePath(self):
# Make sure we have a clean directory to work in.
with self.tempDir() as tempdir:

# Jump to that directory until this test is done.
with self.tempWorkingDir(tempdir):

# Save training snapshots to a relative path.
traindir = 'train/'
os.mkdir(traindir)

filename = 'snapshot'
filepath = os.path.join(traindir, filename)

with self.test_session() as sess:
# Build a simple graph.
v0 = tf.Variable(0.0)
inc = v0.assign_add(1.0)

save = tf.train.Saver({'v0': v0})

# Record a short training history.
tf.initialize_all_variables().run()
save.save(sess, filepath, global_step=0)
inc.eval()
save.save(sess, filepath, global_step=1)
inc.eval()
save.save(sess, filepath, global_step=2)

with self.test_session() as sess:
# Build a new graph with different initialization.
v0 = tf.Variable(-1.0)

# Create a new saver.
save = tf.train.Saver({'v0': v0})
tf.initialize_all_variables().run()

# Get the most recent checkpoint name from the training history file.
name = tf.train.latest_checkpoint(traindir)
self.assertIsNotNone(name)

# Restore "v0" from that checkpoint.
save.restore(sess, name)
self.assertEquals(v0.eval(), 2.0)



if __name__ == "__main__":
tf.test.main()

0 comments on commit eef2aaa

Please sign in to comment.