Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more robust error checking to rewrite.change_batch_size() #14

Open
frreiss opened this issue Jan 21, 2019 · 0 comments
Open

Add more robust error checking to rewrite.change_batch_size() #14

frreiss opened this issue Jan 21, 2019 · 0 comments

Comments

@frreiss
Copy link
Member

frreiss commented Jan 21, 2019

The change_batch_size rewrite (see #4) works by putting the new batch size in place at the input nodes, then propagating the batch size through the rest of the graph by shape inference. If the user does not specify all the input nodes, then the remaining nodes will produce a conflicting batch sizes. This can result in an error (if a node ends up with two mutually inconsistent input batch sizes) or in the rewrite having no apparent effect on output batch sizes (if the user changes the batch size to None). As I noted in #13, the script batch_size_example.py has the latter problem. The batch size changes to None, but implicit inputs inside the batch normalization layers change the output batch size to 64.

The proper fix for this problem is as follows:

  • Add error checking code to the change_batch_size rewrite. If the batch size of a node doesn't change, or if type inference fails; then the rewrite should output a detailed error message. The message should contain the name of the node, the node's input shapes, and the names of the nodes that produced those shapes).
  • Use the error checking code to track down all the hidden inputs in the batch_size_example.py script and add them to the inputs set. Note that it may be necessary to run the input graph through the freeze graph script to remove variables. You can invoke the freeze graph script from python by adding from tensorflow.python.tools import freeze_graph to the beginning of your script and calling freeze_graph. freeze_graph_with_def_protos() directly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant