๋ค์ํ Stable Diffusion ํฌ๋งท ๋ถ๋ฌ์ค๊ธฐ
Stable Diffusion ๋ชจ๋ธ๋ค์ ํ์ต ๋ฐ ์ ์ฅ๋ ํ๋ ์์ํฌ์ ๋ค์ด๋ก๋ ์์น์ ๋ฐ๋ผ ๋ค์ํ ํ์์ผ๋ก ์ ๊ณต๋ฉ๋๋ค. ์ด๋ฌํ ํ์์ ๐ค Diffusers์์ ์ฌ์ฉํ ์ ์๋๋ก ๋ณํํ๋ฉด ์ถ๋ก ์ ์ํ ๋ค์ํ ์ค์ผ์ค๋ฌ ์ฌ์ฉ, ์ฌ์ฉ์ ์ง์ ํ์ดํ๋ผ์ธ ๊ตฌ์ถ, ์ถ๋ก ์๋ ์ต์ ํ๋ฅผ ์ํ ๋ค์ํ ๊ธฐ๋ฒ๊ณผ ๋ฐฉ๋ฒ ๋ฑ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ์ง์ํ๋ ๋ชจ๋ ๊ธฐ๋ฅ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
์ฐ๋ฆฌ๋ .safetensors
ํ์์ ์ถ์ฒํฉ๋๋ค. ์๋ํ๋ฉด ๊ธฐ์กด์ pickled ํ์ผ์ ์ทจ์ฝํ๊ณ ๋จธ์ ์์ ์ฝ๋๋ฅผ ์คํํ ๋ ์
์ฉ๋ ์ ์๋ ๊ฒ์ ๋นํด ํจ์ฌ ๋ ์์ ํฉ๋๋ค. (safetensors ๋ถ๋ฌ์ค๊ธฐ ๊ฐ์ด๋์์ ์์ธํ ์์๋ณด์ธ์.)
์ด ๊ฐ์ด๋์์๋ ๋ค๋ฅธ Stable Diffusion ํ์์ ๐ค Diffusers์ ํธํ๋๋๋ก ๋ณํํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค.
PyTorch .ckpt
์ฒดํฌํฌ์ธํธ ๋๋ .ckpt
ํ์์ ์ผ๋ฐ์ ์ผ๋ก ๋ชจ๋ธ์ ์ ์ฅํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. .ckpt
ํ์ผ์ ์ ์ฒด ๋ชจ๋ธ์ ํฌํจํ๋ฉฐ ์ผ๋ฐ์ ์ผ๋ก ํฌ๊ธฐ๊ฐ ๋ช GB์
๋๋ค. .ckpt
ํ์ผ์ [~StableDiffusionPipeline.from_ckpt] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ง์ ๋ถ๋ฌ์์ ์ฌ์ฉํ ์๋ ์์ง๋ง, ์ผ๋ฐ์ ์ผ๋ก ๋ ๊ฐ์ง ํ์์ ๋ชจ๋ ์ฌ์ฉํ ์ ์๋๋ก .ckpt
ํ์ผ์ ๐ค Diffusers๋ก ๋ณํํ๋ ๊ฒ์ด ๋ ์ข์ต๋๋ค.
.ckpt
ํ์ผ์ ๋ณํํ๋ ๋ ๊ฐ์ง ์ต์
์ด ์์ต๋๋ค. Space๋ฅผ ์ฌ์ฉํ์ฌ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ณํํ๊ฑฐ๋ ์คํฌ๋ฆฝํธ๋ฅผ ์ฌ์ฉํ์ฌ .ckpt
ํ์ผ์ ๋ณํํฉ๋๋ค.
Space๋ก ๋ณํํ๊ธฐ
.ckpt
ํ์ผ์ ๋ณํํ๋ ๊ฐ์ฅ ์ฝ๊ณ ํธ๋ฆฌํ ๋ฐฉ๋ฒ์ SD์์ Diffusers๋ก ์คํ์ด์ค๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค. Space์ ์ง์นจ์ ๋ฐ๋ผ .ckpt ํ์ผ์ ๋ณํ ํ ์ ์์ต๋๋ค.
์ด ์ ๊ทผ ๋ฐฉ์์ ๊ธฐ๋ณธ ๋ชจ๋ธ์์๋ ์ ์๋ํ์ง๋ง ๋ ๋ง์ ์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ์์๋ ์ด๋ ค์์ ๊ฒช์ ์ ์์ต๋๋ค. ๋น pull request๋ ์ค๋ฅ๋ฅผ ๋ฐํํ๋ฉด Space๊ฐ ์คํจํ ๊ฒ์
๋๋ค.
์ด ๊ฒฝ์ฐ ์คํฌ๋ฆฝํธ๋ฅผ ์ฌ์ฉํ์ฌ .ckpt
ํ์ผ์ ๋ณํํด ๋ณผ ์ ์์ต๋๋ค.
์คํฌ๋ฆฝํธ๋ก ๋ณํํ๊ธฐ
๐ค Diffusers๋ .ckpt
ํ์ผ ๋ณํ์ ์ํ ๋ณํ ์คํฌ๋ฆฝํธ๋ฅผ ์ ๊ณตํฉ๋๋ค. ์ด ์ ๊ทผ ๋ฐฉ์์ ์์ Space๋ณด๋ค ๋ ์์ ์ ์
๋๋ค.
์์ํ๊ธฐ ์ ์ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ ๐ค Diffusers์ ๋ก์ปฌ ํด๋ก (clone)์ด ์๋์ง ํ์ธํ๊ณ Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ์ฌ pull request๋ฅผ ์ด๊ณ ๋ณํ๋ ๋ชจ๋ธ์ ํ๋ธ์ ํธ์ํ ์ ์๋๋ก ํ์ธ์.
huggingface-cli login
์คํฌ๋ฆฝํธ๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด:
- ๋ณํํ๋ ค๋
.ckpt
ํ์ผ์ด ํฌํจ๋ ๋ฆฌํฌ์งํ ๋ฆฌ๋ฅผ Git์ผ๋ก ํด๋ก (clone)ํฉ๋๋ค.
์ด ์์ ์์๋ TemporalNet .ckpt ํ์ผ์ ๋ณํํด ๋ณด๊ฒ ์ต๋๋ค:
git lfs install
git clone https://huggingface.co/CiaraRowles/TemporalNet
- ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ณํํ ๋ฆฌํฌ์งํ ๋ฆฌ์์ pull request๋ฅผ ์ฝ๋๋ค:
cd TemporalNet && git fetch origin refs/pr/13:pr/13
git checkout pr/13
- ๋ณํ ์คํฌ๋ฆฝํธ์์ ๊ตฌ์ฑํ ์ ๋ ฅ ์ธ์๋ ์ฌ๋ฌ ๊ฐ์ง๊ฐ ์์ง๋ง ๊ฐ์ฅ ์ค์ํ ์ธ์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
checkpoint_path
: ๋ณํํ.ckpt
ํ์ผ์ ๊ฒฝ๋ก๋ฅผ ์ ๋ ฅํฉ๋๋ค.original_config_file
: ์๋ ์ํคํ ์ฒ์ ๊ตฌ์ฑ์ ์ ์ํ๋ YAML ํ์ผ์ ๋๋ค. ์ด ํ์ผ์ ์ฐพ์ ์ ์๋ ๊ฒฝ์ฐ.ckpt
ํ์ผ์ ์ฐพ์ GitHub ๋ฆฌํฌ์งํ ๋ฆฌ์์ YAML ํ์ผ์ ๊ฒ์ํด ๋ณด์ธ์.dump_path
: ๋ณํ๋ ๋ชจ๋ธ์ ๊ฒฝ๋ก
์๋ฅผ ๋ค์ด, TemporalNet ๋ชจ๋ธ์ Stable Diffusion v1.5 ๋ฐ ControlNet ๋ชจ๋ธ์ด๊ธฐ ๋๋ฌธ์ ControlNet ๋ฆฌํฌ์งํ ๋ฆฌ์์ cldm_v15.yaml ํ์ผ์ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค.
- ์ด์ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ์ฌ .ckpt ํ์ผ์ ๋ณํํ ์ ์์ต๋๋ค:
python ../diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path temporalnetv3.ckpt --original_config_file cldm_v15.yaml --dump_path ./ --controlnet
- ๋ณํ์ด ์๋ฃ๋๋ฉด ๋ณํ๋ ๋ชจ๋ธ์ ์ ๋ก๋ํ๊ณ ๊ฒฐ๊ณผ๋ฌผ์ pull request pull request๋ฅผ ํ ์คํธํ์ธ์!
git push origin pr/13:refs/pr/13
Keras .pb or .h5
๐งช ์ด ๊ธฐ๋ฅ์ ์คํ์ ์ธ ๊ธฐ๋ฅ์ ๋๋ค. ํ์ฌ๋ก์๋ Stable Diffusion v1 ์ฒดํฌํฌ์ธํธ๋ง ๋ณํ KerasCV Space์์ ์ง์๋ฉ๋๋ค.
KerasCV๋ Stable Diffusion v1 ๋ฐ v2์ ๋ํ ํ์ต์ ์ง์ํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ถ๋ก ๋ฐ ๋ฐฐํฌ๋ฅผ ์ํ Stable Diffusion ๋ชจ๋ธ ์คํ์ ์ ํ์ ์ผ๋ก ์ง์ํ๋ ๋ฐ๋ฉด, ๐ค Diffusers๋ ๋ค์ํ noise schedulers, flash attention, and other optimization techniques ๋ฑ ์ด๋ฌํ ๋ชฉ์ ์ ์ํ ๋ณด๋ค ์๋ฒฝํ ๊ธฐ๋ฅ์ ๊ฐ์ถ๊ณ ์์ต๋๋ค.
Convert KerasCV Space ๋ณํ์ .pb
๋๋ .h5
์ PyTorch๋ก ๋ณํํ ๋ค์, ์ถ๋ก ํ ์ ์๋๋ก [StableDiffusionPipeline
] ์ผ๋ก ๊ฐ์ธ์ ์ค๋นํฉ๋๋ค. ๋ณํ๋ ์ฒดํฌํฌ์ธํธ๋ Hugging Face Hub์ ๋ฆฌํฌ์งํ ๋ฆฌ์ ์ ์ฅ๋ฉ๋๋ค.
์์ ๋ก, textual-inversion์ผ๋ก ํ์ต๋ [sayakpaul/textual-inversion-kerasio](https://huggingface.co/sayakpaul/textual-inversion-kerasio/tree/main)
์ฒดํฌํฌ์ธํธ๋ฅผ ๋ณํํด ๋ณด๊ฒ ์ต๋๋ค. ์ด๊ฒ์ ํน์ ํ ํฐ <my-funny-cat>
์ ์ฌ์ฉํ์ฌ ๊ณ ์์ด๋ก ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ธํํฉ๋๋ค.
KerasCV Space ๋ณํ์์๋ ๋ค์์ ์ ๋ ฅํ ์ ์์ต๋๋ค:
- Hugging Face ํ ํฐ.
- UNet ๊ณผ ํ ์คํธ ์ธ์ฝ๋(text encoder) ๊ฐ์ค์น๋ฅผ ๋ค์ด๋ก๋ํ๋ ๊ฒฝ๋ก์ ๋๋ค. ๋ชจ๋ธ์ ์ด๋ป๊ฒ ํ์ตํ ์ง ๋ฐฉ์์ ๋ฐ๋ผ, UNet๊ณผ ํ ์คํธ ์ธ์ฝ๋์ ๊ฒฝ๋ก๋ฅผ ๋ชจ๋ ์ ๊ณตํ ํ์๋ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, textual-inversion์๋ ํ ์คํธ ์ธ์ฝ๋์ ์๋ฒ ๋ฉ๋ง ํ์ํ๊ณ ํ ์คํธ-์ด๋ฏธ์ง(text-to-image) ๋ชจ๋ธ ๋ณํ์๋ UNet ๊ฐ์ค์น๋ง ํ์ํฉ๋๋ค.
- Placeholder ํ ํฐ์ textual-inversion ๋ชจ๋ธ์๋ง ์ ์ฉ๋ฉ๋๋ค.
output_repo_prefix
๋ ๋ณํ๋ ๋ชจ๋ธ์ด ์ ์ฅ๋๋ ๋ฆฌํฌ์งํ ๋ฆฌ์ ์ด๋ฆ์ ๋๋ค.
Submit (์ ์ถ) ๋ฒํผ์ ํด๋ฆญํ๋ฉด KerasCV ์ฒดํฌํฌ์ธํธ๊ฐ ์๋์ผ๋ก ๋ณํ๋ฉ๋๋ค! ์ฒดํฌํฌ์ธํธ๊ฐ ์ฑ๊ณต์ ์ผ๋ก ๋ณํ๋๋ฉด, ๋ณํ๋ ์ฒดํฌํฌ์ธํธ๊ฐ ํฌํจ๋ ์ ๋ฆฌํฌ์งํ ๋ฆฌ๋ก ์ฐ๊ฒฐ๋๋ ๋งํฌ๊ฐ ํ์๋ฉ๋๋ค. ์ ๋ฆฌํฌ์งํ ๋ฆฌ๋ก ์ฐ๊ฒฐ๋๋ ๋งํฌ๋ฅผ ๋ฐ๋ผ๊ฐ๋ฉด ๋ณํ๋ ๋ชจ๋ธ์ ์ฌ์ฉํด ๋ณผ ์ ์๋ ์ถ๋ก ์์ ฏ์ด ํฌํจ๋ ๋ชจ๋ธ ์นด๋๊ฐ ์์ฑ๋ KerasCV Space ๋ณํ์ ํ์ธํ ์ ์์ต๋๋ค.
์ฝ๋๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ก ์ ์คํํ๋ ค๋ฉด ๋ชจ๋ธ ์นด๋์ ์ค๋ฅธ์ชฝ ์๋จ ๋ชจ์๋ฆฌ์ ์๋ Use in Diffusers ๋ฒํผ์ ํด๋ฆญํ์ฌ ์์ ์ฝ๋๋ฅผ ๋ณต์ฌํ์ฌ ๋ถ์ฌ๋ฃ์ต๋๋ค:
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline")
๊ทธ๋ฌ๋ฉด ๋ค์๊ณผ ๊ฐ์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์ ์์ต๋๋ค:
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline")
pipeline.to("cuda")
placeholder_token = "<my-funny-cat-token>"
prompt = f"two {placeholder_token} getting married, photorealistic, high quality"
image = pipeline(prompt, num_inference_steps=50).images[0]
A1111 LoRA files
Automatic1111 (A1111)์ Stable Diffusion์ ์ํด ๋๋ฆฌ ์ฌ์ฉ๋๋ ์น UI๋ก, Civitai ์ ๊ฐ์ ๋ชจ๋ธ ๊ณต์ ํ๋ซํผ์ ์ง์ํฉ๋๋ค. ํนํ LoRA ๊ธฐ๋ฒ์ผ๋ก ํ์ต๋ ๋ชจ๋ธ์ ํ์ต ์๋๊ฐ ๋น ๋ฅด๊ณ ์์ ํ ํ์ธํ๋๋ ๋ชจ๋ธ๋ณด๋ค ํ์ผ ํฌ๊ธฐ๊ฐ ํจ์ฌ ์๊ธฐ ๋๋ฌธ์ ์ธ๊ธฐ๊ฐ ๋์ต๋๋ค.
๐ค Diffusers๋ [~loaders.LoraLoaderMixin.load_lora_weights
]:๋ฅผ ์ฌ์ฉํ์ฌ A1111 LoRA ์ฒดํฌํฌ์ธํธ ๋ถ๋ฌ์ค๊ธฐ๋ฅผ ์ง์ํฉ๋๋ค:
from diffusers import DiffusionPipeline, UniPCMultistepScheduler
import torch
pipeline = DiffusionPipeline.from_pretrained(
"andite/anything-v4.0", torch_dtype=torch.float16, safety_checker=None
).to("cuda")
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
Civitai์์ LoRA ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ค์ด๋ก๋ํ์ธ์; ์ด ์์ ์์๋ Howls Moving Castle,Interior/Scenery LoRA (Ghibli Stlye) ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ์ง๋ง, ์ด๋ค LoRA ์ฒดํฌํฌ์ธํธ๋ ์์ ๋กญ๊ฒ ์ฌ์ฉํด ๋ณด์ธ์!
!wget https://civitai.com/api/download/models/19998 -O howls_moving_castle.safetensors
๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ํ์ดํ๋ผ์ธ์ LoRA ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์ต๋๋ค:
pipeline.load_lora_weights(".", weight_name="howls_moving_castle.safetensors")
์ด์ ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์ ์์ต๋๋ค:
prompt = "masterpiece, illustration, ultra-detailed, cityscape, san francisco, golden gate bridge, california, bay area, in the snow, beautiful detailed starry sky"
negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture"
images = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=512,
height=512,
num_inference_steps=25,
num_images_per_prompt=4,
generator=torch.manual_seed(0),
).images
๋ง์ง๋ง์ผ๋ก, ๋์คํ๋ ์ด์ ์ด๋ฏธ์ง๋ฅผ ํ์ํ๋ ํฌํผ ํจ์๋ฅผ ๋ง๋ญ๋๋ค:
from PIL import Image
def image_grid(imgs, rows=2, cols=2):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
image_grid(images)