New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[oneDNN] Fuse group of primitive ops for batch normalization to Fu… #44950
base: master
Are you sure you want to change the base?
[oneDNN] Fuse group of primitive ops for batch normalization to Fu… #44950
Conversation
…nd scale mul op skipped
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR and I'm so sorry for the delay! I appreciate the detailed explanation.
@penpornk - thanks for your previous comments. I pushed a new commit to:
|
@penpornk Can you please review this PR ? Thanks! |
1 similar comment
@penpornk Can you please review this PR ? Thanks! |
This PR is brought up to the latest baseline, and also updated test case on math op function usage. |
@penpornk Can you please review this PR ? Thanks! |
@penpornk Can you please review this PR ? Thanks! |
5 similar comments
@penpornk Can you please review this PR ? Thanks! |
@penpornk Can you please review this PR ? Thanks! |
@penpornk Can you please review this PR ? Thanks! |
@penpornk Can you please review this PR ? Thanks! |
@penpornk Can you please review this PR ? Thanks! |
self.assertAllClose(original_result, optimized_result) | ||
|
||
@test_util.run_deprecated_v1 | ||
def testFuseDecomposedBatchNorm_FormatUnsupportCase(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: FormatUnsupportCase -> FormatUnsupportedCase
This test is failing with:
tensorflow.python.framework.errors_impl.UnimplementedError: The Conv2D op currently only supports the NHWC tensor format on the CPU. The op was given the format: NCHW
[[{{node conv_op}}]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fusion is only implemented for NHWC format, and I added the check for the data format for precaution.
Is it ok not to add the mismatch test case for the data format?
If it is still needed, do you know any flag that I can add to skip this test case for CPU device?
( I am testing on CPU, couldn't see this test failure).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's supported with oneDNN, but not without?
Maybe you can condition the test on onednn, or otherwise catch the error and skip if the format itself isn't supported.
@@ -751,7 +752,10 @@ def testFuseDecomposedBatchNorm_PatternMismatchCase(self): | |||
self.assertAllClose(original_result, optimized_result) | |||
|
|||
@test_util.run_deprecated_v1 | |||
def testFuseDecomposedBatchNorm_FormatUnsupportCase(self): | |||
def testFuseDecomposedBatchNorm_FormatUnsupportedCase(self): | |||
if tf_config.list_physical_devices("CPU"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will always be true, since there's always at least one CPU device (the host).
Plus, don't you want this to run on CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test case is added to check - if the frozen graph has NCHW format, we are not going to fuse it. The data format check is just to make sure we don't handle NCHW case. If with the current TensorFlow, it won't even allow creating conv op with NCHW format for CPU, the code checking is just double guarantee.
Can we Not to add this NCHW data format test case?
self.assertAllClose(original_result, optimized_result) | ||
|
||
@test_util.run_deprecated_v1 | ||
def testFuseDecomposedBatchNorm_FormatUnsupportCase(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's supported with oneDNN, but not without?
Maybe you can condition the test on onednn, or otherwise catch the error and skip if the format itself isn't supported.
#44950 (comment) This data format test case doesn't really add any value, again, can we not to add it? |
Okay |
@cantonios , I updated the data format test case to just verify the fusion result as expected. I think this should avoid the exception on executing the conv2d NCHW on CPU. |
A TF test failed when oneDNN custom ops aren't enabled:
Could you please help take a look? Command to reproduce:
|
In legacy models, the batch normalization is performed via a group of individual ops instead of a single FusedBatchNorm op. This PR adds fusion support in optimize_for_inference tool to identify a pattern of batch normalization via group of ops in the graph and replace those with single FusedBatchNorm. In inference use case, a FusedBatchNorm op can be further folded with convolution op to reduce the computation and thus increase the performance.
Example of graph transformation via optimize_for_inference tool with change in this PR is shown below. The original subgraph is part of the DenseNet-169. The performance after applying optimize_for_inference is about 1.8x comparing to original graph when measured on CPU for DenseNet-169.
Original subgraph with group of ops made up to batch normalization:
Subgraph after optimize_for_inference tool, which fused individual batch normalization ops to FusedBatchNorm, then followed by folding FusedBatchNorm with convolution ops
Another example is from GoogleNet-v3, in this model, the gamma factor is 1, multiplication with gamma is not present comparing to the batch normalization ops in DenseNet-169. The performance gain after applying optimze_for_inference is about 1.6x comparing to original graph when measured on CPU for GoogleNet-v3.
Original subgraph with a group of primitive ops made up to batch normalization when gamma is 1:
Subgraph after applying optimize_for_inference tool.