TL;DR: You can XLA-compile the isolated parts of Stable Diffusion in TensorFlow and achieve 17% speedup boost in inference.
This repository provides code to serialize the different models involved in Stable Diffusion as SavedModels and to compile them with XLA. As result of the XLA-compiled concrete functions, we can obtain a good amount of speedup in the inference process.
We use the Stable Diffusion model shipped from KerasCV.
Table of content
- KerasCV with XLA: 12.40 seconds
- SavedModels with XLA 10.29 seconds
- SavedModels without XLA 13.69 seconds
~25% w.r.t non-XLA SavedModel & ~17% w.r.t KerasCV.
We first isolate the sub-models involved in Stable Diffusion and serialize them as stand-alone SavedModels:
- Text encoder
- Diffusion model aka UNet
- Decoder
The SavedModel also includes their respective computations. For example SavedModel of the text encoder includes the processing of prompt context and the unconditional context. Similarly, SavedModel of the UNet includes the computations for the diffusion process.
For the serialization, just run serialize_savedmodels.py
.
Once the SavedModels are generated, we load them as concrete functions and XLA-compile them before running inference. We include the complete code for this in benchmark.py
.
For running the KerasCV benchmark:
python benchmark.py --kerascv --jit_compile
For running with SavedModels (without XLA):
python benchmark.py
For running with SavedModels (with XLA):
python benchmark.py --jit_compile
The benchmarks were run on an a2-highgpu-1g
instance.
- The text encoder cannot be XLA-compiled. See this issue for more details.
- For making the SavedModels XLA-compitable, we fix the number of images that can be generated per prompt. Otherwise, it doesn't become a compile-time constant which makes it XLA-incompatible.
Thanks to the ML Developer Programs' team at Google for providing GCP credit support.