Skip to content

v0.27.0: Stable Cascade, Playground v2.5, EDM-style training, IP-Adapter image embeds, and more

Compare
Choose a tag to compare
@sayakpaul sayakpaul released this 14 Mar 16:00
· 296 commits to main since this release

Stable Cascade

We are adding support for a new text-to-image model building on W眉rstchen called Stable Cascade, which comes with a non-commercial license. The Stable Cascade line of pipelines differs from Stable Diffusion in that they are built upon three distinct models and allow for hierarchical compression of image patients, achieving remarkable outputs.

from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
import torch

prior = StableCascadePriorPipeline.from_pretrained(
    "stabilityai/stable-cascade-prior",
    torch_dtype=torch.bfloat16,
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image_emb = prior(prompt=prompt).image_embeddings[0]

decoder = StableCascadeDecoderPipeline.from_pretrained(
    "stabilityai/stable-cascade",
    torch_dtype=torch.bfloat16,
).to("cuda")

image = pipe(image_embeddings=image_emb, prompt=prompt).images[0]
image

馃摐聽Check out the docs here to know more about the model.

Note: You will need a torch>=2.2.0 to use the torch.bfloat16 data type with the Stable Cascade pipeline.

Playground v2.5

PlaygroundAI released a new v2.5 model (playgroundai/playground-v2.5-1024px-aesthetic), which particularly excels at aesthetics. The model closely follows the architecture of Stable Diffusion XL, except for a few tweaks. This release comes with support for this model:

from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained(
    "playgroundai/playground-v2.5-1024px-aesthetic",
    torch_dtype=torch.float16,
    variant="fp16",
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt=prompt, num_inference_steps=50, guidance_scale=3).images[0]
image

Loading from the original single-file checkpoint is also supported:

from diffusers import StableDiffusionXLPipeline, EDMDPMSolverMultistepScheduler
import torch

url = "https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/playground-v2.5-1024px-aesthetic.safetensors"
pipeline = StableDiffusionXLPipeline.from_single_file(url)
pipeline.to(device="cuda", dtype=torch.float16)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image  = pipeline(prompt=prompt, guidance_scale=3.0).images[0]
image.save("playground_test_image.png")

You can also perform LoRA DreamBooth training with the playgroundai/playground-v2.5-1024px-aesthetic checkpoint:

accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic"  \
  --instance_data_dir="dog" \
  --output_dir="dog-playground-lora" \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --use_8bit_adam \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

To know more, follow the instructions here.

EDM-style training support

EDM refers to the training and sampling techniques introduced in the following paper: Elucidating the Design Space of Diffusion-Based Generative Models. We have introduced support for training using the EDM formulation in our train_dreambooth_lora_sdxl.py script.

To train stabilityai/stable-diffusion-xl-base-1.0 using the EDM formulation, you just have to specify the --do_edm_style_training flag in your training command, and voila 馃

If you鈥檙e interested in extending this formulation to other training scripts, we refer you to this PR.

New schedulers with the EDM formulation

To better support the Playground v2.5 model and EDM-style training in general, we are bringing support for EDMDPMSolverMultistepScheduler and EDMEulerScheduler. These support the EDM formulations of the DPMSolverMultistepScheduler and EulerDiscreteScheduler, respectively.

Trajectory Consistency Distillation

Trajectory Consistency Distillation (TCD) enables a model to generate higher quality and more detailed images with fewer steps. Moreover, owing to the effective error mitigation during the distillation process, TCD demonstrates superior performance even under conditions of large inference steps. It was proposed in Trajectory Consistency Distillation.

This release comes with the support of a TCDScheduler that enables this kind of fast sampling. Much like LCM-LoRA, TCD requires an additional adapter for the acceleration. The code snippet below shows a usage:

import torch
from diffusers import StableDiffusionXLPipeline, TCDScheduler

device = "cuda"
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
tcd_lora_id = "h1t/TCD-SDXL-LoRA"

pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)

pipe.load_lora_weights(tcd_lora_id)
pipe.fuse_lora()

prompt = "Painting of the orange cat Otto von Garfield, Count of Bismarck-Sch枚nhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna."

image = pipe(
    prompt=prompt,
    num_inference_steps=4,
    guidance_scale=0,
    eta=0.3, 
    generator=torch.Generator(device=device).manual_seed(0),
).images[0]

tcd_image

馃摐聽Check out the docs here to know more about TCD.

Many thanks to @mhh0318 for contributing the TCDScheduler in #7174 and the guide in #7259.

IP-Adapter image embeddings and masking

All the pipelines supporting IP-Adapter accept a聽ip_adapter_image_embeds聽argument. If you need to run the IP-Adapter multiple times with the same image, you can encode the image once and save the embedding to the disk. This saves computation time and is especially useful when building UIs. Additionally, ComfyUI image embeddings for IP-Adapters are fully compatible in Diffusers and should work out-of-box.

We have also introduced support for providing binary masks to specify which portion of the output image should be assigned to an IP-Adapter. For each input IP-Adapter image, a binary mask and an IP-Adapter must be provided. Thanks to @fabiorigano for contributing this feature through #6847.

馃摐聽To know about the exact usage of both of the above, refer to our official guide.

We thank our community members, @fabiorigano, @asomoza, and @cubiq, for their guidance and input on these features.

Guide on merging LoRAs

Merging LoRAs can be a fun and creative way to create new and unique images. Diffusers provides merging support with the set_adapters method which concatenates the weights of the LoRAs to merge.

Now, Diffusers also supports the add_weighted_adapter method from the PEFT library, unlocking more efficient merging method like TIES, DARE, linear, and even combinations of these merging methods like dare_ties.

馃摐聽Take a look at the Merge LoRAs guide to learn more about merging in Diffusers.

LEDITS++

We are adding support to the real image editing technique called LEDITS++: Limitless Image Editing using Text-to-Image Models, a parameter-free method, requiring no fine-tuning nor any optimization.
To edit real images, the LEDITS++ pipelines first invert the image DPM-solver++ scheduler that facilitates editing with as little as 20 total diffusion steps for inversion and inference combined. LEDITS++ guidance is defined such that it both reflects the direction of the edit (if we want to push away from/towards the edit concept) and the strength of the effect. The guidance also includes a masking term focused on relevant image regions which, for multiple edits especially, ensures that the corresponding guidance terms for each concept remain mostly isolated, limiting interference.

The code snippet below shows a usage:

import torch
import PIL
import requests
from io import BytesIO
from diffusers import LEditsPPPipelineStableDiffusionXL, AutoencoderKL

device = "cuda"
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
    base_model_id, 
    vae=vae, 
    torch_dtype=torch.float16
).to(device)

def download_image(url):
    response = requests.get(url)
    return PIL.Image.open(BytesIO(response.content)).convert("RGB")

img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
image = download_image(img_url)

_ = pipe.invert(
    image = image,
    num_inversion_steps=50,
    skip=0.2
)

edited_image = pipe(
    editing_prompt=["tennis ball","tomato"],
    reverse_editing_direction=[True,False],
    edit_guidance_scale=[5.0,10.0],
    edit_threshold=[0.9,0.85],)
Tennis ball Tomato ball

馃摐聽Check out the docs here to learn more about LEDITS++.

Thanks to @manuelbrack for contributing this in #6074.

All commits

  • Fix flaky IP Adapter test by @DN6 in #6960
  • Move SDXL T2I Adapter lora test into PEFT workflow by @DN6 in #6965
  • Allow passing config_file argument to ControlNetModel when using from_single_file by @DN6 in #6959
  • [PEFT / docs] Add a note about torch.compile by @younesbelkada in #6864
  • [Core] Harmonize single file ckpt model loading by @sayakpaul in #6971
  • fix: controlnet inpaint single file. by @sayakpaul in #6975
  • [docs] IP-Adapter by @stevhliu in #6897
  • fix IPAdapter unload_ip_adapter test by @yiyixuxu in #6972
  • [advanced sdxl lora script] - fix #6967 bug when using prior preservation loss by @linoytsaban in #6968
  • [IP Adapters] feat: allow low_cpu_mem_usage in ip adapter loading by @sayakpaul in #6946
  • Fix diffusers import prompt2prompt by @ihkap11 in #6927
  • add: peft to the benchmark workflow by @sayakpaul in #6989
  • Fix procecss process by @co63oc in #6591
  • Standardize model card for textual inversion sdxl by @Stepheni12 in #6963
  • Update textual_inversion.py by @Bhavay-2001 in #6952
  • [docs] Fix callout by @stevhliu in #6998
  • [docs] Video generation by @stevhliu in #6701
  • start depcrecation cycle for lora_attention_proc 馃憢 by @sayakpaul in #7007
  • Add documentation for strength parameter in Controlnet_img2img pipelines by @tlpss in #6951
  • Fixed typos in dosctrings of init() and in forward() of Unet3DConditionModel by @MK-2012 in #6663
  • [SVD] fix a bug when passing image as tensor by @yiyixuxu in #6999
  • Fix deprecation warning for torch.utils._pytree._register_pytree_node in PyTorch 2.2 by @zyinghua in #7008
  • [IP2P] Make text encoder truly optional in InstructPi2Pix by @sayakpaul in #6995
  • IP-Adapter attention masking by @fabiorigano in #6847
  • Fix Pixart Slow Tests by @DN6 in #6962
  • [from_single_file] pass torch_dtype to set_module_tensor_to_device by @yiyixuxu in #6994
  • [Refactor] FreeInit for AnimateDiff based pipelines by @DN6 in #6874
  • [Community Pipelines]Accelerate inference of stable diffusion xl (SDXL) by IPEX on CPU by @ustcuna in #6683
  • Add section on AnimateLCM to docs by @DN6 in #7024
  • IP-Adapter support for StableDiffusionXLControlNetInpaintPipeline by @rootonchair in #6941
  • Supper IP Adapter weight loading in StableDiffusionXLControlNetInpaintPipeline by @tontan2545 in #7031
  • Fix alt text and image links in AnimateLCM docs by @DN6 in #7029
  • Update ControlNet Inpaint single file test by @DN6 in #7022
  • Fix load_model_dict_into_meta for ControlNet from_single_file by @DN6 in #7034
  • Remove disable_full_determinism from StableVideoDiffusion xformers test. by @DN6 in #7039
  • update header by @pravdomil in #6596
  • fix doc example for fom_single_file by @yiyixuxu in #7015
  • Fix typos in text_to_image examples by @standardAI in #7050
  • Update checkpoint_merger pipeline to pass the "variant" argument by @lstein in #6670
  • allow explicit tokenizer & text_encoder in unload_textual_inversion by @H3zi in #6977
  • re-add unet refactor PR by @yiyixuxu in #7044
  • IPAdapterTesterMixin by @a-r-r-o-w in #6862
  • [Refactor] save_model_card function in text_to_image examples by @standardAI in #7051
  • Fix typos by @standardAI in #7068
  • Fix docstring of community pipeline imagic by @chongdashu in #7062
  • Change images to image. The variable images is not used anywhere by @bimsarapathiraja in #7074
  • fix: TensorRTStableDiffusionPipeline cannot set guidance_scale by @caiyueliang in #7065
  • [Refactor] StableDiffusionReferencePipeline inheriting from DiffusionPipeline by @standardAI in #7071
  • Fix truthy-ness condition in pipelines that use denoising_start by @a-r-r-o-w in #6912
  • Fix head_to_batch_dim for IPAdapterAttnProcessor by @fabiorigano in #7077
  • [docs] Minor updates by @stevhliu in #7063
  • Modularize Dreambooth LoRA SD inferencing during and after training by @rootonchair in #6654
  • Modularize Dreambooth LoRA SDXL inferencing during and after training by @rootonchair in #6655
  • [Community] Bug fix + Latest IP-Adapter impl. for AnimateDiff img2vid/controlnet by @a-r-r-o-w in #7086
  • Pass use_linear_projection parameter to mid block in UNetMotionModel by @Stepheni12 in #7035
  • Resize image before crop by @jiqing-feng in #7095
  • Small change to download in dance diffusion convert script by @DN6 in #7070
  • Fix EMA in train_text_to_image_sdxl.py by @standardAI in #7048
  • Make LoRACompatibleConv padding_mode work. by @jinghuan-Chen in #6031
  • [Easy] edit issue and PR templates by @sayakpaul in #7092
  • FIX [PEFT / Core] Copy the state dict when passing it to load_lora_weights by @younesbelkada in #7058
  • [Core] pass revision in the loading_kwargs. by @sayakpaul in #7019
  • [Examples] Multiple enhancements to the ControlNet training scripts by @sayakpaul in #7096
  • move to uv in the Dockerfiles. by @sayakpaul in #7094
  • Add tests to check configs when using single file loading by @DN6 in #7099
  • denormalize latents with the mean and std if available by @patil-suraj in #7111
  • [Dockerfile] remove uv from docker jax tpu by @sayakpaul in #7115
  • Add EDMEulerScheduler by @patil-suraj in #7109
  • add DPM scheduler with EDM formulation by @patil-suraj in #7120
  • [Docs] Fix typos by @standardAI in #7118
  • DPMSolverMultistep add rescale_betas_zero_snr by @Beinsezii in #7097
  • [Tests] make test steps dependent on certain things and general cleanup of the workflows by @sayakpaul in #7026
  • fix kwarg in the SDXL LoRA DreamBooth by @sayakpaul in #7124
  • [Diffusers CI] Switch slow test runners by @DN6 in #7123
  • [stalebot] don't close the issue if the stale label is removed by @yiyixuxu in #7106
  • refactor: move model helper function in pipeline to a mixin class by @ultranity in #6571
  • [docs] unet type hints by @a-r-r-o-w in #7134
  • use uv for installing stuff in the workflows. by @sayakpaul in #7116
  • limit documentation workflow runs for relevant changes. by @sayakpaul in #7125
  • add: support for notifying the maintainers about the docker ci status. by @sayakpaul in #7113
  • Fix setting fp16 dtype in AnimateDiff convert script. by @DN6 in #7127
  • [Docs] Fix typos by @standardAI in #7131
  • [ip-adapter] refactor prepare_ip_adapter_image_embeds and skip load image_encoder by @yiyixuxu in #7016
  • [CI] fix path filtering in the documentation workflows by @sayakpaul in #7153
  • [Urgent][Docker CI] pin uv version for now and a minor change in the Slack notification by @sayakpaul in #7155
  • Fix LCM benchmark test by @sayakpaul in #7158
  • [CI] Remove max parallel flag on slow test runners by @DN6 in #7162
  • Fix vae_encodings_fn hash in train_text_to_image_sdxl.py by @lhoestq in #7171
  • fix: loading problem for sdxl lora dreambooth by @sayakpaul in #7166
  • Map speedup by @kopyl in #6745
  • [stalebot] fix a bug by @yiyixuxu in #7156
  • Support EDM-style training in DreamBooth LoRA SDXL script by @sayakpaul in #7126
  • Fix PixArt 256px inference by @lawrence-cj in #6789
  • [ip-adapter] fix problem using embeds with the plus version of ip adapters by @asomoza in #7189
  • feat: add ip adapter benchmark by @sayakpaul in #6936
  • [Docs] more elaborate example for peft torch.compile by @sayakpaul in #7161
  • adding callback_on_step_end for StableDiffusionLDM3DPipeline by @rootonchair in #7149
  • Update requirements.txt to remove huggingface-cli by @sayakpaul in #7202
  • [advanced dreambooth lora sdxl] add DoRA training feature by @linoytsaban in #7072
  • FIx torch and cuda version in ONNX tests by @DN6 in #7164
  • [training scripts] add tags of diffusers-training by @linoytsaban in #7206
  • fix a bug in from_config by @yiyixuxu in #7192
  • Fix: UNet2DModel::init type hints; fixes issue #4806 by @fpgaminer in #7175
  • Fix typos by @standardAI in #7181
  • Enable PyTorch's FakeTensorMode for EulerDiscreteScheduler scheduler by @thiagocrepaldi in #7151
  • [docs] Improve SVD pipeline docs by @a-r-r-o-w in #7087
  • [Docs] Update callback.md code example by @rootonchair in #7150
  • [Core] errors should be caught as soon as possible. by @sayakpaul in #7203
  • [Community] PromptDiffusion Pipeline by @iczaw in #6752
  • add TCD Scheduler by @mhh0318 in #7174
  • SDXL Turbo support and example launch by @bram-w in #6473
  • [bug] Fix float/int guidance scale not working in StableVideoDiffusionPipeline by @JinayJain in #7143
  • [Pipiline] Wuerstchen v3 aka Stable Cascasde pipeline by @kashif in #6487
  • Update train_dreambooth_lora_sdxl_advanced.py by @landmann in #7227
  • [Core] move out the utilities from pipeline_utils.py by @sayakpaul in #7234
  • Refactor Prompt2Prompt: Inherit from DiffusionPipeline by @ihkap11 in #7211
  • add DoRA training feature to sdxl dreambooth lora script by @linoytsaban in #7235
  • fix: remove duplicated code in TemporalBasicTransformerBlock. by @AsakusaRinne in #7212
  • [Examples] fix: prior preservation setting in DreamBooth LoRA SDXL script. by @sayakpaul in #7242
  • fix: support for loading playground v2.5 single file checkpoint. by @sayakpaul in #7230
  • Raise an error when trying to use SD Cascade Decoder with dtype bfloat16 and torch < 2.2 by @DN6 in #7244
  • Remove the line. Using it create wrong output by @bimsarapathiraja in #7075
  • [docs] Merge LoRAs by @stevhliu in #7213
  • use self.device by @pravdomil in #6595
  • [docs] Community tips by @stevhliu in #7137
  • [Core] throw error when patch inputs and layernorm are provided for Transformers2D by @sayakpaul in #7200
  • [Tests] fix: VAE tiling tests when setting the right device by @sayakpaul in #7246
  • [Utils] Improve " # Copied from ..." statements in the pipelines by @sayakpaul in #6917
  • [Easy] fix: save_model_card utility of the DreamBooth SDXL LoRA script by @sayakpaul in #7258
  • Make mid block optional for flax UNet by @mar-muel in #7083
  • Solve missing clip_sample implementation in FlaxDDIMScheduler. by @hi-sushanta in #7017
  • [Tests] fix config checking tests by @sayakpaul in #7247
  • [docs] IP-Adapter image embedding by @stevhliu in #7226
  • Adds denoising_end parameter to ControlNetPipeline for SDXL by @UmerHA in #6175
  • Add npu support by @MengqingCao in #7144
  • [Community Pipeline] Skip Marigold depth_colored with color_map=None by @qqii in #7170
  • update the signature of from_single_file by @yiyixuxu in #7216
  • [UNet_Spatio_Temporal_Condition] fix default num_attention_heads in unet_spatio_temporal_condition by @Wang-Xiaodong1899 in #7205
  • [docs/nits] Fix return values based on return_dict and minor doc updates by @a-r-r-o-w in #7105
  • [Chore] remove tf mention by @sayakpaul in #7245
  • Fix gmflow_dir by @pravdomil in #6583
  • Support latents_mean and latents_std by @haofanwang in #7132
  • Inline InputPadder by @pravdomil in #6582
  • [Dockerfiles] add: a workflow to check if docker containers can be built in case of modifications by @sayakpaul in #7129
  • instruct pix2pix pipeline: remove sigma scaling when computing classifier free guidance by @erliding in #7006
  • Change export_to_video default by @DN6 in #6990
  • [Chore] switch to logger.warning by @sayakpaul in #7289
  • [LoRA] use the PyTorch classes wherever needed and start depcrecation cycles by @sayakpaul in #7204
  • Add single file support for Stable Cascade by @DN6 in #7274
  • Fix passing pooled prompt embeds to Cascade Decoder and Combined Pipeline by @DN6 in #7287
  • Fix loading Img2Img refiner components in from_single_file by @DN6 in #7282
  • [Chore] clean residue from copy-pasting in the UNet single file loader by @sayakpaul in #7295
  • Update Cascade documentation by @DN6 in #7257
  • Update Stable Cascade Conversion Scripts by @DN6 in #7271
  • [Pipeline] Add LEDITS++ pipelines by @manuelbrack in #6074
  • [PyPI publishing] feat: automate the process of pypi publication to some extent. by @sayakpaul in #7270
  • add: support for notifying maintainers about the nightly test status by @sayakpaul in #7117
  • Fix Wrong Text-encoder Grad Setting in Custom_Diffusion Training by @Rbrq03 in #7302
  • Add Intro page of TCD by @mhh0318 in #7259
  • Fix typos in UNet2DConditionModel documentation by @alexanderbonnet in #7291
  • Change step_offset scheduler docstrings by @Beinsezii in #7128
  • update get_order_list if statement by @kghamilton89 in #7309
  • add: pytest log installation by @sayakpaul in #7313
  • [Tests] Fix incorrect constant in VAE scaling test. by @DN6 in #7301
  • log loss per image by @noskill in #7278
  • add edm schedulers in doc by @patil-suraj in #7319
  • [Advanced DreamBooth LoRA SDXL] Support EDM-style training (follow up of #7126) by @linoytsaban in #7182
  • Update Cascade Tests by @DN6 in #7324
  • Release: v0.27.0 by @DN6 (direct commit on v0.27.0-release)

Significant community contributions

The following contributors have made significant changes to the library over the last release:

  • @ihkap11
    • Fix diffusers import prompt2prompt (#6927)
    • Refactor Prompt2Prompt: Inherit from DiffusionPipeline (#7211)
  • @ustcuna
    • [Community Pipelines]Accelerate inference of stable diffusion xl (SDXL) by IPEX on CPU (#6683)
  • @rootonchair
    • IP-Adapter support for StableDiffusionXLControlNetInpaintPipeline (#6941)
    • Modularize Dreambooth LoRA SD inferencing during and after training (#6654)
    • Modularize Dreambooth LoRA SDXL inferencing during and after training (#6655)
    • adding callback_on_step_end for StableDiffusionLDM3DPipeline (#7149)
    • [Docs] Update callback.md code example (#7150)
  • @standardAI
    • Fix typos in text_to_image examples (#7050)
    • [Refactor] save_model_card function in text_to_image examples (#7051)
    • Fix typos (#7068)
    • [Refactor] StableDiffusionReferencePipeline inheriting from DiffusionPipeline (#7071)
    • Fix EMA in train_text_to_image_sdxl.py (#7048)
    • [Docs] Fix typos (#7118)
    • [Docs] Fix typos (#7131)
    • Fix typos (#7181)
  • @a-r-r-o-w
    • IPAdapterTesterMixin (#6862)
    • Fix truthy-ness condition in pipelines that use denoising_start (#6912)
    • [Community] Bug fix + Latest IP-Adapter impl. for AnimateDiff img2vid/controlnet (#7086)
    • [docs] unet type hints (#7134)
    • [docs] Improve SVD pipeline docs (#7087)
    • [docs/nits] Fix return values based on return_dict and minor doc updates (#7105)
  • @ultranity
    • refactor: move model helper function in pipeline to a mixin class (#6571)
  • @iczaw
    • [Community] PromptDiffusion Pipeline (#6752)
  • @mhh0318
    • add TCD Scheduler (#7174)
    • Add Intro page of TCD (#7259)
  • @manuelbrack
    • [Pipeline] Add LEDITS++ pipelines (#6074)