Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Pin jax for Dreambooth Fine-Tuning template #45389

Merged
merged 1 commit into from
May 21, 2024

Conversation

matthewdeng
Copy link
Contributor

@matthewdeng matthewdeng commented May 16, 2024

Why are these changes needed?

More recent versions of jax (e.g. 0.4.28) will cause the following problem:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/ray/default/ray/doc/source/templates/05_dreambooth_finetuning/dreambooth/generate.py:9 in  │
│ <module>                                                                                         │
│                                                                                                  │
│    6 import ray                                                                                  │
│    7                                                                                             │
│    8 from flags import run_model_flags                                                           │
│ ❱  9 from generate_utils import get_pipeline                                                     │
│   10                                                                                             │
│   11                                                                                             │
│   12 def run(args):                                                                              │
│                                                                                                  │
│ /home/ray/default/ray/doc/source/templates/05_dreambooth_finetuning/dreambooth/generate_utils.py │
│ :1 in <module>                                                                                   │
│                                                                                                  │
│ ❱  1 from diffusers import DiffusionPipeline                                                     │
│    2 from diffusers.loaders import LoraLoaderMixin                                               │
│    3 import torch                                                                                │
│    4                                                                                             │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/diffusers/__init__.py:38 in <module>            │
│                                                                                                  │
│    35 except OptionalDependencyNotAvailable:                                                     │
│    36 │   from .utils.dummy_pt_objects import *  # noqa F403                                     │
│    37 else:                                                                                      │
│ ❱  38 │   from .models import (                                                                  │
│    39 │   │   AsymmetricAutoencoderKL,                                                           │
│    40 │   │   AutoencoderKL,                                                                     │
│    41 │   │   ControlNetModel,                                                                   │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/diffusers/models/__init__.py:35 in <module>     │
│                                                                                                  │
│   32 │   from .vq_model import VQModel                                                           │
│   33                                                                                             │
│   34 if is_flax_available():                                                                     │
│ ❱ 35 │   from .controlnet_flax import FlaxControlNetModel                                        │
│   36 │   from .unet_2d_condition_flax import FlaxUNet2DConditionModel                            │
│   37 │   from .vae_flax import FlaxAutoencoderKL                                                 │
│   38                                                                                             │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/diffusers/models/controlnet_flax.py:16 in       │
│ <module>                                                                                         │
│                                                                                                  │
│    13 # limitations under the License.                                                           │
│    14 from typing import Optional, Tuple, Union                                                  │
│    15                                                                                            │
│ ❱  16 import flax                                                                                │
│    17 import flax.linen as nn                                                                    │
│    18 import jax                                                                                 │
│    19 import jax.numpy as jnp                                                                    │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/flax/__init__.py:18 in <module>                 │
│                                                                                                  │
│   15                                                                                             │
│   16 """Flax API."""                                                                             │
│   17                                                                                             │
│ ❱ 18 from .configurations import (                                                               │
│   19 │   config as config,                                                                       │
│   20 )                                                                                           │
│   21                                                                                             │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/flax/configurations.py:92 in <module>           │
│                                                                                                  │
│    89 # Whether to use the lazy rng implementation.                                              │
│    90 flax_lazy_rng = static_bool_env('FLAX_LAZY_RNG', True)                                     │
│    91                                                                                            │
│ ❱  92 flax_filter_frames = define_bool_state(                                                    │
│    93 │   name='filter_frames',                                                                  │
│    94 │   default=True,                                                                          │
│    95 │   help=('Whether to hide flax-internal stack frames from tracebacks.'))                  │
│                                                                                                  │
│ /home/ray/anaconda3/lib/python3.10/site-packages/flax/configurations.py:42 in define_bool_state  │
│                                                                                                  │
│    39   'FLAX_<UPPERCASE_NAME>'. JAX config ensures that the flag can be overwritten             │
│    40   on runtime with `flax.config.update('flax_<config_name>', <value>)`.                     │
│    41   """                                                                                      │
│ ❱  42   return jax_config.define_bool_state('flax_' + name, default, help)                       │
│    43                                                                                            │
│    44                                                                                            │
│    45 def static_bool_env(varname: str, default: bool) -> bool:                                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'Config' object has no attribute 'define_bool_state'

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Matthew Deng <matt@anyscale.com>
@matthewdeng matthewdeng marked this pull request as ready for review May 21, 2024 00:19
@matthewdeng matthewdeng requested a review from a team as a code owner May 21, 2024 00:19
@matthewdeng matthewdeng enabled auto-merge (squash) May 21, 2024 00:20
@github-actions github-actions bot added the go Trigger full test run on premerge label May 21, 2024
Copy link
Contributor

@angelinalg angelinalg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp

@matthewdeng matthewdeng merged commit f403087 into ray-project:master May 21, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go Trigger full test run on premerge
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants