DeepCache
DeepCache accelerates [StableDiffusionPipeline
] and [StableDiffusionXLPipeline
] by strategically caching and reusing high-level features while efficiently updating low-level features by taking advantage of the U-Net architecture.
Start by installing DeepCache:
pip install DeepCache
Then load and enable the DeepCacheSDHelper
:
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5', torch_dtype=torch.float16).to("cuda")
+ from DeepCache import DeepCacheSDHelper
+ helper = DeepCacheSDHelper(pipe=pipe)
+ helper.set_params(
+ cache_interval=3,
+ cache_branch_id=0,
+ )
+ helper.enable()
image = pipe("a photo of an astronaut on a moon").images[0]
The set_params
method accepts two arguments: cache_interval
and cache_branch_id
. cache_interval
means the frequency of feature caching, specified as the number of steps between each cache operation. cache_branch_id
identifies which branch of the network (ordered from the shallowest to the deepest layer) is responsible for executing the caching processes.
Opting for a lower cache_branch_id
or a larger cache_interval
can lead to faster inference speed at the expense of reduced image quality (ablation experiments of these two hyperparameters can be found in the paper). Once those arguments are set, use the enable
or disable
methods to activate or deactivate the DeepCacheSDHelper
.
You can find more generated samples (original pipeline vs DeepCache) and the corresponding inference latency in the WandB report. The prompts are randomly selected from the MS-COCO 2017 dataset.
Benchmark
We tested how much faster DeepCache accelerates Stable Diffusion v2.1 with 50 inference steps on an NVIDIA RTX A5000, using different configurations for resolution, batch size, cache interval (I), and cache branch (B).
Resolution | Batch size | Original | DeepCache(I=3, B=0) | DeepCache(I=5, B=0) | DeepCache(I=5, B=1) |
---|---|---|---|---|---|
512 | 8 | 15.96 | 6.88(2.32x) | 5.03(3.18x) | 7.27(2.20x) |
4 | 8.39 | 3.60(2.33x) | 2.62(3.21x) | 3.75(2.24x) | |
1 | 2.61 | 1.12(2.33x) | 0.81(3.24x) | 1.11(2.35x) | |
768 | 8 | 43.58 | 18.99(2.29x) | 13.96(3.12x) | 21.27(2.05x) |
4 | 22.24 | 9.67(2.30x) | 7.10(3.13x) | 10.74(2.07x) | |
1 | 6.33 | 2.72(2.33x) | 1.97(3.21x) | 2.98(2.12x) | |
1024 | 8 | 101.95 | 45.57(2.24x) | 33.72(3.02x) | 53.00(1.92x) |
4 | 49.25 | 21.86(2.25x) | 16.19(3.04x) | 25.78(1.91x) | |
1 | 13.83 | 6.07(2.28x) | 4.43(3.12x) | 7.15(1.93x) |