Skip to content

Commit

Permalink
Add seed arg to MHA layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed May 8, 2024
1 parent e620cb4 commit da83683
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion keras/src/layers/attention/multi_head_attention.py
Expand Up @@ -59,6 +59,7 @@ class MultiHeadAttention(Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
seed: Optional integer to seed the dropout layer.
Call arguments:
query: Query tensor of shape `(B, T, dim)`, where `B` is the batch size,
Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(
f"Received: attention_axes={attention_axes}"
)
self._attention_axes = attention_axes
self.seed = seed

@property
def num_heads(self):
Expand Down Expand Up @@ -189,6 +192,7 @@ def get_config(self):
),
"kernel_constraint": constraints.serialize(self._kernel_constraint),
"bias_constraint": constraints.serialize(self._bias_constraint),
"seed": self.seed,
}
return {**base_config, **config}

Expand Down Expand Up @@ -359,7 +363,7 @@ def _build_attention(self, rank):
)
self._softmax = Softmax(axis=norm_axes, dtype=self.dtype_policy)
self._dropout_layer = Dropout(
rate=self._dropout, dtype=self.dtype_policy
rate=self._dropout, dtype=self.dtype_policy, seed=self.seed
)
self._inverse_sqrt_key_dim = 1.0 / math.sqrt(float(self._key_dim))

Expand Down

0 comments on commit da83683

Please sign in to comment.