stable-diffusion-demo / src /inference.py
Prgckwb's picture
:tada: change some process
6850d81
raw
history blame contribute delete
No virus
4.33 kB
import gradio as gr
import spaces
import torch
from PIL import Image
from compel import Compel, DiffusersTextualInversionManager
from diffusers import DiffusionPipeline, StableDiffusionPipeline
from diffusers.utils import make_image_grid
from src.const import DIFFUSERS_MODEL_IDS, EXTERNAL_MODEL_MAPPING, DEVICE
def load_pipeline(model_id, use_model_offload, safety_checker):
# Diffusers リポジトリ内のモデル
if model_id in DIFFUSERS_MODEL_IDS:
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
)
# CIVITAI 系列由来のモデル
else:
pipe = DiffusionPipeline.from_pretrained(
EXTERNAL_MODEL_MAPPING[model_id],
torch_dtype=torch.float16,
)
# Load Textual Inversion
pipe.load_textual_inversion("checkpoints/embeddings/BadNegAnatomyV1 neg.pt", token='BadNegAnatomyV1-neg')
pipe.load_textual_inversion("checkpoints/embeddings/Deep Negative V1 75T.pt", token='DeepNegative')
pipe.load_textual_inversion("checkpoints/embeddings/easynegative.safetensors", token='EasyNegative')
pipe.load_textual_inversion("checkpoints/embeddings/Negative Hand Embedding.pt", token='negative_hand-neg')
# Load LoRA
pipe.load_lora_weights("checkpoints/lora/detailed style SD1.5.safetensors", adapter_name='detail')
pipe.load_lora_weights("checkpoints/lora/perfection style SD1.5.safetensors", adapter_name='perfection')
pipe.load_lora_weights("checkpoints/lora/Hand v3 SD1.5.safetensors", adapter_name='hands')
pipe.set_adapters(['detail', 'hands'], adapter_weights=[0.5, 0.5])
# VRAM が少ないとき用の対策
if use_model_offload:
pipe.enable_model_cpu_offload()
else:
pipe = pipe.to(DEVICE)
if not safety_checker:
pipe.safety_checker = None
return pipe
@spaces.GPU(duration=120)
@torch.inference_mode()
def inference(
prompt: str,
model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers",
negative_prompt: str = "",
width: int = 512,
height: int = 512,
guidance_scale: float = 7.5,
num_inference_steps: int = 50,
num_images: int = 4,
safety_checker: bool = True,
use_model_offload: bool = False,
seed: int = 8888,
progress=gr.Progress(track_tqdm=True),
) -> Image.Image:
progress(0, 'Loading pipeline...')
pipe = load_pipeline(model_id, use_model_offload, safety_checker)
# Seed 固定
generator = torch.Generator(device=DEVICE).manual_seed(seed)
if isinstance(pipe, StableDiffusionPipeline):
# For Compel
textual_inversion_manager = DiffusersTextualInversionManager(pipe)
compel_procs = Compel(
tokenizer=pipe.tokenizer,
text_encoder=pipe.text_encoder,
textual_inversion_manager=textual_inversion_manager,
truncate_long_prompts=False,
)
prompt_embed = compel_procs(prompt)
negative_prompt_embed = compel_procs(negative_prompt)
prompt_embed, negative_prompt_embed = compel_procs.pad_conditioning_tensors_to_same_length(
[prompt_embed, negative_prompt_embed]
)
progress(0.3, 'Generating images...')
images = pipe(
prompt_embeds=prompt_embed,
negative_prompt_embeds=negative_prompt_embed,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images,
generator=generator,
).images
else:
progress(0.3, 'Generating images...')
images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images,
generator=generator,
).images
progress(0.9, f'Done generating {num_images} images')
if num_images % 2 == 1:
image = make_image_grid(images, rows=num_images, cols=1)
else:
image = make_image_grid(images, rows=2, cols=num_images // 2)
return image