Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance 1/6] use_checkpoint = False #15803

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from

Conversation

huchenlei
Copy link
Contributor

@huchenlei huchenlei commented May 15, 2024

Description

According to lllyasviel/stable-diffusion-webui-forge#716 (comment) ,
calls to parameters in checkpoint function is a significant overhead in A1111. However, checkpoint function is mainly used for training, disabling it does not affect inference at all.

This PR disables checkpoint in A1111 in exchange for performance improvement. This reduces about 100ms/it on my local setup (4090). The duration/it before patch is ~580ms/it.

Screenshots/videos:

image

Checklist:

@huchenlei huchenlei changed the title use_checkpoint = False [Performance 1/6] use_checkpoint = False May 15, 2024
@huchenlei huchenlei changed the base branch from master to dev May 15, 2024 19:30
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
return checkpoint(self._forward, x, context, flag=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoint here is torch.utils.checkpoint.checkpoint, and it does not have flag=False. I think you confused this with ldm.modules.diffusionmodules.util.checkpoint. The sd_hijack_checkpoint.py already removed the checkpointing in ldm, but we might need to do it on sgm as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, what you are looking for is actually

ldm.modules.attention.BasicTransformerBlock.forward = ldm.modules.attention.BasicTransformerBlock._forward

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A closer look indicates that the checkpoint here is only called when training occurs (textual_inversion & hypernetwork) so disabling checkpoint here may be undesirable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants