Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras 3.2.1 breaks LoRA with ModelParallel #19496

Open
martin-gorner opened this issue Apr 11, 2024 · 9 comments
Open

Keras 3.2.1 breaks LoRA with ModelParallel #19496

martin-gorner opened this issue Apr 11, 2024 · 9 comments
Assignees
Labels
Gemma Gemma model specific issues stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug

Comments

@martin-gorner
Copy link
Contributor

You can test with the Gemma chat demo: bit.ly/gemma-pirate-demo

With keras 3.1.1, the line gemma_lm.backbone.enable_lora(rank=8) executes successfully.

With keras 3.2.1, the line gemma_lm.backbone.enable_lora(rank=8) errors out with:

File ~/.local/lib/python3.10/site-packages/keras/src/layers/core/einsum_dense.py:236, in EinsumDense.enable_lora(self, rank, a_initializer, b_initializer)
    229 self._tracker.unlock()
    230 self.lora_kernel_a = self.add_weight(
    231     name="lora_kernel_a",
    232     shape=(self.kernel.shape[:-1] + (rank,)),
    233     initializer=initializers.get(a_initializer),
    234     regularizer=self.kernel_regularizer,
    235 )
--> 236 self.lora_kernel_b = self.add_weight(
    237     name="lora_kernel_b",
    238     shape=(rank, self.kernel.shape[-1]),
    239     initializer=initializers.get(b_initializer),
    240     regularizer=self.kernel_regularizer,
    241 )
    242 self._kernel.trainable = False
    243 self._tracker.lock()

File ~/.local/lib/python3.10/site-packages/keras/src/layers/layer.py:507, in Layer.add_weight(self, shape, initializer, dtype, trainable, autocast, regularizer, constraint, aggregation, name)
    505 initializer = initializers.get(initializer)
    506 with self._open_name_scope():
--> 507     variable = backend.Variable(
    508         initializer=initializer,
    509         shape=shape,
    510         dtype=dtype,
    511         trainable=trainable,
    512         autocast=autocast,
    513         aggregation=aggregation,
    514         name=name,
    515     )
    516 # Will be added to layer.losses
    517 variable.regularizer = regularizers.get(regularizer)

File ~/.local/lib/python3.10/site-packages/keras/src/backend/common/variables.py:161, in KerasVariable.__init__(self, initializer, shape, dtype, trainable, autocast, aggregation, name)
    159     else:
    160         value = initializer
--> 161     self._initialize(value)
    162     self._shape = tuple(self._value.shape)
    163 self._ndim = len(self._shape)

File ~/.local/lib/python3.10/site-packages/keras/src/backend/jax/core.py:32, in Variable._initialize(self, value)
     30 else:
     31     self._layout = None
---> 32 self._direct_assign(value)

File ~/.local/lib/python3.10/site-packages/keras/src/backend/jax/core.py:36, in Variable._direct_assign(self, value)
     34 def _direct_assign(self, value):
     35     if getattr(self, "_layout", None) is not None:
---> 36         value = distribution_lib.distribute_variable(value, self._layout)
     37     self._value = value

File ~/.local/lib/python3.10/site-packages/keras/src/backend/jax/distribution_lib.py:59, in distribute_variable(value, layout)
     56     return value
     58 if layout.is_fully_addressable:
---> 59     return jax.device_put(value, layout)
     60 else:
     61     # Need to only distribute the value to local addressible devices, and
     62     # repack them back into global format.
     63     mapping = layout.addressable_devices_indices_map(value.shape)

File ~/.local/lib/python3.10/site-packages/jax/_src/api.py:2494, in device_put(x, device, src)
   2489 if ((device is None or
   2490      isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
   2491     (src is None or
   2492      isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
   2493   for leaf in tree_leaves(x):
-> 2494     _check_sharding(leaf, s=device)
   2495   return tree_map(
   2496       lambda y: dispatch.device_put_p.bind(
   2497           y, device=device, src=_infer_src_sharding(src, y)), x)
   2499 x_flat, treedef = tree_flatten(x)

File ~/.local/lib/python3.10/site-packages/jax/_src/api.py:2457, in _check_sharding(x, s)
   2455 aval = shaped_abstractify(x)
   2456 if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
-> 2457   pjit.pjit_check_aval_sharding(
   2458       (s,), (aval,), None, "device_put args", allow_uneven_sharding=False)
   2459 s.shard_shape(aval.shape)

File ~/.local/lib/python3.10/site-packages/jax/_src/pjit.py:1236, in pjit_check_aval_sharding(shardings, flat_avals, names, what_aval, allow_uneven_sharding)
   1234     s._to_xla_hlo_sharding(len(shape))
   1235 except ValueError as e:
-> 1236   raise ValueError(
   1237       f'One of {what_aval}{name_str} is incompatible with its sharding '
   1238       f'annotation {s}: {e}')
   1239 # Use the `OpSharding` proto to find out how many ways each dimension of
   1240 # the aval is sharded. This approach will work across all
   1241 # XLACompatibleSharding.
   1242 hlo_sharding = s._to_xla_hlo_sharding(len(shape))

ValueError: One of device_put args is incompatible with its sharding annotation NamedSharding(mesh=Mesh('batch': 1, 'model': 8), spec=PartitionSpec('model', None, None)): Sharding NamedSharding(mesh=Mesh('batch': 1, 'model': 8), spec=PartitionSpec('model', None, None)) is only valid for values of rank at least 3, but was applied to a value of rank 2.
@github-actions github-actions bot added the Gemma Gemma model specific issues label Apr 11, 2024
@fchollet fchollet added the keras-team-review-pending Pending review by a Keras team member. label Apr 11, 2024
@sachinprasadhs sachinprasadhs removed the keras-team-review-pending Pending review by a Keras team member. label Apr 11, 2024
@fchollet
Copy link
Member

CC @james77777778

@james77777778
Copy link
Contributor

CC @james77777778

I didn't spot the issue on Keras side. The enable_lora is even untouched from 3.1.1 to 3.2.1. Additionally, the newly introduced quantization method (quantize) is not called in the demo.
https://github.com/keras-team/keras/compare/v3.1.1...v3.2.1?diff=unified&w=

Could it possibly be a bug in jax?

@james77777778
Copy link
Contributor

I tried this script on Kaggle and got a different error when executing >2 times of:

# Create a device mesh with shape (1, 8) to parition weights across all 8 TPUs cores.
devices = keras.distribution.list_devices()  # 8 TPUs
device_mesh = keras.distribution.DeviceMesh(
    shape=(1, len(devices)),
    axis_names=("batch", "model"),
    devices=devices,
)

# Create a LayoutMap to partition relevant weights
layout_map = keras_nlp.models.GemmaBackbone.get_layout_map(device_mesh)
distribution = keras.distribution.ModelParallel(device_mesh, layout_map)

# Make ModelParallel laoading using the LayoutMap the default
keras.distribution.set_distribution(distribution)

# Initialize GemmaBackbone
gemma_backbone = keras_nlp.models.GemmaBackbone.from_preset("gemma_1.1_instruct_2b_en")

Ref:
https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gemma/gemma_backbone_test.py#L89

The error:

Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Kaggle notebook...

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[26], line 1
----> 1 gemma_backbone = keras_nlp.models.GemmaBackbone.from_preset("gemma_1.1_instruct_2b_en")

File /usr/local/lib/python3.10/site-packages/keras_nlp/src/models/backbone.py:200, in Backbone.from_preset(cls, preset, load_weights, **kwargs)
    194 if not issubclass(preset_cls, cls):
    195     raise ValueError(
    196         f"Preset has type `{preset_cls.__name__}` which is not a "
    197         f"a subclass of calling class `{cls.__name__}`. Call "
    198         f"`from_preset` directly on `{preset_cls.__name__}` instead."
    199     )
--> 200 return load_from_preset(
    201     preset,
    202     load_weights=load_weights,
    203     config_overrides=kwargs,
    204 )

File /usr/local/lib/python3.10/site-packages/keras_nlp/src/utils/preset_utils.py:376, in load_from_preset(preset, load_weights, config_file, config_overrides)
    374     config = json.load(config_file)
    375 config["config"] = {**config["config"], **config_overrides}
--> 376 layer = keras.saving.deserialize_keras_object(config)
    378 # Load any assets for our tokenizers.
    379 tokenizer = get_tokenizer(layer)

File /usr/local/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:711, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    709 with custom_obj_scope, safe_mode_scope:
    710     try:
--> 711         instance = cls.from_config(inner_config)
    712     except TypeError as e:
    713         raise TypeError(
    714             f"{cls} could not be deserialized properly. Please"
    715             " ensure that components that are Python object"
   (...)
    719             f"\n\nconfig={config}.\n\nException encountered: {e}"
    720         )

File /usr/local/lib/python3.10/site-packages/keras_nlp/src/models/backbone.py:135, in Backbone.from_config(cls, config)
    131 @classmethod
    132 def from_config(cls, config):
    133     # The default `from_config()` for functional models will return a
    134     # vanilla `keras.Model`. We override it to get a subclass instance back.
--> 135     return cls(**config)

File /usr/local/lib/python3.10/site-packages/keras_nlp/src/models/gemma/gemma_backbone.py:149, in GemmaBackbone.__init__(self, vocabulary_size, num_layers, num_query_heads, num_key_value_heads, hidden_dim, intermediate_dim, head_dim, layer_norm_epsilon, dropout, dtype, **kwargs)
    147 x = x * ops.cast(ops.sqrt(hidden_dim), x.dtype)
    148 for transformer_layer in self.transformer_layers:
--> 149     x = transformer_layer(x, padding_mask=padding_mask_input)
    150 sequence_output = self.layer_norm(x)
    151 super().__init__(
    152     inputs={
    153         "token_ids": token_id_input,
   (...)
    158     **kwargs,
    159 )

File /usr/local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File /usr/local/lib/python3.10/site-packages/keras_nlp/src/models/gemma/gemma_decoder_block.py:96, in GemmaDecoderBlock.build(self, input_shape)
     94 def build(self, input_shape):
     95     self.pre_attention_norm.build(input_shape)
---> 96     self.attention.build(input_shape)
     98     shape = input_shape
     99     self.pre_ffw_norm.build(shape)

File /usr/local/lib/python3.10/site-packages/keras_nlp/src/models/gemma/gemma_attention.py:64, in CachedGemmaAttention.build(self, inputs_shape)
     55 self.query_dense.build(inputs_shape)
     57 self.key_dense = keras.layers.EinsumDense(
     58     "bsd,kdh->bskh",
     59     output_shape=(None, self.num_key_value_heads, self.head_dim),
   (...)
     62     name="key",
     63 )
---> 64 self.key_dense.build(inputs_shape)
     66 self.value_dense = keras.layers.EinsumDense(
     67     "bsd,kdh->bskh",
     68     output_shape=(None, self.num_key_value_heads, self.head_dim),
   (...)
     71     name="value",
     72 )
     73 self.value_dense.build(inputs_shape)

File /usr/local/lib/python3.10/site-packages/jax/_src/api.py:2519, in device_put(x, device, src)
   2514 if ((device is None or
   2515      isinstance(device, (xc.Device, Sharding, TransferToMemoryKind))) and
   2516     (src is None or
   2517      isinstance(src, (xc.Device, Sharding, TransferToMemoryKind)))):
   2518   for leaf in tree_leaves(x):
-> 2519     _check_sharding(leaf, s=device)
   2520   return tree_map(
   2521       lambda y: dispatch.device_put_p.bind(
   2522           y, device=device, src=_infer_src_sharding(src, y)), x)
   2524 x_flat, treedef = tree_flatten(x)

File /usr/local/lib/python3.10/site-packages/jax/_src/api.py:2482, in _check_sharding(x, s)
   2480 aval = shaped_abstractify(x)
   2481 if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
-> 2482   pjit.pjit_check_aval_sharding(
   2483       (s,), (aval,), None, "device_put args", allow_uneven_sharding=False)
   2484 s.shard_shape(aval.shape)

File /usr/local/lib/python3.10/site-packages/jax/_src/pjit.py:1034, in pjit_check_aval_sharding(shardings, flat_avals, names, what_aval, allow_uneven_sharding)
   1032 for i, size in enumerate(num_ways_dim_sharded):
   1033   if not allow_uneven_sharding and shape[i] % size != 0:
-> 1034     raise ValueError(f"One of {what_aval}{name_str} was given the sharding "
   1035                      f"of {s}, which implies that "
   1036                      f"the global size of its dimension {i} should be "
   1037                      f"divisible by {size}, but it is equal to {shape[i]} "
   1038                      f"(full shape: {shape})")

ValueError: One of device_put args was given the sharding of NamedSharding(mesh=Mesh('batch': 1, 'model': 8), spec=PartitionSpec('model', 'batch', None)), which implies that the global size of its dimension 0 should be divisible by 8, but it is equal to 1 (full shape: (1, 2048, 256))

This should indicate that there is an issue with the model initialization of GemmaBackbone using keras.distribution and it is unrelated to enable_lora

@SuryanarayanaY SuryanarayanaY added the stat:awaiting keras-eng Awaiting response from Keras engineer label Apr 15, 2024
@mostafamdy
Copy link

has anyone found the solution or any workaround?

@CzarScar
Copy link

Meet the same issue on Kaggle.

@mostafamdy
Copy link

change keras version to 3.1.1

martin-gorner added a commit to martin-gorner/keras that referenced this issue Apr 30, 2024
@mattdangerw
Copy link
Member

Started looking...

I think the issue is a bug in the guides (and get_layout_map in KerasNLP), and not an issue with Keras itself. Basically as of 3.2 the lora weights will have a full path to them decoder_block_8/attention/query/lora_kernel_a instead of query/lora_kernel_a. I believe that is correct, but new as of 3.2, and has the effect that our guides will attempt to distribute the lora variables (which do not all have the correct shape for distribution).

We can restore the original behavior or the guide by updating our layout map paths a little bit to be stricter (and not select lora kernel variables)...

# Regex to match against the query, key and value matrices in attention layers
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = ("model", None, None)
layout_map["decoder_block.*attention_output/kernel"] = ("model", None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, "model")
layout_map["decoder_block.*ffw_linear/kernel"] = ("model", None)

It also might be ok to shard the lora kernel "A" variables on the same rank as our actual kernels, but I can't imaging that would impact the runtime very much given the lora variable sizes. I'll test this out tomorrow.

@martin-gorner
Copy link
Contributor Author

Tested, the fix works well. Thank you!

@josharian
Copy link

Possibly related: keras-team/keras-nlp#1613

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug
Projects
None yet
Development

No branches or pull requests

10 participants