Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent Naming with different input shapes #1162

Closed
ppwwyyxx opened this issue Feb 18, 2016 · 5 comments
Closed

Inconsistent Naming with different input shapes #1162

ppwwyyxx opened this issue Feb 18, 2016 · 5 comments
Assignees

Comments

@ppwwyyxx
Copy link
Contributor

Install from binary pip package

  1. Which pip package you installed. CPU only for Linux
  2. The output from python -c "import tensorflow; print(tensorflow.version)": 0.7.0

Steps to reproduce

Run the following code:

import tensorflow as tf

a = tf.placeholder(tf.float32, name='haha', shape=[128, 20, 20, 20])
c = tf.nn.moments(a, axes=[0])[0]
print c.name

with tf.Graph().as_default():
    a = tf.placeholder(tf.float32, name='haha', shape=[None, 20, 20, 20])
    c = tf.nn.moments(a, axes=[0])[0]
    print c.name

It prints:

moments/Squeeze:0
moments/Squeeze_1:0

In short, with different input shapes, the op returns a variable with different names.
I'm not sure whether this is expected to happen, or something to be fixed. IMHO, this could cause some problems because one may use a fixed batch size for training, and a None batch size for inference. Getting a different name make it harder to manage variable load & restore.

@vrv
Copy link

vrv commented Feb 18, 2016

The reason is that moments produces a slightly different graph when the input shape is known (it can be pre-computed). When it's unknown, the code introduces a "squeeze" operator that isn't normally there, in the same op scope as the final mean, so there's a name collision and we increment by 1.

@vincentvanhoucke: can we/ should we make moments() return a consistent name for the returned tuples?

@vincentvanhoucke
Copy link
Contributor

These are the names of the ops, not the variables, so it won't be an issue for save and restore.
We should try to provide stable names for endpoints of composite ops though, that seems sane, even if users should definitely not rely on that naming to be stable over time. I'm working on updates to moments() for v0.8, I'll put that on my list of requirements.

@ppwwyyxx
Copy link
Contributor Author

I agree that usually ops won't be an issue for save and restore. But what happen to me is that I'm using ExponentialMovingAverage on c, then there will be a variable that inherit the inconsistent name from c, and then cause problems for save/restore.
BTW, this is how the non-official batch normalization mentioned in #1122 works (take moving average on mean/var).

@vincentvanhoucke
Copy link
Contributor

A fix for this has been checked in and should make its way to github shortly.

vrv pushed a commit that referenced this issue Feb 25, 2016
Helps with: #917
Also fixes #1162

The main benefit is that the computation of the sufficient statistics is now decoupled of the aggregation of the moments, which means that if you want to perform the accumulation incrementally, you don't have to keep all the inputs around, and can instead keep the much more compact sum and sum-of-squares. Accumulation could also be performed locally if you aggregate across multiple devices.
Computing sum and sum-of-squares can also theoretically be performed in parallel now.

Tested running inception: same performance, same step time.
Batch normalization benchmark is a bit faster on CPU, a bit slower on GPU:

Before:
cpu shape:4/3 #layers:10 mode:py scale:True train:False - 1.139310 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.021970 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:True - 2.767147 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.074531 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.742835 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.013473 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:True - 1.738806 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.052777 secs
cpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.119180 secs
gpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.011201 secs
cpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.218297 secs
gpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.048526 secs

After:
cpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.998944 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.025828 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:True - 2.657428 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.086614 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.603137 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.017668 secs
cpu shape:4/3 #layers:10 mode:py scale:True train:True - 1.519533 secs
gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.055214 secs
cpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.071344 secs
gpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.016440 secs
cpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.222093 secs
gpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.039967 secs
Change: 115507032
@oanoelsis
Copy link

is it fixed now? I have a same problem about tf.saver.restore() with different batch size train time and inference time. and i want to know this is the reason or not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants