ShadeEngine's picture
End of training
4ac8f3e verified

T-GATE

T-GATE accelerates inference for Stable Diffusion, PixArt, and Latency Consistency Model pipelines by skipping the cross-attention calculation once it converges. This method doesn't require any additional training and it can speed up inference from 10-50%. T-GATE is also compatible with other optimization methods like DeepCache.

Before you begin, make sure you install T-GATE.

pip install tgate
pip install -U torch diffusers transformers accelerate DeepCache

To use T-GATE with a pipeline, you need to use its corresponding loader.

Pipeline T-GATE Loader
PixArt TgatePixArtLoader
Stable Diffusion XL TgateSDXLLoader
Stable Diffusion XL + DeepCache TgateSDXLDeepCacheLoader
Stable Diffusion TgateSDLoader
Stable Diffusion + DeepCache TgateSDDeepCacheLoader

Next, create a TgateLoader with a pipeline, the gate step (the time step to stop calculating the cross attention), and the number of inference steps. Then call the tgate method on the pipeline with a prompt, gate step, and the number of inference steps.

Let's see how to enable this for several different pipelines.

Accelerate PixArtAlphaPipeline with T-GATE:

import torch
from diffusers import PixArtAlphaPipeline
from tgate import TgatePixArtLoader

pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)

gate_step = 8
inference_step = 25
pipe = TgatePixArtLoader(
       pipe,
       gate_step=gate_step,
       num_inference_steps=inference_step,
).to("cuda")

image = pipe.tgate(
       "An alpaca made of colorful building blocks, cyberpunk.",
       gate_step=gate_step,
       num_inference_steps=inference_step,
).images[0]

Accelerate StableDiffusionXLPipeline with T-GATE:

import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import DPMSolverMultistepScheduler
from tgate import TgateSDXLLoader

pipe = StableDiffusionXLPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

gate_step = 10
inference_step = 25
pipe = TgateSDXLLoader(
       pipe,
       gate_step=gate_step,
       num_inference_steps=inference_step,
).to("cuda")

image = pipe.tgate(
       "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
       gate_step=gate_step,
       num_inference_steps=inference_step
).images[0]

Accelerate StableDiffusionXLPipeline with DeepCache and T-GATE:

import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import DPMSolverMultistepScheduler
from tgate import TgateSDXLDeepCacheLoader

pipe = StableDiffusionXLPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

gate_step = 10
inference_step = 25
pipe = TgateSDXLDeepCacheLoader(
       pipe,
       cache_interval=3,
       cache_branch_id=0,
).to("cuda")

image = pipe.tgate(
       "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
       gate_step=gate_step,
       num_inference_steps=inference_step
).images[0]

Accelerate latent-consistency/lcm-sdxl with T-GATE:

import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import UNet2DConditionModel, LCMScheduler
from diffusers import DPMSolverMultistepScheduler
from tgate import TgateSDXLLoader

unet = UNet2DConditionModel.from_pretrained(
    "latent-consistency/lcm-sdxl",
    torch_dtype=torch.float16,
    variant="fp16",
)
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    unet=unet,
    torch_dtype=torch.float16,
    variant="fp16",
)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

gate_step = 1
inference_step = 4
pipe = TgateSDXLLoader(
       pipe,
       gate_step=gate_step,
       num_inference_steps=inference_step,
       lcm=True
).to("cuda")

image = pipe.tgate(
       "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
       gate_step=gate_step,
       num_inference_steps=inference_step
).images[0]

T-GATE also supports [StableDiffusionPipeline] and PixArt-alpha/PixArt-LCM-XL-2-1024-MS.

Benchmarks

Model MACs Param Latency Zero-shot 10K-FID on MS-COCO
SD-1.5 16.938T 859.520M 7.032s 23.927
SD-1.5 w/ T-GATE 9.875T 815.557M 4.313s 20.789
SD-2.1 38.041T 865.785M 16.121s 22.609
SD-2.1 w/ T-GATE 22.208T 815.433 M 9.878s 19.940
SD-XL 149.438T 2.570B 53.187s 24.628
SD-XL w/ T-GATE 84.438T 2.024B 27.932s 22.738
Pixart-Alpha 107.031T 611.350M 61.502s 38.669
Pixart-Alpha w/ T-GATE 65.318T 462.585M 37.867s 35.825
DeepCache (SD-XL) 57.888T - 19.931s 23.755
DeepCache w/ T-GATE 43.868T - 14.666s 23.999
LCM (SD-XL) 11.955T 2.570B 3.805s 25.044
LCM w/ T-GATE 11.171T 2.024B 3.533s 25.028
LCM (Pixart-Alpha) 8.563T 611.350M 4.733s 36.086
LCM w/ T-GATE 7.623T 462.585M 4.543s 37.048

The latency is tested on an NVIDIA 1080TI, MACs and Params are calculated with calflops, and the FID is calculated with PytorchFID.