Skip to content

Commit

Permalink
fix(gpu): allow soft placement
Browse files Browse the repository at this point in the history
  • Loading branch information
Joppe Geluykens committed Apr 10, 2018
1 parent 29c3f31 commit dd820b3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 14 deletions.
14 changes: 1 addition & 13 deletions tfstackgan/python/estimator/python/head_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,23 +246,11 @@ def create_estimator_spec(self, features, mode, logits, labels=None,
train_ops = control_flow_ops.group(train_ops.generator_train_op,
train_ops.discriminator_train_op)

# scaffold = monitored_session.Scaffold(
# ready_op=control_flow_ops.no_op())

with ops.device('/cpu:0'):
ready_op = control_flow_ops.group(
variables.report_uninitialized_variables(),
resources.report_uninitialized_resources())

scaffold = monitored_session.Scaffold(
ready_op=ready_op)

return model_fn_lib.EstimatorSpec(
loss=scalar_loss,
mode=model_fn_lib.ModeKeys.TRAIN,
train_op=train_ops, # train_ops.global_step_inc_op,
training_hooks=None, # training_hooks
scaffold=scaffold)
training_hooks=None) # training_hooks
else:
raise ValueError('Mode not recognized: %s' % mode)

Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def main(_):

# 3) Do some configuration,
# - Session config
sess_config = tf.ConfigProto(log_device_placement=True)
sess_config = tf.ConfigProto(log_device_placement=True,
allow_soft_placement=True)
sess_config.gpu_options.allow_growth = True
# - Distribution config
distribution = tf.contrib.distribute.OneDeviceStrategy(device='/gpu:0')
Expand Down

0 comments on commit dd820b3

Please sign in to comment.