Skip to content

Commit

Permalink
Replace NumPy Operations with JAX NumPy Equivalents for JIT Compilati…
Browse files Browse the repository at this point in the history
…on Compatibility (#23356)

* Replace numpy operations with jax.numpy for JIT compatibility

Replaced numpy operations with their jax.numpy equivalents in the transformer library. This change was necessary to prevent errors during JIT compilation. Specifically, the modifications involve changing numpy's in-place assignments to jax.numpy's immutable update methods.

* rm numpy import

* rm numpy import and fix np->jnp

* fixed slices bug

* fixed decoder_start_tokens -> decoder_start_token_id

* fixed jnp in modleing mt5

* doc fix

* rm numpy import

* make
  • Loading branch information
gojiteji committed May 16, 2023
1 parent c2393ca commit ba6815e
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 45 deletions.
11 changes: 5 additions & 6 deletions src/transformers/models/bart/modeling_flax_bart.py
Expand Up @@ -22,7 +22,6 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
Expand Down Expand Up @@ -218,15 +217,15 @@
"""


def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down
Expand Up @@ -209,11 +209,11 @@ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_tok
"""
Shift input ids one token to the right.
"""
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down
Expand Up @@ -23,7 +23,6 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
Expand Down Expand Up @@ -221,11 +220,11 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
"""
Shift input ids one token to the right.
"""
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/longt5/modeling_flax_longt5.py
Expand Up @@ -60,11 +60,11 @@ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_tok
"""
Shift input ids one token to the right.
"""
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/marian/modeling_flax_marian.py
Expand Up @@ -231,11 +231,11 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
"""
Shift input ids one token to the right.
"""
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down
15 changes: 7 additions & 8 deletions src/transformers/models/mbart/modeling_flax_mbart.py
Expand Up @@ -22,7 +22,6 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
Expand Down Expand Up @@ -223,20 +222,20 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
have a single `decoder_start_token_id` in contrast to other Bart-like models.
"""
prev_output_tokens = np.array(input_ids).copy()
prev_output_tokens = jnp.array(input_ids).copy()

if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")

# replace possible -100 values in labels by `pad_token_id`
prev_output_tokens = np.where(prev_output_tokens == -100, pad_token_id, input_ids)
index_of_eos = (np.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
decoder_start_tokens = np.array(
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=np.int32
prev_output_tokens = jnp.where(prev_output_tokens == -100, pad_token_id, input_ids)
index_of_eos = (jnp.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1)
decoder_start_tokens = jnp.array(
[prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=jnp.int32
).squeeze()

prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].copy()
prev_output_tokens[:, 0] = decoder_start_tokens
prev_output_tokens = prev_output_tokens.at[:, 1:].set(prev_output_tokens[:, :-1])
prev_output_tokens = prev_output_tokens.at[:, 0].set(decoder_start_tokens)

return prev_output_tokens

Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/mt5/modeling_flax_mt5.py
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
""" Flax mT5 model."""

import numpy as np
import jax.numpy as jnp

from ...utils import logging
from ..t5.modeling_flax_t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model
Expand All @@ -27,15 +27,15 @@


# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/pegasus/modeling_flax_pegasus.py
Expand Up @@ -214,11 +214,11 @@ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_tok
"""
Shift input ids one token to the right.
"""
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/t5/modeling_flax_t5.py
Expand Up @@ -60,11 +60,11 @@ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_tok
"""
Shift input ids one token to the right.
"""
shifted_input_ids = np.zeros_like(input_ids)
shifted_input_ids[:, 1:] = input_ids[:, :-1]
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids = jnp.zeros_like(input_ids)
shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)

shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids


Expand Down

0 comments on commit ba6815e

Please sign in to comment.