λ©λͺ¨λ¦¬μ μλ
λ©λͺ¨λ¦¬ λλ μλμ λν΄ π€ Diffusers μΆλ‘ μ μ΅μ ννκΈ° μν λͺ κ°μ§ κΈ°μ κ³Ό μμ΄λμ΄λ₯Ό μ μν©λλ€. μΌλ°μ μΌλ‘, memory-efficient attentionμ μν΄ xFormers μ¬μ©μ μΆμ²νκΈ° λλ¬Έμ, μΆμ²νλ μ€μΉ λ°©λ²μ λ³΄κ³ μ€μΉν΄ 보μΈμ.
λ€μ μ€μ μ΄ μ±λ₯κ³Ό λ©λͺ¨λ¦¬μ λ―ΈμΉλ μν₯μ λν΄ μ€λͺ ν©λλ€.
μ§μ°μκ° | μλ ν₯μ | |
---|---|---|
λ³λ μ€μ μμ | 9.50s | x1 |
cuDNN auto-tuner | 9.37s | x1.01 |
fp16 | 3.61s | x2.63 |
Channels Last λ©λͺ¨λ¦¬ νμ | 3.30s | x2.88 |
traced UNet | 3.21s | x2.96 |
memory-efficient attention | 2.63s | x3.61 |
cuDNN auto-tuner νμ±ννκΈ°
NVIDIA cuDNNμ 컨볼루μ μ κ³μ°νλ λ§μ μκ³ λ¦¬μ¦μ μ§μν©λλ€. Autotunerλ 짧μ λ²€μΉλ§ν¬λ₯Ό μ€ννκ³ μ£Όμ΄μ§ μ λ ₯ ν¬κΈ°μ λν΄ μ£Όμ΄μ§ νλμ¨μ΄μμ μ΅κ³ μ μ±λ₯μ κ°μ§ 컀λμ μ νν©λλ€.
컨볼루μ λ€νΈμν¬λ₯Ό νμ©νκ³ μκΈ° λλ¬Έμ (λ€λ₯Έ μ νλ€μ νμ¬ μ§μλμ§ μμ), λ€μ μ€μ μ ν΅ν΄ μΆλ‘ μ μ cuDNN autotunerλ₯Ό νμ±νν μ μμ΅λλ€:
import torch
torch.backends.cudnn.benchmark = True
fp32 λμ tf32 μ¬μ©νκΈ° (Ampere λ° μ΄ν CUDA μ₯μΉλ€μμ)
Ampere λ° μ΄ν CUDA μ₯μΉμμ νλ ¬κ³± λ° μ»¨λ³Όλ£¨μ μ TensorFloat32(TF32) λͺ¨λλ₯Ό μ¬μ©νμ¬ λ λΉ λ₯΄μ§λ§ μ½κ° λ μ νν μ μμ΅λλ€. κΈ°λ³Έμ μΌλ‘ PyTorchλ 컨볼루μ μ λν΄ TF32 λͺ¨λλ₯Ό νμ±ννμ§λ§ νλ ¬ κ³±μ μ νμ±ννμ§ μμ΅λλ€. λ€νΈμν¬μ μμ ν float32 μ λ°λκ° νμν κ²½μ°κ° μλλ©΄ νλ ¬ κ³±μ μ λν΄μλ μ΄ μ€μ μ νμ±ννλ κ²μ΄ μ’μ΅λλ€. μ΄λ μΌλ°μ μΌλ‘ 무μν μ μλ μμΉμ μ νλ μμ€μ΄ μμ§λ§, κ³μ° μλλ₯Ό ν¬κ² λμΌ μ μμ΅λλ€. κ·Έκ²μ λν΄ μ¬κΈ°μ λ μ½μ μ μμ΅λλ€. μΆλ‘ νκΈ° μ μ λ€μμ μΆκ°νκΈ°λ§ νλ©΄ λ©λλ€:
import torch
torch.backends.cuda.matmul.allow_tf32 = True
λ°μ λ°λ κ°μ€μΉ
λ λ§μ GPU λ©λͺ¨λ¦¬λ₯Ό μ μ½νκ³ λ λΉ λ₯Έ μλλ₯Ό μ»κΈ° μν΄ λͺ¨λΈ κ°μ€μΉλ₯Ό λ°μ λ°λ(half precision)λ‘ μ§μ λΆλ¬μ€κ³ μ€νν μ μμ΅λλ€.
μ¬κΈ°μλ fp16
μ΄λΌλ λΈλμΉμ μ μ₯λ float16 λ²μ μ κ°μ€μΉλ₯Ό λΆλ¬μ€κ³ , κ·Έ λ float16
μ νμ μ¬μ©νλλ‘ PyTorchμ μ§μνλ μμ
μ΄ ν¬ν¨λ©λλ€.
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
μ΄λ€ νμ΄νλΌμΈμμλ [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) λ₯Ό μ¬μ©νλ κ²μ κ²μμ μ΄λ―Έμ§λ₯Ό μμ±ν μ μκ³ , μμν float16 μ λ°λλ₯Ό μ¬μ©νλ κ²λ³΄λ€ νμ λ리기 λλ¬Έμ μ¬μ©νμ§ μλ κ²μ΄ μ’μ΅λλ€.
μΆκ° λ©λͺ¨λ¦¬ μ μ½μ μν μ¬λΌμ΄μ€ μ΄ν μ
μΆκ° λ©λͺ¨λ¦¬ μ μ½μ μν΄, ν λ²μ λͺ¨λ κ³μ°νλ λμ λ¨κ³μ μΌλ‘ κ³μ°μ μννλ μ¬λΌμ΄μ€ λ²μ μ μ΄ν μ (attention)μ μ¬μ©ν μ μμ΅λλ€.
Attention slicingμ λͺ¨λΈμ΄ νλ μ΄μμ μ΄ν μ ν€λλ₯Ό μ¬μ©νλ ν, λ°°μΉ ν¬κΈ°κ° 1μΈ κ²½μ°μλ μ μ©ν©λλ€. νλ μ΄μμ μ΄ν μ ν€λκ° μλ κ²½μ° *QK^T* μ΄ν μ 맀νΈλ¦μ€λ μλΉν μμ λ©λͺ¨λ¦¬λ₯Ό μ μ½ν μ μλ κ° ν€λμ λν΄ μμ°¨μ μΌλ‘ κ³μ°λ μ μμ΅λλ€.κ° ν€λμ λν΄ μμ°¨μ μΌλ‘ μ΄ν
μ
κ³μ°μ μννλ €λ©΄, λ€μκ³Ό κ°μ΄ μΆλ‘ μ μ νμ΄νλΌμΈμμ [~StableDiffusionPipeline.enable_attention_slicing
]λ₯Ό νΈμΆνλ©΄ λ©λλ€:
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_attention_slicing()
image = pipe(prompt).images[0]
μΆλ‘ μκ°μ΄ μ½ 10% λλ €μ§λ μ½κ°μ μ±λ₯ μ νκ° μμ§λ§ μ΄ λ°©λ²μ μ¬μ©νλ©΄ 3.2GB μ λμ μμ VRAMμΌλ‘λ Stable Diffusionμ μ¬μ©ν μ μμ΅λλ€!
λ ν° λ°°μΉλ₯Ό μν sliced VAE λμ½λ
μ νλ VRAMμμ λκ·λͺ¨ μ΄λ―Έμ§ λ°°μΉλ₯Ό λμ½λ©νκ±°λ 32κ° μ΄μμ μ΄λ―Έμ§κ° ν¬ν¨λ λ°°μΉλ₯Ό νμ±ννκΈ° μν΄, λ°°μΉμ latent μ΄λ―Έμ§λ₯Ό ν λ²μ νλμ© λμ½λ©νλ μ¬λΌμ΄μ€ VAE λμ½λλ₯Ό μ¬μ©ν μ μμ΅λλ€.
μ΄λ₯Ό [~StableDiffusionPipeline.enable_attention_slicing
] λλ [~StableDiffusionPipeline.enable_xformers_memory_efficient_attention
]κ³Ό κ²°ν©νμ¬ λ©λͺ¨λ¦¬ μ¬μ©μ μΆκ°λ‘ μ΅μνν μ μμ΅λλ€.
VAE λμ½λλ₯Ό ν λ²μ νλμ© μννλ €λ©΄ μΆλ‘ μ μ νμ΄νλΌμΈμμ [~StableDiffusionPipeline.enable_vae_slicing
]μ νΈμΆν©λλ€. μλ₯Ό λ€μ΄:
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_vae_slicing()
images = pipe([prompt] * 32).images
λ€μ€ μ΄λ―Έμ§ λ°°μΉμμ VAE λμ½λκ° μ½κ°μ μ±λ₯ ν₯μμ΄ μ΄λ£¨μ΄μ§λλ€. λ¨μΌ μ΄λ―Έμ§ λ°°μΉμμλ μ±λ₯ μν₯μ μμ΅λλ€.
λ©λͺ¨λ¦¬ μ μ½μ μν΄ κ°μ κΈ°λ₯μ μ¬μ©νμ¬ CPUλ‘ μ€νλ‘λ©
μΆκ° λ©λͺ¨λ¦¬ μ μ½μ μν΄ κ°μ€μΉλ₯Ό CPUλ‘ μ€νλ‘λνκ³ μλ°©ν₯ μ λ¬μ μνν λλ§ GPUλ‘ λ‘λν μ μμ΅λλ€.
CPU μ€νλ‘λ©μ μννλ €λ©΄ [~StableDiffusionPipeline.enable_sequential_cpu_offload
]λ₯Ό νΈμΆνκΈ°λ§ νλ©΄ λ©λλ€:
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_sequential_cpu_offload()
image = pipe(prompt).images[0]
κ·Έλ¬λ©΄ λ©λͺ¨λ¦¬ μλΉλ₯Ό 3GB λ―Έλ§μΌλ‘ μ€μΌ μ μμ΅λλ€.
μ°Έκ³ λ‘ μ΄ λ°©λ²μ μ 체 λͺ¨λΈμ΄ μλ μλΈλͺ¨λ μμ€μμ μλν©λλ€. μ΄λ λ©λͺ¨λ¦¬ μλΉλ₯Ό μ΅μννλ κ°μ₯ μ’μ λ°©λ²μ΄μ§λ§ νλ‘μΈμ€μ λ°λ³΅μ νΉμ±μΌλ‘ μΈν΄ μΆλ‘ μλκ° ν¨μ¬ λ립λλ€. νμ΄νλΌμΈμ UNet κ΅¬μ± μμλ μ¬λ¬ λ² μ€νλ©λλ€('num_inference_steps' λ§νΌ). λ§€λ² UNetμ μλ‘ λ€λ₯Έ μλΈλͺ¨λμ΄ μμ°¨μ μΌλ‘ μ¨λ‘λλ λ€μ νμμ λ°λΌ μ€νλ‘λλλ―λ‘ λ©λͺ¨λ¦¬ μ΄λ νμκ° λ§μ΅λλ€.
λ λ€λ₯Έ μ΅μ ν λ°©λ²μΈ λͺ¨λΈ μ€νλ‘λ©μ μ¬μ©νλ κ²μ κ³ λ €νμμμ€. μ΄λ ν¨μ¬ λΉ λ₯΄μ§λ§ λ©λͺ¨λ¦¬ μ μ½μ΄ ν¬μ§λ μμ΅λλ€.λν ttention slicingκ³Ό μ°κ²°ν΄μ μ΅μ λ©λͺ¨λ¦¬(< 2GB)λ‘λ λμν μ μμ΅λλ€.
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_sequential_cpu_offload()
pipe.enable_attention_slicing(1)
image = pipe(prompt).images[0]
μ°Έκ³ : 'enable_sequential_cpu_offload()'λ₯Ό μ¬μ©ν λ, 미리 νμ΄νλΌμΈμ CUDAλ‘ μ΄λνμ§ μλ κ²μ΄ μ€μν©λλ€.κ·Έλ μ§ μμΌλ©΄ λ©λͺ¨λ¦¬ μλΉμ μ΄λμ΄ μ΅μνλ©λλ€. λ λ§μ μ 보λ₯Ό μν΄ μ΄ μ΄μλ₯Ό 보μΈμ.
λΉ λ₯Έ μΆλ‘ κ³Ό λ©λͺ¨λ¦¬ λ©λͺ¨λ¦¬ μ μ½μ μν λͺ¨λΈ μ€νλ‘λ©
μμ°¨μ CPU μ€νλ‘λ©μ μ΄μ μΉμ μμ μ€λͺ ν κ²μ²λΌ λ§μ λ©λͺ¨λ¦¬λ₯Ό 보쑴νμ§λ§ νμμ λ°λΌ μλΈλͺ¨λμ GPUλ‘ μ΄λνκ³ μ λͺ¨λμ΄ μ€νλ λ μ¦μ CPUλ‘ λ°νλκΈ° λλ¬Έμ μΆλ‘ μλκ° λλ €μ§λλ€.
μ 체 λͺ¨λΈ μ€νλ‘λ©μ κ° λͺ¨λΈμ κ΅¬μ± μμμΈ _modules_μ μ²λ¦¬νλ λμ , μ 체 λͺ¨λΈμ GPUλ‘ μ΄λνλ λμμ λλ€. μ΄λ‘ μΈν΄ μΆλ‘ μκ°μ λ―ΈμΉλ μν₯μ λ―Έλ―Ένμ§λ§(νμ΄νλΌμΈμ 'cuda'λ‘ μ΄λνλ κ²κ³Ό λΉκ΅νμ¬) μ¬μ ν μ½κ°μ λ©λͺ¨λ¦¬λ₯Ό μ μ½ν μ μμ΅λλ€.
μ΄ μλ리μ€μμλ νμ΄νλΌμΈμ μ£Όμ κ΅¬μ± μμ μ€ νλλ§(μΌλ°μ μΌλ‘ ν μ€νΈ μΈμ½λ, unet λ° vae) GPUμ μκ³ , λλ¨Έμ§λ CPUμμ λκΈ°ν κ²μ λλ€. μ¬λ¬ λ°λ³΅μ μν΄ μ€νλλ UNetκ³Ό κ°μ κ΅¬μ± μμλ λ μ΄μ νμνμ§ μμ λκΉμ§ GPUμ λ¨μ μμ΅λλ€.
μ΄ κΈ°λ₯μ μλμ κ°μ΄ νμ΄νλΌμΈμμ enable_model_cpu_offload()
λ₯Ό νΈμΆνμ¬ νμ±νν μ μμ΅λλ€.
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_model_cpu_offload()
image = pipe(prompt).images[0]
μ΄λ μΆκ°μ μΈ λ©λͺ¨λ¦¬ μ μ½μ μν attention slicingκ³Όλ νΈνλ©λλ€.
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_model_cpu_offload()
pipe.enable_attention_slicing(1)
image = pipe(prompt).images[0]
μ΄ κΈ°λ₯μ μ¬μ©νλ €λ©΄ 'accelerate' λ²μ 0.17.0 μ΄μμ΄ νμν©λλ€.
Channels Last λ©λͺ¨λ¦¬ νμ μ¬μ©νκΈ°
Channels Last λ©λͺ¨λ¦¬ νμμ μ°¨μ μμλ₯Ό 보쑴νλ λ©λͺ¨λ¦¬μμ NCHW ν μ λ°°μ΄μ λ체νλ λ°©λ²μ λλ€. Channels Last ν μλ μ±λμ΄ κ°μ₯ μ‘°λ°ν μ°¨μμ΄ λλ λ°©μμΌλ‘ μ λ ¬λ©λλ€(μΌλͺ ν½μ λΉ μ΄λ―Έμ§λ₯Ό μ μ₯). νμ¬ λͺ¨λ μ°μ°μ Channels Last νμμ μ§μνλ κ²μ μλλΌ μ±λ₯μ΄ μ νλ μ μμΌλ―λ‘, μ¬μ©ν΄λ³΄κ³ λͺ¨λΈμ μ μλνλμ§ νμΈνλ κ²μ΄ μ’μ΅λλ€.
μλ₯Ό λ€μ΄ νμ΄νλΌμΈμ UNet λͺ¨λΈμ΄ channels Last νμμ μ¬μ©νλλ‘ μ€μ νλ €λ©΄ λ€μμ μ¬μ©ν μ μμ΅λλ€:
print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1)
pipe.unet.to(memory_format=torch.channels_last) # in-place μ°μ°
# 2λ²μ§Έ μ°¨μμμ μ€νΈλΌμ΄λ 1μ κ°μ§λ (2880, 1, 960, 320)λ‘, μ°μ°μ΄ μλν¨μ μ¦λͺ
ν©λλ€.
print(pipe.unet.conv_out.state_dict()["weight"].stride())
μΆμ (tracing)
μΆμ μ λͺ¨λΈμ ν΅ν΄ μμ μ λ ₯ ν μλ₯Ό ν΅ν΄ μ€νλλλ°, ν΄λΉ μ λ ₯μ΄ λͺ¨λΈμ λ μ΄μ΄λ₯Ό ν΅κ³Όν λ νΈμΆλλ μμ μ μΊ‘μ²νμ¬ μ€ν νμΌ λλ 'ScriptFunction'μ΄ λ°νλλλ‘ νκ³ , μ΄λ just-in-time μ»΄νμΌλ‘ μ΅μ νλ©λλ€.
UNet λͺ¨λΈμ μΆμ νκΈ° μν΄ λ€μμ μ¬μ©ν μ μμ΅λλ€:
import time
import torch
from diffusers import StableDiffusionPipeline
import functools
# torch κΈ°μΈκΈ° λΉνμ±ν
torch.set_grad_enabled(False)
# λ³μ μ€μ
n_experiments = 2
unet_runs_per_experiment = 50
# μ
λ ₯ λΆλ¬μ€κΈ°
def generate_inputs():
sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
return sample, timestep, encoder_hidden_states
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
).to("cuda")
unet = pipe.unet
unet.eval()
unet.to(memory_format=torch.channels_last) # Channels Last λ©λͺ¨λ¦¬ νμ μ¬μ©
unet.forward = functools.partial(unet.forward, return_dict=False) # return_dict=Falseμ κΈ°λ³Έκ°μΌλ‘ μ€μ
# μλ°μ
for _ in range(3):
with torch.inference_mode():
inputs = generate_inputs()
orig_output = unet(*inputs)
# μΆμ
print("tracing..")
unet_traced = torch.jit.trace(unet, inputs)
unet_traced.eval()
print("done tracing")
# μλ°μ
λ° κ·Έλν μ΅μ ν
for _ in range(5):
with torch.inference_mode():
inputs = generate_inputs()
orig_output = unet_traced(*inputs)
# λ²€μΉλ§νΉ
with torch.inference_mode():
for _ in range(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
for _ in range(unet_runs_per_experiment):
orig_output = unet_traced(*inputs)
torch.cuda.synchronize()
print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
for _ in range(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
for _ in range(unet_runs_per_experiment):
orig_output = unet(*inputs)
torch.cuda.synchronize()
print(f"unet inference took {time.time() - start_time:.2f} seconds")
# λͺ¨λΈ μ μ₯
unet_traced.save("unet_traced.pt")
κ·Έ λ€μ, νμ΄νλΌμΈμ unet
νΉμ±μ λ€μκ³Ό κ°μ΄ μΆμ λ λͺ¨λΈλ‘ λ°κΏ μ μμ΅λλ€.
from diffusers import StableDiffusionPipeline
import torch
from dataclasses import dataclass
@dataclass
class UNet2DConditionOutput:
sample: torch.Tensor
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
).to("cuda")
# jitted unet μ¬μ©
unet_traced = torch.jit.load("unet_traced.pt")
# pipe.unet μμ
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = pipe.unet.config.in_channels
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
pipe.unet = TracedUNet()
with torch.inference_mode():
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
Memory-efficient attention
μ΄ν μ λΈλ‘μ λμνμ μ΅μ ννλ μ΅κ·Ό μμ μΌλ‘ GPU λ©λͺ¨λ¦¬ μ¬μ©λμ΄ ν¬κ² ν₯μλκ³ ν₯μλμμ΅λλ€. @tridaoμ κ°μ₯ μ΅κ·Όμ νλμ μ΄ν μ : code, paper.
λ°°μΉ ν¬κΈ° 1(ν둬ννΈ 1κ°)μ 512x512 ν¬κΈ°λ‘ μΆλ‘ μ μ€νν λ λͺ κ°μ§ Nvidia GPUμμ μ»μ μλ ν₯μμ λ€μκ³Ό κ°μ΅λλ€:
| GPU | κΈ°μ€ μ΄ν μ FP16 | λ©λͺ¨λ¦¬ ν¨μ¨μ μΈ μ΄ν μ FP16 | |------------------ |--------------------- |--------------------------------- | | NVIDIA Tesla T4 | 3.5it/s | 5.5it/s | | NVIDIA 3060 RTX | 4.6it/s | 7.8it/s | | NVIDIA A10G | 8.88it/s | 15.6it/s | | NVIDIA RTX A6000 | 11.7it/s | 21.09it/s | | NVIDIA TITAN RTX | 12.51it/s | 18.22it/s | | A100-SXM4-40GB | 18.6it/s | 29.it/s | | A100-SXM-80GB | 18.7it/s | 29.5it/s |
μ΄λ₯Ό νμ©νλ €λ©΄ λ€μμ λ§μ‘±ν΄μΌ ν©λλ€:
- PyTorch > 1.12
- Cuda μ¬μ© κ°λ₯
- xformers λΌμ΄λΈλ¬λ¦¬λ₯Ό μ€μΉν¨
from diffusers import StableDiffusionPipeline
import torch
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
).to("cuda")
pipe.enable_xformers_memory_efficient_attention()
with torch.inference_mode():
sample = pipe("a small cat")
# μ ν: μ΄λ₯Ό λΉνμ±ν νκΈ° μν΄ λ€μμ μ¬μ©ν μ μμ΅λλ€.
# pipe.disable_xformers_memory_efficient_attention()