Skip to content

v0.26.0: New video pipelines, single-file checkpoint revamp, multi IP-Adapter inference with multiple images

Compare
Choose a tag to compare
@sayakpaul sayakpaul released this 01 Feb 02:03
· 531 commits to main since this release

This new release comes with two new video pipelines, a more unified and consistent experience for single-file checkpoint loading, support for multiple IP-Adapters’ inference with multiple reference images, and more.

I2VGenXL

I2VGenXL is an image-to-video pipeline, proposed in I2VGen-XL: High-Quality Image-to-Video Synthesis via Cascaded Diffusion Models.

import torch
from diffusers import I2VGenXLPipeline
from diffusers.utils import export_to_gif, load_image

repo_id = "ali-vilab/i2vgen-xl"
pipeline = I2VGenXLPipeline.from_pretrained(repo_id, torch_dtype=torch.float16).to("cuda")
pipeline.enable_model_cpu_offload()

image_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0001.jpg"
image = load_image(image_url).convert("RGB")
prompt = "A green frog floats on the surface of the water on green lotus leaves, with several pink lotus flowers, in a Chinese painting style."
negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
generator = torch.manual_seed(8888)

frames = pipeline(
    prompt=prompt,
    image=image,
    num_inference_steps=50,
    negative_prompt=negative_prompt,
    generator=generator,
).frames
export_to_gif(frames[0], "i2v.gif")
masterpiece, bestquality, sunset.
library

📜 Check out the docs here.

PIA

PIA is a Personalized Image Animator, that aligns with condition images, controls motion by text, and is compatible with various T2I models without specific tuning. PIA uses a base T2I model with temporal alignment layers for image animation. A key component of PIA is the condition module, which transfers appearance information for individual frame synthesis in the latent space, thus allowing a stronger focus on motion alignment. PIA was introduced in PIA: Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models.

import torch
from diffusers import (
    EulerDiscreteScheduler,
    MotionAdapter,
    PIAPipeline,
)
from diffusers.utils import export_to_gif, load_image

adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter")
pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16)

pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()

image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
)
image = image.resize((512, 512))
prompt = "cat in a field"
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"

generator = torch.Generator("cpu").manual_seed(0)
output = pipe(image=image, prompt=prompt, generator=generator)
frames = output.frames[0]
export_to_gif(frames, "pia-animation.gif")
masterpiece, bestquality, sunset.
cat in a field

📜 Check out the docs here.

Multiple IP-Adapters + Multiple reference images support (“Instant LoRA” Feature)

IP-Adapters are becoming quite popular, so we have added support for performing inference multiple IP-Adapters and multiple reference images! Thanks to @asomoza for their help. Get started with the code below:

import torch
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from transformers import CLIPVisionModelWithProjection
from diffusers.utils import load_image

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter", 
    subfolder="models/image_encoder",
    torch_dtype=torch.float16,
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    image_encoder=image_encoder,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"])
pipeline.set_ip_adapter_scale([0.7, 0.3])

pipeline.enable_model_cpu_offload()

face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")

style_folder = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy"
style_images =  [load_image(f"{style_folder}/img{i}.png") for i in range(10)]

generator = torch.Generator(device="cpu").manual_seed(0)

image = pipeline(
    prompt="wonderwoman",
    ip_adapter_image=[style_images, face_image],
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=50
    generator=generator,
).images[0]

Reference style images:

Reference face Image Output Image

📜 Check out the docs here.

Single-file checkpoint loading

from_single_file() utility has been refactored for better readability and to follow similar semantics as from_pretrained() . Support for loading single file checkpoints and configs from URLs has also been added.

DPM scheduler fix

We introduced a fix for DPM schedulers, so now you can use it with SDXL to generate high-quality images in fewer steps than the Euler scheduler.

Apart from these, we have done a myriad of refactoring to improve the library design and will continue to do so in the coming days.

All commits

Significant community contributions

The following contributors have made significant changes to the library over the last release:

  • @a-r-r-o-w
    • [Community] Experimental AnimateDiff Image to Video (open to improvements) (#6509)
    • AnimateDiff Video to Video (#6328)
    • [docs] AnimateDiff Video-to-Video (#6712)
    • fix community README (#6645)
  • @ultranity
    • refactor: extract init/forward function in UNet2DConditionModel (#6478)
  • @lawrence-cj
  • @ayushtues
    • [WIP][Community Pipeline] InstaFlow! One-Step Stable Diffusion with Rectified Flow (#6057)
  • @haofanwang
    • Add InstantID Pipeline (#6673)
    • [Fix bugs] pipeline_controlnet_sd_xl.py (#6653)
  • @brandostrong
    • SD 1.5 Support For Advanced Lora Training (train_dreambooth_lora_sdxl_advanced.py) (#6449)
  • @dg845
    • Add Community Example Consistency Training Script (#6717)
    • Add UFOGenScheduler to Community Examples (#6650)
    • Fix bug in ResnetBlock2D.forward where LoRA Scale gets Overwritten (#6736)