Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CombinedPipeline fails to accept bfloat16 image tensor as input #7598

Closed
Michael-F-Ellis opened this issue Apr 7, 2024 · 14 comments · Fixed by #7894
Closed

CombinedPipeline fails to accept bfloat16 image tensor as input #7598

Michael-F-Ellis opened this issue Apr 7, 2024 · 14 comments · Fixed by #7894

Comments

@Michael-F-Ellis
Copy link

StableCascadeCombinedPipeLine accepts an images argument that can be a PIL image, a torch tensor, or a list of either. Unfortunately, I can't get it to accept a bfloat16 type for the image. It raises a runtime error in CLIP. I tried float32, but HF's A10G Large runs out of memory.

Here's the error I'm seeing when I try to pass an image encoded as torch.bfloat16

File "/home/user/app/app.py", line 50, in generate_image
    results =  pipe(
  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py", line 268, in __call__
    prior_outputs = self.prior_pipe(
  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py", line 504, in __call__
    image_embeds_pooled, uncond_image_embeds_pooled = self.encode_image(
  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py", line 254, in encode_image
    image = self.feature_extractor(image, return_tensors="pt").pixel_values
  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/transformers/image_processing_utils.py", line 551, in __call__
    return self.preprocess(images, **kwargs)
  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/transformers/models/clip/image_processing_clip.py", line 306, in preprocess
    images = [to_numpy_array(image) for image in images]
  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/transformers/models/clip/image_processing_clip.py", line 306, in <listcomp>
    images = [to_numpy_array(image) for image in images]
  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/transformers/image_utils.py", line 174, in to_numpy_array
    return to_numpy(img)
  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/transformers/utils/generic.py", line 308, in to_numpy
    return framework_to_numpy[framework](obj)
  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/transformers/utils/generic.py", line 293, in <lambda>
    "pt": lambda obj: obj.detach().cpu().numpy(),
TypeError: Got unsupported ScalarType BFloat16

FWIW, here are relevant snippets from the code that produced the above error:

# Define a transform to convert a PIL (method given by Claude 3 Sonnet)
def transform(image):
    # Convert the image to a PyTorch tensor
    input_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0)
    # Convert the tensor to 'bfloat16' dtype
    input_tensor = input_tensor.to(torch.bfloat16)
    return input_tensor

# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
    pipe = StableCascadeCombinedPipeline.from_pretrained(repo, torch_dtype=torch.bfloat16)
    pipe.to("cuda")

# The generate function
@spaces.GPU(enable_queue=True)
def generate_image(prompt, image):  
    if image is not None:
        # Convert the PIL image to Torch tensor
        # and move it to GPU
        img_tensor = transform(image)
        img_tensor = [img_tensor.to("cuda")]
    else:
        img_tensor=None

    seed  =  random.randint(-100000,100000)

    results =  pipe(
                prompt=prompt,
                images=img_tensor,
                height=1024,
                width=1024,
                num_inference_steps=20, 
                generator=torch.Generator(device="cuda").manual_seed(seed)
            )
    return results.images[0]

Originally posted by @Michael-F-Ellis in #7571 (comment)

@tolgacangoz
Copy link
Contributor

tolgacangoz commented Apr 7, 2024

Since diffusers' CLIP comes from transformers, it might be more appropriate to open this issue there. Maybe in torch or numpy.

@Michael-F-Ellis
Copy link
Author

Since diffusers' CLIP comes from transformers, it might be more appropriate to open this issue there. Or maybe in numpy.

I found this issue in pytorch. It's been around since 2022 and the OP states, "Numpy doesn't support bfloat16, and doesn't plan to do so."

So it sounds like a 'can't get there from here' situation. Anyway, my actual goal is doing img2img with StableCascade. Any suggestion on how to do it without having to resort to float32 for the model and image would be much appreciated.

@tolgacangoz
Copy link
Contributor

Could you try this one?

@Michael-F-Ellis
Copy link
Author

Thanks. I tried

   input_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0)
    # Convert the tensor to 'bfloat16' dtype
   input_tensor = input_tensor.to(torch.bfloat16).float().numpy().astype(ml_dtypes.bfloat16)

It compiled and loaded ok, but I get a runtime AttributeError trying to send the image to cuda.

  File "/home/user/app/app.py", line 46, in generate_image
    img_tensor = [img_tensor.cuda]
AttributeError: 'numpy.ndarray' object has no attribute 'cuda'

I previously had image_tensor.to('cuda') but that complained that the tensor object had no to() method.

@tolgacangoz
Copy link
Contributor

Could you try this one?

I imagined this modification in transformers' utils/generic.py file at line around 293:

+    import ml_dtypes

    framework_to_numpy = {
-        "pt": lambda obj: obj.detach().cpu().numpy(),
+        "pt": lambda obj: obj.detach().cpu().float().numpy().astype(ml_dtypes.bfloat16),

@Michael-F-Ellis
Copy link
Author

Ahh, sorry for misunderstanding. Apologies for my next dumb question:
Given that I'm trying to run this in a HF Space, how do I get at the transformers/utils/generic.py file to hack it?

@tolgacangoz
Copy link
Contributor

tolgacangoz commented Apr 7, 2024

I guess we can add a requirements.txt file at the root of the repository in an HF Space. You can add your forked and modified transformers there, and when the space is reinitialized/restarted/rebuilded it will be installed?

@Michael-F-Ellis
Copy link
Author

Btw, Nvidia T4 medium might be enough if you make this change:

Thanks! Using a cheaper GPU is appealing, I'll try that after I solve the A10 conundrum :-)

Meanwhile, forking the transformers repo and making your suggested changes got me farther down the road. It's no longer choking in generic.py. Now I'm getting

  File "/home/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (CUDABFloat16Type) and weight type (CPUBFloat16Type) should be the same

I guess I understand even less than I thought. My code has

# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
    pipe = StableCascadeCombinedPipeline.from_pretrained(repo, torch_dtype=torch.bfloat16)
    pipe.to("cuda")

Does that not convert the weights to CUDA format? Or do I need to somehow get the image to be seen as CPUBFloat16Type?

BTW, If I pass None for the image, the code runs correctly and produces an image based strictly on the prompt. So it's clearly doing the right thing with the weights.

@tolgacangoz
Copy link
Contributor

tolgacangoz commented Apr 8, 2024

It seems that pipe.to('cuda') doesn't seem to do its job fully here, so one needs to add this:

    pipe = StableCascadeCombinedPipeline.from_pretrained(repo, torch_dtype=torch.bfloat16)
    pipe.to('cuda')
+   pipe.prior_image_encoder.to('cuda')

@Michael-F-Ellis
Copy link
Author

pipe.prior_image_encoder.to('cuda')

That did it. Thanks! I really appreciate the time and effort you put into resolving this.

It also enables sending a PIL image as input without transforming it.

@Michael-F-Ellis
Copy link
Author

I've just verified that the fix eliminates the need for a hacked version of transformers. Will post a cleaned up and simplified version of the code so others can use it as a starting point.

@Michael-F-Ellis
Copy link
Author

As promised here's a minimal app.py that's been tested in HF Spacess on an A10G large:

import gradio as gr
import spaces
from diffusers import StableCascadeCombinedPipeline
import os
import torch
from PIL import Image
import random

# Constants
repo = "stabilityai/stable-cascade"


# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
    pipe = StableCascadeCombinedPipeline.from_pretrained(repo, torch_dtype=torch.bfloat16)
    pipe.to("cuda")
    # As of 2024-04-08, pipe.to("cuda") does not move prior image encoder to GPU,
    # so we need to do it manually. 
    # See https://github.com/huggingface/diffusers/issues/7598
    pipe.prior_image_encoder.to('cuda') 

# The generate function
@spaces.GPU(enable_queue=True)
def generate_image(prompt, image):  
    seed  =  random.randint(-100000,100000)

    results =  pipe(
                prompt=prompt,
                images=[image] if image is not None else None,
                # default output size for SC is 1024x1024
                height=1024,
                width=1024,
                num_inference_steps=20, # 20 steps takes ~17 seconds on an A10G GPU
                generator=torch.Generator(device="cuda").manual_seed(seed)
            )
    return results.images[0]


# ------------- Gradio Interface -----------------------
description = """
A minimal demo using Stable Cascade combined pipeline for image-to-image generation.
A more useful version would provide UI components for controlling the input parameters
and save the settings with the generated images.
"""

with gr.Blocks(css="style.css") as demo:
    gr.HTML("<h1><center>Stable Cascade Img2Img ⚡</center></h1>")
    gr.Markdown(description)
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(label='Enter your prompt', scale=8, value="holding a puppy")
            submit = gr.Button(scale=1, variant='primary')
    imgin = gr.Image(label='Input Image', type='pil', height=1024, width=1024, interactive=True)
    imgout = gr.Image(label='Generated Image', height=1024, width=1024)

    prompt.submit(fn=generate_image,
                 inputs=[prompt, imgin],
                 outputs=imgout,
                 )
    submit.click(fn=generate_image,
                 inputs=[prompt, imgin],
                 outputs=imgout,
                 )
    
demo.queue().launch()

The requirements.txt file for this app is:

transformers
diffusers
torch
accelerate
gradio

Copy link

github-actions bot commented May 8, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label May 8, 2024
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label May 9, 2024
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented May 9, 2024

hey @Michael-F-Ellis
I fixed here so that the to method for StableCascadeCombinedPipeline should work correctly now
#7894

import torch
from diffusers import StableCascadeCombinedPipeline, StableCascadePriorPipeline
pipe = StableCascadeCombinedPipeline.from_pretrained(
        "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16
    )
pipe.to("cuda")

@standardAI thanks for finding the cause of the error!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants