Skip to content

sayakpaul/xla-benchmark-sd

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

XLA compilation of Stable Diffusion in TensorFlow

TL;DR: You can XLA-compile the isolated parts of Stable Diffusion in TensorFlow and achieve 17% speedup boost in inference.

This repository provides code to serialize the different models involved in Stable Diffusion as SavedModels and to compile them with XLA. As result of the XLA-compiled concrete functions, we can obtain a good amount of speedup in the inference process.

We use the Stable Diffusion model shipped from KerasCV.

Table of content

Results

  • KerasCV with XLA: 12.40 seconds
  • SavedModels with XLA 10.29 seconds
  • SavedModels without XLA 13.69 seconds

~25% w.r.t non-XLA SavedModel & ~17% w.r.t KerasCV.

Steps

We first isolate the sub-models involved in Stable Diffusion and serialize them as stand-alone SavedModels:

  • Text encoder
  • Diffusion model aka UNet
  • Decoder

The SavedModel also includes their respective computations. For example SavedModel of the text encoder includes the processing of prompt context and the unconditional context. Similarly, SavedModel of the UNet includes the computations for the diffusion process.

For the serialization, just run serialize_savedmodels.py.

Once the SavedModels are generated, we load them as concrete functions and XLA-compile them before running inference. We include the complete code for this in benchmark.py.

Running the benchmark

For running the KerasCV benchmark:

python benchmark.py --kerascv --jit_compile

For running with SavedModels (without XLA):

python benchmark.py 

For running with SavedModels (with XLA):

python benchmark.py --jit_compile

Details of the benchmark

The benchmarks were run on an a2-highgpu-1g instance.

Gotchas

  • The text encoder cannot be XLA-compiled. See this issue for more details.
  • For making the SavedModels XLA-compitable, we fix the number of images that can be generated per prompt. Otherwise, it doesn't become a compile-time constant which makes it XLA-incompatible.

Acknowledgements

Thanks to the ML Developer Programs' team at Google for providing GCP credit support.

About

Provides code to serialize the different models involved in Stable Diffusion as SavedModels and to compile them with XLA.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages