Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracking: development around Rectified Flow #182

Closed
7 tasks done
yqzhishen opened this issue Apr 1, 2024 · 3 comments · Fixed by #184
Closed
7 tasks done

Tracking: development around Rectified Flow #182

yqzhishen opened this issue Apr 1, 2024 · 3 comments · Fixed by #184

Comments

@yqzhishen
Copy link
Member

yqzhishen commented Apr 1, 2024

We are introducing Rectified Flow, a new ODE-based generative model, to this repository (in RectifiedFlow branch). The differences between Rectified Flow and the currently used DDPM will result in some API changes. The testing and adaptation may take one or more weeks. Since we are still in the early stage and the code is not well-organized, the APIs and configurations on the branch may change over time without any backward compatibility. This issue is raised mainly to inform those who are testing and researching on that branch with the changes (and possible migration steps).

TODOs

  • Initial implementation with temporary configurations
  • Testing, verifying and comparing
  • Migrate inference APIs and corresponding configurations to continuous acceleration profile: int64 depth to float32 depth, speedup to steps
  • Re-organize code of Rectified Flow, and adapt DDPM to continuous acceleration (convert to discrete settings)
  • ONNX exporter adaptations and fixes
  • More tests to determine proper default configuration
  • Documentation, ready to merge
@yqzhishen yqzhishen pinned this issue Apr 1, 2024
@yqzhishen
Copy link
Member Author

The first stage of refactoring and migration to continous acceleration has been finished.

Rectified Flow models can still run with full compatibility, but the following configurations will no longer take effects on Rectified Flow at training time (they will be converted automatically at inference time if the config file does not contain the new keys):

  • timesteps: replaced by time_scale_factor, and can be float
  • K_step: replaced by T_start (between 0 and 1; 0 means K_step = timesteps, 1 means K_step = 0)
  • K_step_infer: replaced by T_start_infer (between 0 and 1)
  • diff_speedup: replaced by sampling_steps (meaning the actual steps of sampling)

Inference API (scripts/infer.py) has been changed as follows:

  • --depth now accepts a float value between 0 and 1
  • --speedup is removed and replaced by --steps

@yqzhishen
Copy link
Member Author

ONNX exporting is supported now, but some early Rectified Flow models will result in KeyError. Please manually add the missing keys into the configuration file.

@yqzhishen
Copy link
Member Author

The second stage of refactoring has been finished in dc6896b.

Due to adjustment in the state dict, previous model trained on this branch before the commit should be migrated with the following code:

import collections
import pathlib
from typing import Dict, Any

import click
import torch


@click.command()
@click.argument(
    'in_ckpt', type=click.Path(
        exists=True, dir_okay=False, file_okay=True, readable=True, path_type=pathlib.Path
    )
)
@click.argument(
    'out_ckpt', type=click.Path(
        exists=False, dir_okay=False, file_okay=True, writable=True, path_type=pathlib.Path
    )
)
def migrate_reflow(in_ckpt: pathlib.Path, out_ckpt: pathlib.Path):
    ckpt = torch.load(in_ckpt, map_location='cpu')
    in_state_dict: Dict[str, Any] = ckpt['state_dict']
    out_state_dict = collections.OrderedDict()
    for k, v in in_state_dict.items():
        if 'denoise_fn' in k:
            out_state_dict[k.replace('denoise_fn', 'velocity_fn')] = v
        elif 'spec_min' in k or 'spec_max' in k:
            continue
        else:
            out_state_dict[k] = v
    torch.save({'category': ckpt['category'], 'state_dict': out_state_dict}, out_ckpt)


if __name__ == '__main__':
    migrate_reflow()

The following configuration keys are renamed:

  • diffusion_type: RectifiedFlow -> diffusion_type: reflow
  • diff_decoder_type -> backbone_type
  • diff_loss_type -> main_loss_type
  • lognorm loss now has its own switch: main_loss_log_norm (only for Rectified Flow models)

@yqzhishen yqzhishen linked a pull request Apr 17, 2024 that will close this issue
@yqzhishen yqzhishen unpinned this issue May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant