|
<!--Copyright 2024 The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
--> |
|
|
|
# λ©λͺ¨λ¦¬μ μλ |
|
|
|
λ©λͺ¨λ¦¬ λλ μλμ λν΄ π€ Diffusers *μΆλ‘ *μ μ΅μ ννκΈ° μν λͺ κ°μ§ κΈ°μ κ³Ό μμ΄λμ΄λ₯Ό μ μν©λλ€. |
|
μΌλ°μ μΌλ‘, memory-efficient attentionμ μν΄ [xFormers](https://github.com/facebookresearch/xformers) μ¬μ©μ μΆμ²νκΈ° λλ¬Έμ, μΆμ²νλ [μ€μΉ λ°©λ²](xformers)μ λ³΄κ³ μ€μΉν΄ 보μΈμ. |
|
|
|
λ€μ μ€μ μ΄ μ±λ₯κ³Ό λ©λͺ¨λ¦¬μ λ―ΈμΉλ μν₯μ λν΄ μ€λͺ
ν©λλ€. |
|
|
|
| | μ§μ°μκ° | μλ ν₯μ | |
|
| ---------------- | ------- | ------- | |
|
| λ³λ μ€μ μμ | 9.50s | x1 | |
|
| cuDNN auto-tuner | 9.37s | x1.01 | |
|
| fp16 | 3.61s | x2.63 | |
|
| Channels Last λ©λͺ¨λ¦¬ νμ | 3.30s | x2.88 | |
|
| traced UNet | 3.21s | x2.96 | |
|
| memory-efficient attention | 2.63s | x3.61 | |
|
|
|
<em> |
|
NVIDIA TITAN RTXμμ 50 DDIM μ€ν
μ "a photo of an astronaut riding a horse on mars" ν둬ννΈλ‘ 512x512 ν¬κΈ°μ λ¨μΌ μ΄λ―Έμ§λ₯Ό μμ±νμμ΅λλ€. |
|
</em> |
|
|
|
## cuDNN auto-tuner νμ±ννκΈ° |
|
|
|
[NVIDIA cuDNN](https://developer.nvidia.com/cudnn)μ 컨볼루μ
μ κ³μ°νλ λ§μ μκ³ λ¦¬μ¦μ μ§μν©λλ€. Autotunerλ 짧μ λ²€μΉλ§ν¬λ₯Ό μ€ννκ³ μ£Όμ΄μ§ μ
λ ₯ ν¬κΈ°μ λν΄ μ£Όμ΄μ§ νλμ¨μ΄μμ μ΅κ³ μ μ±λ₯μ κ°μ§ 컀λμ μ νν©λλ€. |
|
|
|
**컨볼루μ
λ€νΈμν¬**λ₯Ό νμ©νκ³ μκΈ° λλ¬Έμ (λ€λ₯Έ μ νλ€μ νμ¬ μ§μλμ§ μμ), λ€μ μ€μ μ ν΅ν΄ μΆλ‘ μ μ cuDNN autotunerλ₯Ό νμ±νν μ μμ΅λλ€: |
|
|
|
```python |
|
import torch |
|
|
|
torch.backends.cudnn.benchmark = True |
|
``` |
|
|
|
### fp32 λμ tf32 μ¬μ©νκΈ° (Ampere λ° μ΄ν CUDA μ₯μΉλ€μμ) |
|
|
|
Ampere λ° μ΄ν CUDA μ₯μΉμμ νλ ¬κ³± λ° μ»¨λ³Όλ£¨μ
μ TensorFloat32(TF32) λͺ¨λλ₯Ό μ¬μ©νμ¬ λ λΉ λ₯΄μ§λ§ μ½κ° λ μ νν μ μμ΅λλ€. |
|
κΈ°λ³Έμ μΌλ‘ PyTorchλ 컨볼루μ
μ λν΄ TF32 λͺ¨λλ₯Ό νμ±ννμ§λ§ νλ ¬ κ³±μ
μ νμ±ννμ§ μμ΅λλ€. |
|
λ€νΈμν¬μ μμ ν float32 μ λ°λκ° νμν κ²½μ°κ° μλλ©΄ νλ ¬ κ³±μ
μ λν΄μλ μ΄ μ€μ μ νμ±ννλ κ²μ΄ μ’μ΅λλ€. |
|
μ΄λ μΌλ°μ μΌλ‘ 무μν μ μλ μμΉμ μ νλ μμ€μ΄ μμ§λ§, κ³μ° μλλ₯Ό ν¬κ² λμΌ μ μμ΅λλ€. |
|
κ·Έκ²μ λν΄ [μ¬κΈ°](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32)μ λ μ½μ μ μμ΅λλ€. |
|
μΆλ‘ νκΈ° μ μ λ€μμ μΆκ°νκΈ°λ§ νλ©΄ λ©λλ€: |
|
|
|
```python |
|
import torch |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
``` |
|
|
|
## λ°μ λ°λ κ°μ€μΉ |
|
|
|
λ λ§μ GPU λ©λͺ¨λ¦¬λ₯Ό μ μ½νκ³ λ λΉ λ₯Έ μλλ₯Ό μ»κΈ° μν΄ λͺ¨λΈ κ°μ€μΉλ₯Ό λ°μ λ°λ(half precision)λ‘ μ§μ λΆλ¬μ€κ³ μ€νν μ μμ΅λλ€. |
|
μ¬κΈ°μλ `fp16`μ΄λΌλ λΈλμΉμ μ μ₯λ float16 λ²μ μ κ°μ€μΉλ₯Ό λΆλ¬μ€κ³ , κ·Έ λ `float16` μ νμ μ¬μ©νλλ‘ PyTorchμ μ§μνλ μμ
μ΄ ν¬ν¨λ©λλ€. |
|
|
|
```Python |
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
|
|
torch_dtype=torch.float16, |
|
) |
|
pipe = pipe.to("cuda") |
|
|
|
prompt = "a photo of an astronaut riding a horse on mars" |
|
image = pipe(prompt).images[0] |
|
``` |
|
|
|
<Tip warning={true}> |
|
μ΄λ€ νμ΄νλΌμΈμμλ [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) λ₯Ό μ¬μ©νλ κ²μ κ²μμ μ΄λ―Έμ§λ₯Ό μμ±ν μ μκ³ , μμν float16 μ λ°λλ₯Ό μ¬μ©νλ κ²λ³΄λ€ νμ λ리기 λλ¬Έμ μ¬μ©νμ§ μλ κ²μ΄ μ’μ΅λλ€. |
|
</Tip> |
|
|
|
## μΆκ° λ©λͺ¨λ¦¬ μ μ½μ μν μ¬λΌμ΄μ€ μ΄ν
μ
|
|
|
|
μΆκ° λ©λͺ¨λ¦¬ μ μ½μ μν΄, ν λ²μ λͺ¨λ κ³μ°νλ λμ λ¨κ³μ μΌλ‘ κ³μ°μ μννλ μ¬λΌμ΄μ€ λ²μ μ μ΄ν
μ
(attention)μ μ¬μ©ν μ μμ΅λλ€. |
|
|
|
<Tip> |
|
Attention slicingμ λͺ¨λΈμ΄ νλ μ΄μμ μ΄ν
μ
ν€λλ₯Ό μ¬μ©νλ ν, λ°°μΉ ν¬κΈ°κ° 1μΈ κ²½μ°μλ μ μ©ν©λλ€. |
|
νλ μ΄μμ μ΄ν
μ
ν€λκ° μλ κ²½μ° *QK^T* μ΄ν
μ
맀νΈλ¦μ€λ μλΉν μμ λ©λͺ¨λ¦¬λ₯Ό μ μ½ν μ μλ κ° ν€λμ λν΄ μμ°¨μ μΌλ‘ κ³μ°λ μ μμ΅λλ€. |
|
</Tip> |
|
|
|
κ° ν€λμ λν΄ μμ°¨μ μΌλ‘ μ΄ν
μ
κ³μ°μ μννλ €λ©΄, λ€μκ³Ό κ°μ΄ μΆλ‘ μ μ νμ΄νλΌμΈμμ [`~StableDiffusionPipeline.enable_attention_slicing`]λ₯Ό νΈμΆνλ©΄ λ©λλ€: |
|
|
|
```Python |
|
import torch |
|
from diffusers import StableDiffusionPipeline |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
|
|
torch_dtype=torch.float16, |
|
) |
|
pipe = pipe.to("cuda") |
|
|
|
prompt = "a photo of an astronaut riding a horse on mars" |
|
pipe.enable_attention_slicing() |
|
image = pipe(prompt).images[0] |
|
``` |
|
|
|
μΆλ‘ μκ°μ΄ μ½ 10% λλ €μ§λ μ½κ°μ μ±λ₯ μ νκ° μμ§λ§ μ΄ λ°©λ²μ μ¬μ©νλ©΄ 3.2GB μ λμ μμ VRAMμΌλ‘λ Stable Diffusionμ μ¬μ©ν μ μμ΅λλ€! |
|
|
|
|
|
## λ ν° λ°°μΉλ₯Ό μν sliced VAE λμ½λ |
|
|
|
μ νλ VRAMμμ λκ·λͺ¨ μ΄λ―Έμ§ λ°°μΉλ₯Ό λμ½λ©νκ±°λ 32κ° μ΄μμ μ΄λ―Έμ§κ° ν¬ν¨λ λ°°μΉλ₯Ό νμ±ννκΈ° μν΄, λ°°μΉμ latent μ΄λ―Έμ§λ₯Ό ν λ²μ νλμ© λμ½λ©νλ μ¬λΌμ΄μ€ VAE λμ½λλ₯Ό μ¬μ©ν μ μμ΅λλ€. |
|
|
|
μ΄λ₯Ό [`~StableDiffusionPipeline.enable_attention_slicing`] λλ [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`]κ³Ό κ²°ν©νμ¬ λ©λͺ¨λ¦¬ μ¬μ©μ μΆκ°λ‘ μ΅μνν μ μμ΅λλ€. |
|
|
|
VAE λμ½λλ₯Ό ν λ²μ νλμ© μννλ €λ©΄ μΆλ‘ μ μ νμ΄νλΌμΈμμ [`~StableDiffusionPipeline.enable_vae_slicing`]μ νΈμΆν©λλ€. μλ₯Ό λ€μ΄: |
|
|
|
```Python |
|
import torch |
|
from diffusers import StableDiffusionPipeline |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
|
|
torch_dtype=torch.float16, |
|
) |
|
pipe = pipe.to("cuda") |
|
|
|
prompt = "a photo of an astronaut riding a horse on mars" |
|
pipe.enable_vae_slicing() |
|
images = pipe([prompt] * 32).images |
|
``` |
|
|
|
λ€μ€ μ΄λ―Έμ§ λ°°μΉμμ VAE λμ½λκ° μ½κ°μ μ±λ₯ ν₯μμ΄ μ΄λ£¨μ΄μ§λλ€. λ¨μΌ μ΄λ―Έμ§ λ°°μΉμμλ μ±λ₯ μν₯μ μμ΅λλ€. |
|
|
|
|
|
<a name="sequential_offloading"></a> |
|
## λ©λͺ¨λ¦¬ μ μ½μ μν΄ κ°μ κΈ°λ₯μ μ¬μ©νμ¬ CPUλ‘ μ€νλ‘λ© |
|
|
|
μΆκ° λ©λͺ¨λ¦¬ μ μ½μ μν΄ κ°μ€μΉλ₯Ό CPUλ‘ μ€νλ‘λνκ³ μλ°©ν₯ μ λ¬μ μνν λλ§ GPUλ‘ λ‘λν μ μμ΅λλ€. |
|
|
|
CPU μ€νλ‘λ©μ μννλ €λ©΄ [`~StableDiffusionPipeline.enable_sequential_cpu_offload`]λ₯Ό νΈμΆνκΈ°λ§ νλ©΄ λ©λλ€: |
|
|
|
```Python |
|
import torch |
|
from diffusers import StableDiffusionPipeline |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
|
|
torch_dtype=torch.float16, |
|
) |
|
|
|
prompt = "a photo of an astronaut riding a horse on mars" |
|
pipe.enable_sequential_cpu_offload() |
|
image = pipe(prompt).images[0] |
|
``` |
|
|
|
κ·Έλ¬λ©΄ λ©λͺ¨λ¦¬ μλΉλ₯Ό 3GB λ―Έλ§μΌλ‘ μ€μΌ μ μμ΅λλ€. |
|
|
|
μ°Έκ³ λ‘ μ΄ λ°©λ²μ μ 체 λͺ¨λΈμ΄ μλ μλΈλͺ¨λ μμ€μμ μλν©λλ€. μ΄λ λ©λͺ¨λ¦¬ μλΉλ₯Ό μ΅μννλ κ°μ₯ μ’μ λ°©λ²μ΄μ§λ§ νλ‘μΈμ€μ λ°λ³΅μ νΉμ±μΌλ‘ μΈν΄ μΆλ‘ μλκ° ν¨μ¬ λ립λλ€. νμ΄νλΌμΈμ UNet κ΅¬μ± μμλ μ¬λ¬ λ² μ€νλ©λλ€('num_inference_steps' λ§νΌ). λ§€λ² UNetμ μλ‘ λ€λ₯Έ μλΈλͺ¨λμ΄ μμ°¨μ μΌλ‘ μ¨λ‘λλ λ€μ νμμ λ°λΌ μ€νλ‘λλλ―λ‘ λ©λͺ¨λ¦¬ μ΄λ νμκ° λ§μ΅λλ€. |
|
|
|
<Tip> |
|
λ λ€λ₯Έ μ΅μ ν λ°©λ²μΈ <a href="#model_offloading">λͺ¨λΈ μ€νλ‘λ©</a>μ μ¬μ©νλ κ²μ κ³ λ €νμμμ€. μ΄λ ν¨μ¬ λΉ λ₯΄μ§λ§ λ©λͺ¨λ¦¬ μ μ½μ΄ ν¬μ§λ μμ΅λλ€. |
|
</Tip> |
|
|
|
λν ttention slicingκ³Ό μ°κ²°ν΄μ μ΅μ λ©λͺ¨λ¦¬(< 2GB)λ‘λ λμν μ μμ΅λλ€. |
|
|
|
|
|
```Python |
|
import torch |
|
from diffusers import StableDiffusionPipeline |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
|
|
torch_dtype=torch.float16, |
|
) |
|
|
|
prompt = "a photo of an astronaut riding a horse on mars" |
|
pipe.enable_sequential_cpu_offload() |
|
pipe.enable_attention_slicing(1) |
|
|
|
image = pipe(prompt).images[0] |
|
``` |
|
|
|
**μ°Έκ³ **: 'enable_sequential_cpu_offload()'λ₯Ό μ¬μ©ν λ, 미리 νμ΄νλΌμΈμ CUDAλ‘ μ΄λνμ§ **μλ** κ²μ΄ μ€μν©λλ€.κ·Έλ μ§ μμΌλ©΄ λ©λͺ¨λ¦¬ μλΉμ μ΄λμ΄ μ΅μνλ©λλ€. λ λ§μ μ 보λ₯Ό μν΄ [μ΄ μ΄μ](https://github.com/huggingface/diffusers/issues/1934)λ₯Ό 보μΈμ. |
|
|
|
<a name="model_offloading"></a> |
|
## λΉ λ₯Έ μΆλ‘ κ³Ό λ©λͺ¨λ¦¬ λ©λͺ¨λ¦¬ μ μ½μ μν λͺ¨λΈ μ€νλ‘λ© |
|
|
|
[μμ°¨μ CPU μ€νλ‘λ©](#sequential_offloading)μ μ΄μ μΉμ
μμ μ€λͺ
ν κ²μ²λΌ λ§μ λ©λͺ¨λ¦¬λ₯Ό 보쑴νμ§λ§ νμμ λ°λΌ μλΈλͺ¨λμ GPUλ‘ μ΄λνκ³ μ λͺ¨λμ΄ μ€νλ λ μ¦μ CPUλ‘ λ°νλκΈ° λλ¬Έμ μΆλ‘ μλκ° λλ €μ§λλ€. |
|
|
|
μ 체 λͺ¨λΈ μ€νλ‘λ©μ κ° λͺ¨λΈμ κ΅¬μ± μμμΈ _modules_μ μ²λ¦¬νλ λμ , μ 체 λͺ¨λΈμ GPUλ‘ μ΄λνλ λμμ
λλ€. μ΄λ‘ μΈν΄ μΆλ‘ μκ°μ λ―ΈμΉλ μν₯μ λ―Έλ―Ένμ§λ§(νμ΄νλΌμΈμ 'cuda'λ‘ μ΄λνλ κ²κ³Ό λΉκ΅νμ¬) μ¬μ ν μ½κ°μ λ©λͺ¨λ¦¬λ₯Ό μ μ½ν μ μμ΅λλ€. |
|
|
|
μ΄ μλ리μ€μμλ νμ΄νλΌμΈμ μ£Όμ κ΅¬μ± μμ μ€ νλλ§(μΌλ°μ μΌλ‘ ν
μ€νΈ μΈμ½λ, unet λ° vae) GPUμ μκ³ , λλ¨Έμ§λ CPUμμ λκΈ°ν κ²μ
λλ€. |
|
μ¬λ¬ λ°λ³΅μ μν΄ μ€νλλ UNetκ³Ό κ°μ κ΅¬μ± μμλ λ μ΄μ νμνμ§ μμ λκΉμ§ GPUμ λ¨μ μμ΅λλ€. |
|
|
|
μ΄ κΈ°λ₯μ μλμ κ°μ΄ νμ΄νλΌμΈμμ `enable_model_cpu_offload()`λ₯Ό νΈμΆνμ¬ νμ±νν μ μμ΅λλ€. |
|
|
|
```Python |
|
import torch |
|
from diffusers import StableDiffusionPipeline |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
prompt = "a photo of an astronaut riding a horse on mars" |
|
pipe.enable_model_cpu_offload() |
|
image = pipe(prompt).images[0] |
|
``` |
|
|
|
μ΄λ μΆκ°μ μΈ λ©λͺ¨λ¦¬ μ μ½μ μν attention slicingκ³Όλ νΈνλ©λλ€. |
|
|
|
```Python |
|
import torch |
|
from diffusers import StableDiffusionPipeline |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
prompt = "a photo of an astronaut riding a horse on mars" |
|
pipe.enable_model_cpu_offload() |
|
pipe.enable_attention_slicing(1) |
|
|
|
image = pipe(prompt).images[0] |
|
``` |
|
|
|
<Tip> |
|
μ΄ κΈ°λ₯μ μ¬μ©νλ €λ©΄ 'accelerate' λ²μ 0.17.0 μ΄μμ΄ νμν©λλ€. |
|
</Tip> |
|
|
|
## Channels Last λ©λͺ¨λ¦¬ νμ μ¬μ©νκΈ° |
|
|
|
Channels Last λ©λͺ¨λ¦¬ νμμ μ°¨μ μμλ₯Ό 보쑴νλ λ©λͺ¨λ¦¬μμ NCHW ν
μ λ°°μ΄μ λ체νλ λ°©λ²μ
λλ€. |
|
Channels Last ν
μλ μ±λμ΄ κ°μ₯ μ‘°λ°ν μ°¨μμ΄ λλ λ°©μμΌλ‘ μ λ ¬λ©λλ€(μΌλͺ
ν½μ
λΉ μ΄λ―Έμ§λ₯Ό μ μ₯). |
|
νμ¬ λͺ¨λ μ°μ°μ Channels Last νμμ μ§μνλ κ²μ μλλΌ μ±λ₯μ΄ μ νλ μ μμΌλ―λ‘, μ¬μ©ν΄λ³΄κ³ λͺ¨λΈμ μ μλνλμ§ νμΈνλ κ²μ΄ μ’μ΅λλ€. |
|
|
|
|
|
μλ₯Ό λ€μ΄ νμ΄νλΌμΈμ UNet λͺ¨λΈμ΄ channels Last νμμ μ¬μ©νλλ‘ μ€μ νλ €λ©΄ λ€μμ μ¬μ©ν μ μμ΅λλ€: |
|
|
|
```python |
|
print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1) |
|
pipe.unet.to(memory_format=torch.channels_last) # in-place μ°μ° |
|
# 2λ²μ§Έ μ°¨μμμ μ€νΈλΌμ΄λ 1μ κ°μ§λ (2880, 1, 960, 320)λ‘, μ°μ°μ΄ μλν¨μ μ¦λͺ
ν©λλ€. |
|
print(pipe.unet.conv_out.state_dict()["weight"].stride()) |
|
``` |
|
|
|
## μΆμ (tracing) |
|
|
|
μΆμ μ λͺ¨λΈμ ν΅ν΄ μμ μ
λ ₯ ν
μλ₯Ό ν΅ν΄ μ€νλλλ°, ν΄λΉ μ
λ ₯μ΄ λͺ¨λΈμ λ μ΄μ΄λ₯Ό ν΅κ³Όν λ νΈμΆλλ μμ
μ μΊ‘μ²νμ¬ μ€ν νμΌ λλ 'ScriptFunction'μ΄ λ°νλλλ‘ νκ³ , μ΄λ just-in-time μ»΄νμΌλ‘ μ΅μ νλ©λλ€. |
|
|
|
UNet λͺ¨λΈμ μΆμ νκΈ° μν΄ λ€μμ μ¬μ©ν μ μμ΅λλ€: |
|
|
|
```python |
|
import time |
|
import torch |
|
from diffusers import StableDiffusionPipeline |
|
import functools |
|
|
|
# torch κΈ°μΈκΈ° λΉνμ±ν |
|
torch.set_grad_enabled(False) |
|
|
|
# λ³μ μ€μ |
|
n_experiments = 2 |
|
unet_runs_per_experiment = 50 |
|
|
|
|
|
# μ
λ ₯ λΆλ¬μ€κΈ° |
|
def generate_inputs(): |
|
sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16) |
|
timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999 |
|
encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16) |
|
return sample, timestep, encoder_hidden_states |
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
unet = pipe.unet |
|
unet.eval() |
|
unet.to(memory_format=torch.channels_last) # Channels Last λ©λͺ¨λ¦¬ νμ μ¬μ© |
|
unet.forward = functools.partial(unet.forward, return_dict=False) # return_dict=Falseμ κΈ°λ³Έκ°μΌλ‘ μ€μ |
|
|
|
# μλ°μ
|
|
for _ in range(3): |
|
with torch.inference_mode(): |
|
inputs = generate_inputs() |
|
orig_output = unet(*inputs) |
|
|
|
# μΆμ |
|
print("tracing..") |
|
unet_traced = torch.jit.trace(unet, inputs) |
|
unet_traced.eval() |
|
print("done tracing") |
|
|
|
|
|
# μλ°μ
λ° κ·Έλν μ΅μ ν |
|
for _ in range(5): |
|
with torch.inference_mode(): |
|
inputs = generate_inputs() |
|
orig_output = unet_traced(*inputs) |
|
|
|
|
|
# λ²€μΉλ§νΉ |
|
with torch.inference_mode(): |
|
for _ in range(n_experiments): |
|
torch.cuda.synchronize() |
|
start_time = time.time() |
|
for _ in range(unet_runs_per_experiment): |
|
orig_output = unet_traced(*inputs) |
|
torch.cuda.synchronize() |
|
print(f"unet traced inference took {time.time() - start_time:.2f} seconds") |
|
for _ in range(n_experiments): |
|
torch.cuda.synchronize() |
|
start_time = time.time() |
|
for _ in range(unet_runs_per_experiment): |
|
orig_output = unet(*inputs) |
|
torch.cuda.synchronize() |
|
print(f"unet inference took {time.time() - start_time:.2f} seconds") |
|
|
|
# λͺ¨λΈ μ μ₯ |
|
unet_traced.save("unet_traced.pt") |
|
``` |
|
|
|
κ·Έ λ€μ, νμ΄νλΌμΈμ `unet` νΉμ±μ λ€μκ³Ό κ°μ΄ μΆμ λ λͺ¨λΈλ‘ λ°κΏ μ μμ΅λλ€. |
|
|
|
```python |
|
from diffusers import StableDiffusionPipeline |
|
import torch |
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
class UNet2DConditionOutput: |
|
sample: torch.Tensor |
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
|
|
# jitted unet μ¬μ© |
|
unet_traced = torch.jit.load("unet_traced.pt") |
|
|
|
|
|
# pipe.unet μμ |
|
class TracedUNet(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.in_channels = pipe.unet.config.in_channels |
|
self.device = pipe.unet.device |
|
|
|
def forward(self, latent_model_input, t, encoder_hidden_states): |
|
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0] |
|
return UNet2DConditionOutput(sample=sample) |
|
|
|
|
|
pipe.unet = TracedUNet() |
|
|
|
with torch.inference_mode(): |
|
image = pipe([prompt] * 1, num_inference_steps=50).images[0] |
|
``` |
|
|
|
|
|
## Memory-efficient attention |
|
|
|
μ΄ν
μ
λΈλ‘μ λμνμ μ΅μ ννλ μ΅κ·Ό μμ
μΌλ‘ GPU λ©λͺ¨λ¦¬ μ¬μ©λμ΄ ν¬κ² ν₯μλκ³ ν₯μλμμ΅λλ€. |
|
@tridaoμ κ°μ₯ μ΅κ·Όμ νλμ μ΄ν
μ
: [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf). |
|
|
|
λ°°μΉ ν¬κΈ° 1(ν둬ννΈ 1κ°)μ 512x512 ν¬κΈ°λ‘ μΆλ‘ μ μ€νν λ λͺ κ°μ§ Nvidia GPUμμ μ»μ μλ ν₯μμ λ€μκ³Ό κ°μ΅λλ€: |
|
|
|
| GPU | κΈ°μ€ μ΄ν
μ
FP16 | λ©λͺ¨λ¦¬ ν¨μ¨μ μΈ μ΄ν
μ
FP16 | |
|
|------------------ |--------------------- |--------------------------------- | |
|
| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s | |
|
| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s | |
|
| NVIDIA A10G | 8.88it/s | 15.6it/s | |
|
| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s | |
|
| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s | |
|
| A100-SXM4-40GB | 18.6it/s | 29.it/s | |
|
| A100-SXM-80GB | 18.7it/s | 29.5it/s | |
|
|
|
μ΄λ₯Ό νμ©νλ €λ©΄ λ€μμ λ§μ‘±ν΄μΌ ν©λλ€: |
|
- PyTorch > 1.12 |
|
- Cuda μ¬μ© κ°λ₯ |
|
- [xformers λΌμ΄λΈλ¬λ¦¬λ₯Ό μ€μΉν¨](xformers) |
|
```python |
|
from diffusers import StableDiffusionPipeline |
|
import torch |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
|
|
pipe.enable_xformers_memory_efficient_attention() |
|
|
|
with torch.inference_mode(): |
|
sample = pipe("a small cat") |
|
|
|
# μ ν: μ΄λ₯Ό λΉνμ±ν νκΈ° μν΄ λ€μμ μ¬μ©ν μ μμ΅λλ€. |
|
# pipe.disable_xformers_memory_efficient_attention() |
|
``` |
|
|