Spaces:
Runtime error
Running the PixArtAlphaPipeline
in under 8GB GPU VRAM
It is possible to run the [PixArtAlphaPipeline
] under 8GB GPU VRAM by loading the text encoder in 8-bit numerical precision. Let's walk through a full-fledged example.
First, install the bitsandbytes
library:
pip install -U bitsandbytes
Then load the text encoder in 8-bit:
from transformers import T5EncoderModel
from diffusers import PixArtAlphaPipeline
text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="text_encoder",
load_in_8bit=True,
device_map="auto",
)
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=text_encoder,
transformer=None,
device_map="auto"
)
Now, use the pipe
to encode a prompt:
with torch.no_grad():
prompt = "cute cat"
prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)
del text_encoder
del pipe
flush()
flush()
is just a utility function to clear the GPU VRAM and is implemented like so:
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
Then compute the latents providing the prompt embeddings as inputs:
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=None,
torch_dtype=torch.float16,
).to("cuda")
latents = pipe(
negative_prompt=None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
num_images_per_prompt=1,
output_type="latent",
).images
del pipe.transformer
flush()
Notice that while initializing pipe
, you're setting text_encoder
to None
so that it's not loaded.
Once the latents are computed, pass it off the VAE to decode into a real image:
with torch.no_grad():
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")
image.save("cat.png")
All of this, put together, should allow you to run [PixArtAlphaPipeline
] under 8GB GPU VRAM.
Find the script here that can be run end-to-end to report the memory being used.
Text embeddings computed in 8-bit can have an impact on the quality of the generated images because of the information loss in the representation space induced by the reduced precision. It's recommended to compare the outputs with and without 8-bit.