Skip to content

Releases: huggingface/diffusers

v0.27.2: Fix scheduler `add_noise` 🐞, embeddings in StableCascade, `scale` when using LoRA

20 Mar 01:57
Compare
Choose a tag to compare

All commits

  • [scheduler] fix a bug in add_noise by @yiyixuxu in #7386
  • [LoRA] fix cross_attention_kwargs problems and tighten tests by @sayakpaul in #7388
  • Fix issue with prompt embeds and latents in SD Cascade Decoder with multiple image embeddings for a single prompt. by @DN6 in #7381

v0.27.1: Clear `scale` argument confusion for LoRA

19 Mar 03:59
Compare
Choose a tag to compare

All commits

  • Release: v0.27.0 by @DN6 (direct commit on v0.27.1-patch)
  • [LoRA] pop the LoRA scale so that it doesn't get propagated to the weeds by @sayakpaul in #7338
  • Release: 0.27.1-patch by @sayakpaul (direct commit on v0.27.1-patch)

v0.27.0: Stable Cascade, Playground v2.5, EDM-style training, IP-Adapter image embeds, and more

14 Mar 16:00
Compare
Choose a tag to compare

Stable Cascade

We are adding support for a new text-to-image model building on Würstchen called Stable Cascade, which comes with a non-commercial license. The Stable Cascade line of pipelines differs from Stable Diffusion in that they are built upon three distinct models and allow for hierarchical compression of image patients, achieving remarkable outputs.

from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
import torch

prior = StableCascadePriorPipeline.from_pretrained(
    "stabilityai/stable-cascade-prior",
    torch_dtype=torch.bfloat16,
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image_emb = prior(prompt=prompt).image_embeddings[0]

decoder = StableCascadeDecoderPipeline.from_pretrained(
    "stabilityai/stable-cascade",
    torch_dtype=torch.bfloat16,
).to("cuda")

image = pipe(image_embeddings=image_emb, prompt=prompt).images[0]
image

📜 Check out the docs here to know more about the model.

Note: You will need a torch>=2.2.0 to use the torch.bfloat16 data type with the Stable Cascade pipeline.

Playground v2.5

PlaygroundAI released a new v2.5 model (playgroundai/playground-v2.5-1024px-aesthetic), which particularly excels at aesthetics. The model closely follows the architecture of Stable Diffusion XL, except for a few tweaks. This release comes with support for this model:

from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained(
    "playgroundai/playground-v2.5-1024px-aesthetic",
    torch_dtype=torch.float16,
    variant="fp16",
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt=prompt, num_inference_steps=50, guidance_scale=3).images[0]
image

Loading from the original single-file checkpoint is also supported:

from diffusers import StableDiffusionXLPipeline, EDMDPMSolverMultistepScheduler
import torch

url = "https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/playground-v2.5-1024px-aesthetic.safetensors"
pipeline = StableDiffusionXLPipeline.from_single_file(url)
pipeline.to(device="cuda", dtype=torch.float16)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image  = pipeline(prompt=prompt, guidance_scale=3.0).images[0]
image.save("playground_test_image.png")

You can also perform LoRA DreamBooth training with the playgroundai/playground-v2.5-1024px-aesthetic checkpoint:

accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic"  \
  --instance_data_dir="dog" \
  --output_dir="dog-playground-lora" \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --use_8bit_adam \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

To know more, follow the instructions here.

EDM-style training support

EDM refers to the training and sampling techniques introduced in the following paper: Elucidating the Design Space of Diffusion-Based Generative Models. We have introduced support for training using the EDM formulation in our train_dreambooth_lora_sdxl.py script.

To train stabilityai/stable-diffusion-xl-base-1.0 using the EDM formulation, you just have to specify the --do_edm_style_training flag in your training command, and voila 🤗

If you’re interested in extending this formulation to other training scripts, we refer you to this PR.

New schedulers with the EDM formulation

To better support the Playground v2.5 model and EDM-style training in general, we are bringing support for EDMDPMSolverMultistepScheduler and EDMEulerScheduler. These support the EDM formulations of the DPMSolverMultistepScheduler and EulerDiscreteScheduler, respectively.

Trajectory Consistency Distillation

Trajectory Consistency Distillation (TCD) enables a model to generate higher quality and more detailed images with fewer steps. Moreover, owing to the effective error mitigation during the distillation process, TCD demonstrates superior performance even under conditions of large inference steps. It was proposed in Trajectory Consistency Distillation.

This release comes with the support of a TCDScheduler that enables this kind of fast sampling. Much like LCM-LoRA, TCD requires an additional adapter for the acceleration. The code snippet below shows a usage:

import torch
from diffusers import StableDiffusionXLPipeline, TCDScheduler

device = "cuda"
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
tcd_lora_id = "h1t/TCD-SDXL-LoRA"

pipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device)
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)

pipe.load_lora_weights(tcd_lora_id)
pipe.fuse_lora()

prompt = "Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna."

image = pipe(
    prompt=prompt,
    num_inference_steps=4,
    guidance_scale=0,
    eta=0.3, 
    generator=torch.Generator(device=device).manual_seed(0),
).images[0]

tcd_image

📜 Check out the docs here to know more about TCD.

Many thanks to @mhh0318 for contributing the TCDScheduler in #7174 and the guide in #7259.

IP-Adapter image embeddings and masking

All the pipelines supporting IP-Adapter accept a ip_adapter_image_embeds argument. If you need to run the IP-Adapter multiple times with the same image, you can encode the image once and save the embedding to the disk. This saves computation time and is especially useful when building UIs. Additionally, ComfyUI image embeddings for IP-Adapters are fully compatible in Diffusers and should work out-of-box.

We have also introduced support for providing binary masks to specify which portion of the output image should be assigned to an IP-Adapter. For each input IP-Adapter image, a binary mask and an IP-Adapter must be provided. Thanks to @fabiorigano for contributing this feature through #6847.

📜 To know about the exact usage of both of the above, refer to our official guide.

We thank our community members, @fabiorigano, @asomoza, and @cubiq, for their guidance and input on these features.

Guide on merging LoRAs

Merging LoRAs can be a fun and creative way to create new and unique images. Diffusers provides merging support with the set_adapters method which concatenates the weights of the LoRAs to merge.

Now, Diffusers also supports the add_weighted_adapter method from the PEFT library, unlocking more efficient merging method like TIES, DARE, linear, and even combinations of these merging methods like dare_ties.

📜 Take a look at the Merge LoRAs guide to learn more about merging in Diffusers.

LEDITS++

We are adding support to the real image editing technique called LEDITS++: Limitless Image Editing using Text-to-Image Models, a parameter-free method, requiring no fine-tuning nor any optimization.
To edit real images, the LEDITS++ pipelines first invert the image DPM-solver++ scheduler that facilitates editing with as little as 20 total diffusion steps for inversion and inference combined. LEDITS++ guidance is defined such that it both reflects the direction of the edit (if we want to push away from/towards the edit concept) and the strength of the effect. The guidance also includes a masking term focused on relevant image regions which, for multiple edits especially, ensures that the corresponding guidance terms for each concept remain mostly isolated, limiting interference.

The code snippet below shows a usage:

import torch
import PIL
import requests
from io import BytesIO
from diffusers import LEditsPPPipelineStableDiffusionXL, AutoencoderKL

device = "cuda"
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
    base_model_id, 
    vae=vae, 
    torch_dtype=torch.float16
).to(device)

def download_image(url):
    response = requests.get(url)
    return PIL.Image.open(BytesIO(response.content)).convert("RGB")

img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
image = download_image(img_url)

_ = pipe.invert(
    image = image,
    num_inversion_steps=50,
    skip=0.2
)

edited_image = pipe(
    editing_prompt...
Read more

v0.26.3: Patch release to fix DPMSolverSinglestepScheduler and configuring VAE from single file mixin

13 Feb 09:31
Compare
Choose a tag to compare

All commits

  • Fix configuring VAE from single file mixin by @DN6 in #6950
  • [DPMSolverSinglestepScheduler] correct get_order_list for solver_order=2and lower_order_final=True by @yiyixuxu in #6953

v0.26.2: Patch fix for adding `self.use_ada_layer_norm_*` params back to `BasicTransformerBlock`

06 Feb 02:13
Compare
Choose a tag to compare

In v0.26.0, we introduced a bug 🐛 in the BasicTransformerBlock by removing some boolean flags. This caused many popular libraries tomesd to break. We have fixed that in this release. Thanks to @vladmandic for bringing this to our attention.

All commits

  • add self.use_ada_layer_norm_* params back to BasicTransformerBlock by @yiyixuxu in #6841

v0.26.1: Patch release to fix `torchvision` dependency

02 Feb 09:23
Compare
Choose a tag to compare

In the v0.26.0 release, we slipped in the torchvision library as a required library, which shouldn't have been the case. This is now fixed.

All commits

v0.26.0: New video pipelines, single-file checkpoint revamp, multi IP-Adapter inference with multiple images

01 Feb 02:03
Compare
Choose a tag to compare

This new release comes with two new video pipelines, a more unified and consistent experience for single-file checkpoint loading, support for multiple IP-Adapters’ inference with multiple reference images, and more.

I2VGenXL

I2VGenXL is an image-to-video pipeline, proposed in I2VGen-XL: High-Quality Image-to-Video Synthesis via Cascaded Diffusion Models.

import torch
from diffusers import I2VGenXLPipeline
from diffusers.utils import export_to_gif, load_image

repo_id = "ali-vilab/i2vgen-xl"
pipeline = I2VGenXLPipeline.from_pretrained(repo_id, torch_dtype=torch.float16).to("cuda")
pipeline.enable_model_cpu_offload()

image_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0001.jpg"
image = load_image(image_url).convert("RGB")
prompt = "A green frog floats on the surface of the water on green lotus leaves, with several pink lotus flowers, in a Chinese painting style."
negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
generator = torch.manual_seed(8888)

frames = pipeline(
    prompt=prompt,
    image=image,
    num_inference_steps=50,
    negative_prompt=negative_prompt,
    generator=generator,
).frames
export_to_gif(frames[0], "i2v.gif")
masterpiece, bestquality, sunset.
library

📜 Check out the docs here.

PIA

PIA is a Personalized Image Animator, that aligns with condition images, controls motion by text, and is compatible with various T2I models without specific tuning. PIA uses a base T2I model with temporal alignment layers for image animation. A key component of PIA is the condition module, which transfers appearance information for individual frame synthesis in the latent space, thus allowing a stronger focus on motion alignment. PIA was introduced in PIA: Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models.

import torch
from diffusers import (
    EulerDiscreteScheduler,
    MotionAdapter,
    PIAPipeline,
)
from diffusers.utils import export_to_gif, load_image

adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter")
pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16)

pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()

image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
)
image = image.resize((512, 512))
prompt = "cat in a field"
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"

generator = torch.Generator("cpu").manual_seed(0)
output = pipe(image=image, prompt=prompt, generator=generator)
frames = output.frames[0]
export_to_gif(frames, "pia-animation.gif")
masterpiece, bestquality, sunset.
cat in a field

📜 Check out the docs here.

Multiple IP-Adapters + Multiple reference images support (“Instant LoRA” Feature)

IP-Adapters are becoming quite popular, so we have added support for performing inference multiple IP-Adapters and multiple reference images! Thanks to @asomoza for their help. Get started with the code below:

import torch
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from transformers import CLIPVisionModelWithProjection
from diffusers.utils import load_image

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter", 
    subfolder="models/image_encoder",
    torch_dtype=torch.float16,
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    image_encoder=image_encoder,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"])
pipeline.set_ip_adapter_scale([0.7, 0.3])

pipeline.enable_model_cpu_offload()

face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")

style_folder = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy"
style_images =  [load_image(f"{style_folder}/img{i}.png") for i in range(10)]

generator = torch.Generator(device="cpu").manual_seed(0)

image = pipeline(
    prompt="wonderwoman",
    ip_adapter_image=[style_images, face_image],
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=50
    generator=generator,
).images[0]

Reference style images:

Reference face Image Output Image

📜 Check out the docs here.

Single-file checkpoint loading

from_single_file() utility has been refactored for better readability and to follow similar semantics as from_pretrained() . Support for loading single file checkpoints and configs from URLs has also been added.

DPM scheduler fix

We introduced a fix for DPM schedulers, so now you can use it with SDXL to generate high-quality images in fewer steps than the Euler scheduler.

Apart from these, we have done a myriad of refactoring to improve the library design and will continue to do so in the coming days.

All commits

Read more

Patch release

17 Jan 16:48
Compare
Choose a tag to compare

Make sure diffusers can correctly be used in offline mode again: #1767 (comment)

v0.25.0: aMUSEd, faster SDXL, interruptable pipelines

27 Dec 13:49
Compare
Choose a tag to compare

aMUSEd

collage_full

aMUSEd is a lightweight text to image model based off of the MUSE architecture. aMUSEd is particularly useful in applications that require a lightweight and fast model, such as generating many images quickly at once. aMUSEd is currently a research release.

aMUSEd is a VQVAE token-based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with MUSE, it uses the smaller text encoder CLIP-L/14 instead of T5-XXL. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes.

Text-to-image generation

import torch
from diffusers import AmusedPipeline

pipe = AmusedPipeline.from_pretrained(
    "amused/amused-512", variant="fp16", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

prompt = "cowboy"
image = pipe(prompt, generator=torch.manual_seed(8)).images[0]
image.save("text2image_512.png")

Image-to-image generation

import torch
from diffusers import AmusedImg2ImgPipeline
from diffusers.utils import load_image

pipe = AmusedImg2ImgPipeline.from_pretrained(
    "amused/amused-512", variant="fp16", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

prompt = "apple watercolor"
input_image = (
    load_image(
        "https://huggingface.co/amused/amused-512/resolve/main/assets/image2image_256_orig.png"
    )
    .resize((512, 512))
    .convert("RGB")
)

image = pipe(prompt, input_image, strength=0.7, generator=torch.manual_seed(3)).images[0]
image.save("image2image_512.png")

Inpainting

import torch
from diffusers import AmusedInpaintPipeline
from diffusers.utils import load_image
from PIL import Image

pipe = AmusedInpaintPipeline.from_pretrained(
    "amused/amused-512", variant="fp16", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

prompt = "a man with glasses"
input_image = (
    load_image(
        "https://huggingface.co/amused/amused-512/resolve/main/assets/inpainting_256_orig.png"
    )
    .resize((512, 512))
    .convert("RGB")
)
mask = (
    load_image(
        "https://huggingface.co/amused/amused-512/resolve/main/assets/inpainting_256_mask.png"
    )
    .resize((512, 512))
    .convert("L")
)    

image = pipe(prompt, input_image, mask, generator=torch.manual_seed(3)).images[0]
image.save(f"inpainting_512.png")

📜 Docs: https://huggingface.co/docs/diffusers/main/en/api/pipelines/amused

🛠️ Models:

Faster SDXL

We’re excited to present an array of optimization techniques that can be used to accelerate the inference latency of text-to-image diffusion models. All of these can be done in native PyTorch without requiring additional C++ code.

SDXL_Batch_Size__1_Steps__30

These techniques are not specific to Stable Diffusion XL (SDXL) and can be used to improve other text-to-image diffusion models too. Starting from default fp32 precision, we can achieve a 3x speed improvement by applying different PyTorch optimization techniques. We encourage you to check out the detailed docs provided below.

Note: Compared to the default way most people use Diffusers which is fp16 + SDPA, applying all the optimization explained in the blog below yields a 30% speed-up.

📜 Docs: https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion
🌠 PyTorch blog post: https://pytorch.org/blog/accelerating-generative-ai-3/

Interruptible pipelines

Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.

This callback function should take the following arguments: pipe, i, t, and callback_kwargs (this must be returned). Set the pipeline's _interrupt attribute to True to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.

In this example, the diffusion process is stopped after 10 steps even though num_inference_steps is set to 50.

from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.enable_model_cpu_offload()
num_inference_steps = 50

def interrupt_callback(pipe, i, t, callback_kwargs):
    stop_idx = 10
    if i == stop_idx:
        pipe._interrupt = True

    return callback_kwargs

pipe(
    "A photo of a cat",
    num_inference_steps=num_inference_steps,
    callback_on_step_end=interrupt_callback,
)

📜 Docs: https://huggingface.co/docs/diffusers/main/en/using-diffusers/callback

peft in our LoRA training examples

We incorporated peft in all the officially supported training examples concerning LoRA. This greatly simplifies the code and improves readability. LoRA training hasn't been easier, thanks to peft!

More memory-friendly version of LCM LoRA SDXL training

We incorporated best practices from peft to make LCM LoRA training for SDXL more memory-friendly. As such, you don't have to initialize two UNets (teacher and student) anymore. This version also integrates with the datasets library for quick experimentation. Check out this section for more details.

All commits

Read more

v0.24.0: IP Adapters, Kandinsky 3.0, Stable Video Diffusion, SDXL Turbo

29 Nov 19:21
Compare
Choose a tag to compare

Stable Video Diffusion, SDXL Turbo, IP Adapters, Kandinsky 3.0

Stable Diffusion Video

Stable Video Diffusion is a powerful image-to-video generation model that can generate high resolution (576x1024) 2-4 seconds videos conditioned on the input image.

Image to Video Generation

There are two variants of SVD. SVD and SVD-XT. The SVD checkpoint is trained to generate 14 frames and the SVD-XT checkpoint is further finetuned to generate 25 frames.

You need to condition the generation on an initial image, as follows:

import torch

from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video

pipe = StableVideoDiffusionPipeline.from_pretrained(
    "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
)
pipe.enable_model_cpu_offload()

# Load the conditioning image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
image = image.resize((1024, 576))

generator = torch.manual_seed(42)
frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]

export_to_video(frames, "generated.mp4", fps=7)

Since generating videos is more memory intensive, we can use the decode_chunk_size argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory. Setting decode_chunk_size=1 will decode one frame at a time and will use the least amount of memory, but the video might have some flickering.

Additionally, we also use model cpu offloading to reduce the memory usage.

rocket_generated

SDXL Turbo

SDXL Turbo is an adversarial time-distilled Stable Diffusion XL (SDXL) model capable of running inference in as little as 1 step. Also, it does not use classifier-free guidance, further increasing its speed. On a good consumer GPU, you can now generate an image in just 100ms.

Text-to-Image

For text-to-image, pass a text prompt. By default, SDXL Turbo generates a 512x512 image, and that resolution gives the best results. You can try setting the height and width parameters to 768x768 or 1024x1024, but you should expect quality degradations when doing so.

Make sure to set guidance_scale to 0.0 to disable, as the model was trained without it. A single inference step is enough to generate high quality images.
Increasing the number of steps to 2, 3 or 4 should improve image quality.

from diffusers import AutoPipelineForText2Image
import torch

pipeline_text2image = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipeline_text2image = pipeline_text2image.to("cuda")

prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."

image = pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=1).images[0]
image
generated image of a racoon in a robe

Image-to-image

For image-to-image generation, make sure that num_inference_steps * strength is larger or equal to 1.
The image-to-image pipeline will run for int(num_inference_steps * strength) steps, e.g. 0.5 * 2.0 = 1 step in
our example below.

from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image, make_image_grid

# use from_pipe to avoid consuming additional memory when loading a checkpoint
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to("cuda")

init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
init_image = init_image.resize((512, 512))

prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"

image = pipeline(prompt, image=init_image, strength=0.5, guidance_scale=0.0, num_inference_steps=2).images[0]
make_image_grid([init_image, image], rows=1, cols=2)
Image-to-image generation sample using SDXL Turbo

IP Adapters

IP Adapters have shown to be remarkably powerful at images conditioned on other images.

Thanks to @okotaku, we have added IP adapters to the most important pipelines allowing you to combine them for a variety of different workflows, e.g. they work with Img2Img2, ControlNet, and LCM-LoRA out of the box.

LCM-LoRA

from diffusers import DiffusionPipeline, LCMScheduler
import torch
from diffusers.utils import load_image

model_id =  "sd-dreambooth-library/herge-style"
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"

pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)

pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe.load_lora_weights(lcm_lora_id)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()

prompt = "best quality, high quality"
image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
images = pipe(
    prompt=prompt,
    ip_adapter_image=image,
    num_inference_steps=4,
    guidance_scale=1,
).images[0]

yiyi_test_2_out

ControlNet

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch
from diffusers.utils import load_image

controlnet_model_path = "lllyasviel/control_v11f1p_sd15_depth"
controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16)

pipeline = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16)
pipeline.to("cuda")

image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png")
depth_map = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/depth.png")

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
    prompt='best quality, high quality', 
    image=depth_map,
    ip_adapter_image=image,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=50,
    generator=generator,
).images
images[0].save("yiyi_test_2_out.png")
ip_image condition output
statue depth yiyi_test_2_out

For more information:

Kandinsky 3.0

Kandinsky has released the 3rd version, which has much improved text-to-image alignment thanks to using Flan-T5 as the text encoder.

Text-to-Image

from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
        
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."

generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]

Image-to-Image

from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image
import torch

pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
        
prompt = "A painting of the inside of a subway train with tiny raccoons."
image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png")

generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]

Check it out:

All commits

Read more