Skip to content

Commit

Permalink
Fix unexpected dtype policy changes when quantization fails (#19690)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed May 8, 2024
1 parent 43e5155 commit e620cb4
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 72 deletions.
2 changes: 2 additions & 0 deletions keras/src/dtype_policies/dtype_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ def _get_all_valid_policies(self):

@keras_export("keras.dtype_policies.QuantizedFloat8DTypePolicy")
class QuantizedFloat8DTypePolicy(QuantizedDTypePolicy):
default_amax_history_length = 1024

def __init__(self, name, amax_history_length=1024):
super().__init__(name)
if not isinstance(amax_history_length, int):
Expand Down
5 changes: 5 additions & 0 deletions keras/src/dtype_policies/dtype_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,11 @@ def test_properties_for_float8(self):
policy = QuantizedFloat8DTypePolicy("float8_from_mixed_bfloat16", 512)
self.assertEqual(policy.amax_history_length, 512)

# Test default_amax_history_length
self.assertEqual(
QuantizedFloat8DTypePolicy.default_amax_history_length, 1024
)

def test_invalid_properties_for_float8(self):
with self.assertRaisesRegex(TypeError, "must be an integer."):
QuantizedFloat8DTypePolicy("float8_from_float32", "512")
Expand Down
45 changes: 23 additions & 22 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def save_own_variables(self, store):
target_variables.append(self.outputs_grad_amax_history)
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode=mode)
)
for i, variable in enumerate(target_variables):
store[str(i)] = variable
Expand Down Expand Up @@ -247,7 +247,7 @@ def load_own_variables(self, store):
target_variables.append(self.outputs_grad_amax_history)
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode=mode)
)
for i, variable in enumerate(target_variables):
variable.assign(store[str(i)])
Expand Down Expand Up @@ -327,7 +327,7 @@ def quantized_build(self, input_shape, mode):
self._float8_build()
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode=mode)
)

def _int8_build(
Expand All @@ -353,14 +353,15 @@ def _int8_build(
self._is_quantized = True

def _float8_build(self):
if not isinstance(
self.dtype_policy, dtype_policies.QuantizedFloat8DTypePolicy
):
raise TypeError(
"`self.dtype_policy` must be the type of "
f"QuantizedFloat8DTypePolicy. Received {self.dtype_policy}"
)
amax_history_length = self.dtype_policy.amax_history_length
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy

# If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set
# `amax_history_length` to its default value.
amax_history_length = getattr(
self.dtype_policy,
"amax_history_length",
QuantizedFloat8DTypePolicy.default_amax_history_length,
)
# We set `trainable=True` because we will use the gradients to overwrite
# these variables
scale_kwargs = {
Expand Down Expand Up @@ -410,7 +411,7 @@ def quantized_call(self, inputs, training=None):
else:
mode = self.dtype_policy.quantization_mode
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode=mode)
)

def _int8_call(self, inputs):
Expand Down Expand Up @@ -550,15 +551,6 @@ def quantize(self, mode):
)
self._check_quantize_args(mode, self.compute_dtype)

# Set new dtype policy
if not isinstance(
self.dtype_policy, dtype_policies.QuantizedDTypePolicy
):
quantized_dtype = f"{mode}_from_{self.dtype_policy.name}"
# We set the internal `self._dtype_policy` instead of using the
# setter to avoid double `quantize` call
self._dtype_policy = dtype_policies.get(quantized_dtype)

self._tracker.unlock()
if mode == "int8":
# Quantize `self._kernel` to int8 and compute corresponding scale
Expand All @@ -580,10 +572,19 @@ def quantize(self, mode):
self._float8_build()
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode=mode)
)
self._tracker.lock()

# Set new dtype policy
if not isinstance(
self.dtype_policy, dtype_policies.QuantizedDTypePolicy
):
quantized_dtype = f"{mode}_from_{self.dtype_policy.name}"
# We set the internal `self._dtype_policy` instead of using the
# setter to avoid double `quantize` call
self._dtype_policy = dtype_policies.get(quantized_dtype)

# Release memory manually because sometimes the backend doesn't
gc.collect()

Expand Down
39 changes: 34 additions & 5 deletions keras/src/layers/core/dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,9 @@ def test_enable_lora_when_already_enabled(self):
with self.assertRaisesRegex(ValueError, "lora is already enabled"):
layer.enable_lora(rank=2)

"""Test quantization-related (int8 and float8) methods"""
# Test quantization-related (int8 and float8) methods

@pytest.mark.skipif(
backend.backend() == "numpy",
reason=f"{backend.backend()} does not support ops.custom_gradient.",
)
@pytest.mark.requires_trainable_backend
def test_quantize_int8(self):
layer = layers.Dense(units=16)
layer.build((None, 8))
Expand Down Expand Up @@ -451,6 +448,38 @@ def test_quantize_by_setting_dtype_policy(
layer.dtype_policy = policy
self.assertLen(layer.variables, expected_num_variables)

@parameterized.named_parameters(
("int7", "int7"),
("float7", "float7"),
)
def test_quantize_invalid_mode(self, mode):
layer = layers.Dense(units=2)
layer.build((None, 2))
x = np.random.random((1, 2))
# dtype_policy should not be altered by failed quantization
original_dtype_policy = layer.dtype_policy

# Test quantize
with self.assertRaisesRegex(ValueError, "Invalid quantization mode."):
layer.quantize(mode)
self.assertEqual(layer.dtype_policy, original_dtype_policy)

# Test quantized_build
with self.assertRaisesRegex(
NotImplementedError, "Invalid quantization mode."
):
layer.quantized_build((None, 2), mode)
self.assertEqual(layer.dtype_policy, original_dtype_policy)

# Test quantized_call
with self.assertRaisesRegex(
NotImplementedError, "Invalid quantization mode."
):
# Explicitly set quantization_mode
layer._dtype_policy.quantization_mode = mode
layer.quantized_call(x)
self.assertEqual(layer.dtype_policy, original_dtype_policy)

@pytest.mark.requires_trainable_backend
def test_quantize_int8_dtype_argument(self):
self.run_layer_test(
Expand Down
45 changes: 23 additions & 22 deletions keras/src/layers/core/einsum_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def save_own_variables(self, store):
target_variables.append(self.outputs_grad_amax_history)
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode=mode)
)
for i, variable in enumerate(target_variables):
store[str(i)] = variable
Expand Down Expand Up @@ -302,7 +302,7 @@ def load_own_variables(self, store):
target_variables.append(self.outputs_grad_amax_history)
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode=mode)
)
for i, variable in enumerate(target_variables):
variable.assign(store[str(i)])
Expand Down Expand Up @@ -391,7 +391,7 @@ def quantized_build(self, input_shape, mode):
self._float8_build()
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode=mode)
)

def _int8_build(
Expand Down Expand Up @@ -439,14 +439,15 @@ def _int8_build(
self._is_quantized = True

def _float8_build(self):
if not isinstance(
self.dtype_policy, dtype_policies.QuantizedFloat8DTypePolicy
):
raise TypeError(
"`self.dtype_policy` must be the type of "
f"QuantizedFloat8DTypePolicy. Received {self.dtype_policy}"
)
amax_history_length = self.dtype_policy.amax_history_length
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy

# If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set
# `amax_history_length` to its default value.
amax_history_length = getattr(
self.dtype_policy,
"amax_history_length",
QuantizedFloat8DTypePolicy.default_amax_history_length,
)
# We set `trainable=True` because we will use the gradients to overwrite
# these variables
scale_kwargs = {
Expand Down Expand Up @@ -496,7 +497,7 @@ def quantized_call(self, inputs, training=None):
else:
mode = self.dtype_policy.quantization_mode
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode=mode)
)

def _int8_call(self, inputs):
Expand Down Expand Up @@ -665,15 +666,6 @@ def quantize(self, mode):
)
self._check_quantize_args(mode, self.compute_dtype)

# Set new dtype policy
if not isinstance(
self.dtype_policy, dtype_policies.QuantizedDTypePolicy
):
quantized_dtype = f"{mode}_from_{self.dtype_policy.name}"
# We set the internal `self._dtype_policy` instead of using the
# setter to avoid double `quantize` call
self._dtype_policy = dtype_policies.get(quantized_dtype)

self._tracker.unlock()
if mode == "int8":
(
Expand Down Expand Up @@ -717,10 +709,19 @@ def quantize(self, mode):
self._float8_build()
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode=mode)
)
self._tracker.lock()

# Set new dtype policy
if not isinstance(
self.dtype_policy, dtype_policies.QuantizedDTypePolicy
):
quantized_dtype = f"{mode}_from_{self.dtype_policy.name}"
# We set the internal `self._dtype_policy` instead of using the
# setter to avoid double `quantize` call
self._dtype_policy = dtype_policies.get(quantized_dtype)

# Release memory manually because sometimes the backend doesn't
gc.collect()

Expand Down
48 changes: 39 additions & 9 deletions keras/src/layers/core/einsum_dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,9 @@ def test_lora_rank_argument(self):
supports_masking=False,
)

"""Test quantization-related (int8 and float8) methods"""
# Test quantization-related (int8 and float8) methods

@pytest.mark.skipif(
backend.backend() == "numpy",
reason=f"{backend.backend()} does not support ops.custom_gradient.",
)
@pytest.mark.requires_trainable_backend
def test_quantize_int8(self):
layer = layers.EinsumDense(
equation="ab,bcd->acd",
Expand Down Expand Up @@ -474,10 +471,7 @@ def test_quantize_int8(self):
("btd,ndh->btnh", "btd,ndh->btnh", (None, 2, 8), (1, 2, 4)),
("btd,df->btf", "btd,df->btf", (None, 4), (1, 2, 4)),
)
@pytest.mark.skipif(
backend.backend() == "numpy",
reason=f"{backend.backend()} does not support ops.custom_gradient.",
)
@pytest.mark.requires_trainable_backend
def test_quantize_int8_with_specific_equations(
self, equation, output_shape, input_shape
):
Expand Down Expand Up @@ -575,6 +569,42 @@ def test_quantize_by_setting_dtype_policy(
layer.dtype_policy = policy
self.assertLen(layer.variables, expected_num_variables)

@parameterized.named_parameters(
("int7", "int7"),
("float7", "float7"),
)
def test_quantize_invalid_mode(self, mode):
layer = layers.EinsumDense(
equation="ab,bcd->acd",
output_shape=(8, 32),
bias_axes="d",
)
layer.build((None, 3))
x = np.random.random((1, 3))
# dtype_policy should not be altered by failed quantization
original_dtype_policy = layer.dtype_policy

# Test quantize
with self.assertRaisesRegex(ValueError, "Invalid quantization mode."):
layer.quantize(mode)
self.assertEqual(layer.dtype_policy, original_dtype_policy)

# Test quantized_build
with self.assertRaisesRegex(
NotImplementedError, "Invalid quantization mode."
):
layer.quantized_build((None, 2), mode)
self.assertEqual(layer.dtype_policy, original_dtype_policy)

# Test quantized_call
with self.assertRaisesRegex(
NotImplementedError, "Invalid quantization mode."
):
# Explicitly set quantization_mode
layer._dtype_policy.quantization_mode = mode
layer.quantized_call(x)
self.assertEqual(layer.dtype_policy, original_dtype_policy)

@pytest.mark.requires_trainable_backend
def test_quantize_int8_dtype_argument(self):
self.run_layer_test(
Expand Down

0 comments on commit e620cb4

Please sign in to comment.