Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rfc: graph: support accumulation mode #1824

Open
wants to merge 1 commit into
base: rfcs
Choose a base branch
from

Conversation

wzt1997
Copy link
Contributor

@wzt1997 wzt1997 commented Mar 8, 2024

This is to propose adding accumulation data type support in oneDNN Graph API.

Link to the rendered document.

@TaoLv TaoLv added the RFC A design document label Mar 8, 2024
@wzt1997 wzt1997 force-pushed the zhitao/rfc/accumulation_mode branch 3 times, most recently from e32c8cc to b6f30c6 Compare March 13, 2024 08:11
```cpp
// for demonstrate purpose, we may provide the graph parameter in a wrapper or define more
// constructors. please refer to the implementation part for a brief introduction.
graph g(kind, accumulation_mode);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to define the API clear here because we already have fpmath_mode parameter on this constructor.
Also need to define the behavior for the combinations of fpmath_mode and accumulation_mode. For example, what will be computation behavior if fpmath_mode=bf16 and accumulation_mode=f16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Will confirm with oneDNN about the behavior. Since oneDNN does not support to check whether the given accumulation mode is compatible currently, I think most likely users will receive an exception for unsupported cases due to creating primitive failure during compilation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TaoLv Both serve different purposes:

  • fpmath_mode allows conversion of inputs during computation
  • accumulation_mode allows conversion of intermediate accumulators.

In the example you gave, a primitive implementation would be allowed to convert inputs to bf16 (in order to use AMX/XMX instructions), and also use f16 for accumulators (e.g. in instructions or when writing partial accumulators to memory)

@wzt1997 wzt1997 force-pushed the zhitao/rfc/accumulation_mode branch 2 times, most recently from d433e2f to c8df861 Compare March 18, 2024 07:52
Comment on lines 5 to 6
Please refer to the [dcoumentation](https://oneapi-src.github.io/oneDNN/dev_guide_attributes_accumulation_mode.html#doxid-dev-guide-attributes-accumulation-mode)
and [rfc](https://github.com/mgouicem/oneDNN/blob/mgouicem/rfcs/f16-accumulation/rfcs/20230118-f16-accumulation/README.md) for the motivation and design detail of primitive API.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, please help address a few questions & comments I have.

  1. Please correct me if I'm wrong about the motivation for this feature, as that'd help me understand it better - https://oneapi-src.github.io/oneDNN/dev_guide_attributes_fpmath_mode.html seems to suggest that inputs would also be downcasted before computation if they are in FP32 & fpmath_mode is enabled for a lower dtype. So, also combining this approach of also allowing choice of dtype of accumulation seems to be more similar to a framework's feature of using mixed precision computation (with an Automatic Mixed Precision feature, apart from deciding what ops' inputs to downcast, frameworks can also control what dtype would be used for accumulation). With this feature, framework developers would have more insights into & more flexibility over the computation they offload to oneDNN.

  2. If the whole graph would use low precision, that might be problematic because some ops may have poorer accuracy with BF16, for example.

  3. How would you try to strike a balance between good accuracy & good performance?

Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment!

For the first point, in my understanding, the answer is yes. fpmath mode and accumulation mode empower different capabilities to framework developers about controlling the computation. You may refer to the discussion for more details.

Regarding points 2 and 3, I understand the trade-off between performance and accuracy when using low-precision accumulation. In my view, it is the responsibility of framework users to define the attribute based on their desired outcomes. Users may choose different accumulation modes based on the specific scenario. In addition, in terms of graph-level API, framework users can create multiple graphs if different accumulation modes are expected.

typedef enum {
dnnl_accumulation_mode_strict,
dnnl_accumulation_mode_relaxed,
dnnl_accumulation_mode_any,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for some op, any may result in poor accuracy, thus decreasing the accuracy of the whole workload.

So, would you ensure that any would strike a balance between good accuracy & good performance? i.e. for some specific ops, not choosing accumulation in low precision, if that's known to decrease accuracy?

Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well. we have not decided such automatic mechanism in the library currently. For the current design, Graph API will not modify the attribute value users provide, but pass them to DNNL backend. As far as I know, the backend does not support the feature currently neither.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchitintel From oneDNN perspective, there is no way to know when decreasing accuracy is fine and when it is not.
For example, many low-accuracy recipe use full precision for first layer, however oneDNN has no knowledge about first layer or not.

Here, it is the responsibility of the user to pass the attribute properly based on their accuracy needs. On oneDNN side, we will just use the fastest implementation available that complies with that attribute.

However, when there is no performance benefit to using low-accuracy accumulators, we typically keep accumulators in f32.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info, @mgouicem!

@wzt1997 wzt1997 force-pushed the zhitao/rfc/accumulation_mode branch 3 times, most recently from f3a30b5 to e929a52 Compare April 8, 2024 01:50
@wzt1997 wzt1997 force-pushed the zhitao/rfc/accumulation_mode branch 2 times, most recently from a8d2282 to 35e2252 Compare April 10, 2024 06:58
@wzt1997 wzt1997 force-pushed the zhitao/rfc/accumulation_mode branch from 35e2252 to 4a0797d Compare April 16, 2024 07:28
@wzt1997 wzt1997 force-pushed the zhitao/rfc/accumulation_mode branch from 4a0797d to aeeff6a Compare April 16, 2024 07:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFC A design document
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants