Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor keras.dtype_policies #19711

Merged
merged 9 commits into from
May 15, 2024

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented May 13, 2024

EDITED:
Please refer to #19711 (comment) for the new updates.


I think it would be beneficial to provide some flexibility to QuantizedDTypePolicy regarding the global dtype policy keras.config.dtype_policy()

Additionally, there is a new property in DTypePolicy: is_quantized that should be useful for these quantization-related methods.

With this PR, we can do the following:

import keras
from keras import dtype_policies
from keras import layers
from keras import models


@keras.saving.register_keras_serializable("MyPackage")
class MySubclass(layers.Layer):
    def __init__(self, **kwargs):
        dtypes = kwargs.pop("dtypes", {})
        super().__init__(**kwargs)
        self.layer = layers.Dense(8, dtype=dtypes.pop("layer", None))

    def call(self, inputs, training=None):
        return self.layer(inputs)

    def get_config(self):
        config = super().get_config()
        config.pop("dtype")
        if self.layer.dtype_policy.is_quantized:
            _config = dtype_policies.serialize(self.layer.dtype_policy)
            _config["config"]["source_name"] = None
            config.update({"dtypes": {"layer": _config}})
        return config


inputs = layers.Input(shape=[None, 4])
outputs = MySubclass()(inputs)
model = models.Model(inputs, outputs)

"""global dtype policy (float32)"""

model.quantize("int8")
for layer in model._flatten_layers(include_self=False, recursive=True):
    print(layer.name, layer.dtype_policy)
model.save("model.keras")

"""global dtype policy (bfloat16)"""

keras.config.set_dtype_policy("bfloat16")
new_model = models.load_model("model.keras")
for layer in new_model._flatten_layers(include_self=False, recursive=True):
    print(layer.name, layer.dtype_policy)

Outputs:

# During saving (global dtype policy: float32)
input_layer <FloatDTypePolicy "float32">
my_subclass <FloatDTypePolicy "float32">
dense <QuantizedDTypePolicy "int8_from_float32">

# During loading (global dtype policy: bfloat16)
input_layer <FloatDTypePolicy "bfloat16">
my_subclass <FloatDTypePolicy "bfloat16">
dense_1 <QuantizedDTypePolicy "int8_from_bfloat16">

@mattdangerw has pointed out that currently the dtype policies of the quantized saves are immutable regarding the global dtype policy. keras-team/keras-nlp#1612 (comment)
With this PR, we can make a slight modification in get_config to support that feature.

@codecov-commenter
Copy link

codecov-commenter commented May 13, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 78.53%. Comparing base (310c275) to head (ecf2523).
Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #19711   +/-   ##
=======================================
  Coverage   78.52%   78.53%           
=======================================
  Files         498      498           
  Lines       45769    45756   -13     
  Branches     8456     8454    -2     
=======================================
- Hits        35942    35936    -6     
+ Misses       8091     8087    -4     
+ Partials     1736     1733    -3     
Flag Coverage Δ
keras 78.38% <100.00%> (+<0.01%) ⬆️
keras-jax 61.95% <100.00%> (+<0.01%) ⬆️
keras-numpy 56.29% <87.93%> (-0.01%) ⬇️
keras-tensorflow 63.41% <100.00%> (-0.01%) ⬇️
keras-torch 61.99% <100.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

keras/src/dtype_policies/dtype_policy.py Outdated Show resolved Hide resolved
@@ -202,6 +202,10 @@ def __repr__(self):
return f'<FloatDTypePolicy "{self._name}">'


GLOBAL_DEFAULT_PLACEHOLDER = "global_default"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a more explicit name, e.g. "DEFAULT_DTYPE_POLICY". Why use this string as the initial value, instead of e.g. None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use this string as the initial value, instead of e.g. None?

Currently, DTypePolicy and its subclasses rely on string value for parsing.
It is not clear for me how we can pass None in combination with the quantization mode.

Should we refactor QuantizedDTypePolicy to support a signature for both the quantization mode and the source dtype policy?

Ex:

policy = QuantizedDTypePolicy(mode="int8", source_dtype_policy="mixed_bfloat16")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, DTypePolicy and its subclasses rely on string value for parsing.
It is not clear for me how we can pass None in combination with the quantization mode.

We could just modify DTypePolicy to support None, meaning "default".

Should we refactor QuantizedDTypePolicy to support a signature for both the quantization mode and the source dtype policy?

Yes, that's a great idea!

@james77777778 james77777778 changed the title Add flexibility to QuantizedDTypePolicy Refactor keras.dtype_policies May 15, 2024
@james77777778
Copy link
Contributor Author

james77777778 commented May 15, 2024

I've significantly refactored the keras.dtype_policies.

Some notes:

  • Replicate all methods from FloatDTypePolicy to DTypePolicy so that FloatDTypePolicy becomes an alias for DTypePolicy. The reason is that the overriden __new__ in DTypePolicy caused numerous issues and addressing them would introduce unnecessary complexity.
  • Introduce a new signature for QuantizedDTypePolicy and QuantizedFloat8DTypePolicy.
  • Utilize dtype_policies.serialize in get_config of keras.layers.Layer. This is required because we now use different signatures for different dtype policies.
  • Update the tests.

Imcompatible warning:

  • We can still use something like "int8_from_float32" in keras.dtype_polices.get but it is now impossible to be passed to QuantizedDTypePolicy and QuantizedFloat8DTypePolicy.

To add flexibility to quantized dtype policy:

import keras
from keras import dtype_policies
from keras import layers
from keras import models


@keras.saving.register_keras_serializable("MyPackage")
class MySubclass(layers.Layer):
    def __init__(self, **kwargs):
        dtypes = kwargs.pop("dtypes", {})
        super().__init__(**kwargs)
        self.layer = layers.Dense(8, dtype=dtypes.pop("layer", None))

    def call(self, inputs, training=None):
        return self.layer(inputs)

    def get_config(self):
        config = super().get_config()
        config.pop("dtype")
        if self.layer.dtype_policy.is_quantized:
            _config = dtype_policies.serialize(self.layer.dtype_policy)
            _config["config"]["source_name"] = None
            config.update({"dtypes": {"layer": _config}})
        return config


inputs = layers.Input(shape=[None, 4])
outputs = MySubclass()(inputs)
model = models.Model(inputs, outputs)

"""global dtype policy (float32)"""

model.quantize("int8")
for layer in model._flatten_layers(include_self=False, recursive=True):
    print(layer.name, layer.dtype_policy)
model.save("model.keras")

"""global dtype policy (bfloat16)"""

keras.config.set_dtype_policy("bfloat16")
new_model = models.load_model("model.keras")
for layer in new_model._flatten_layers(include_self=False, recursive=True):
    print(layer.name, layer.dtype_policy)

The outputs:

# global dtype policy: float32
input_layer <FloatDTypePolicy "float32">
my_subclass <FloatDTypePolicy "float32">
dense <QuantizedDTypePolicy "int8_from_float32">

# global dtype policy: bfloat16
input_layer <FloatDTypePolicy "bfloat16">
my_subclass <FloatDTypePolicy "bfloat16">
dense_1 <QuantizedDTypePolicy "int8_from_bfloat16">

@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation May 15, 2024
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work -- it's definitely cleaner this way! LGTM

PR Queue automation moved this from Assigned Reviewer to Approved by Reviewer May 15, 2024
@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels May 15, 2024
@fchollet fchollet merged commit 3105247 into keras-team:master May 15, 2024
6 checks passed
PR Queue automation moved this from Approved by Reviewer to Merged May 15, 2024
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels May 15, 2024
@james77777778 james77777778 deleted the flexible-quantized-dtype branch May 16, 2024 00:33
mloc added a commit to mloc/tensorboard that referenced this pull request May 21, 2024
Keras' output format was slightly changed in keras-team/keras#19711; in
some cases dtypes will now be exported as a config map instead of just a
string.
This fixes test breakages when using ToT keras.
mloc added a commit to mloc/tensorboard that referenced this pull request May 21, 2024
Keras' output format was slightly changed in keras-team/keras#19711; for
non-input layers dtypes will now be exported as a config map instead of just a
string.
This fixes test breakages when using ToT keras.
mloc added a commit to mloc/tensorboard that referenced this pull request May 21, 2024
Keras' output format was slightly changed in keras-team/keras#19711; for
non-input layers dtypes will now be exported as a config map instead of just a
string.
This fixes test breakages when using ToT keras.
arcra pushed a commit to tensorflow/tensorboard that referenced this pull request May 21, 2024
Keras' output format was slightly changed in
keras-team/keras#19711; for non-input layers
dtypes will now be exported as a config map instead of just a string.
This fixes test breakages when using ToT keras.

Alternative to #6855
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
PR Queue
Merged
Development

Successfully merging this pull request may close these issues.

None yet

5 participants