File size: 4,326 Bytes
90deeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6850d81
90deeeb
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import spaces
import torch
from PIL import Image
from compel import Compel, DiffusersTextualInversionManager
from diffusers import DiffusionPipeline, StableDiffusionPipeline
from diffusers.utils import make_image_grid

from src.const import DIFFUSERS_MODEL_IDS, EXTERNAL_MODEL_MAPPING, DEVICE


def load_pipeline(model_id, use_model_offload, safety_checker):
    # Diffusers リポジトリ内のモデル
    if model_id in DIFFUSERS_MODEL_IDS:
        pipe = DiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
        )

    # CIVITAI 系列由来のモデル
    else:
        pipe = DiffusionPipeline.from_pretrained(
            EXTERNAL_MODEL_MAPPING[model_id],
            torch_dtype=torch.float16,
        )

        # Load Textual Inversion
        pipe.load_textual_inversion("checkpoints/embeddings/BadNegAnatomyV1 neg.pt", token='BadNegAnatomyV1-neg')
        pipe.load_textual_inversion("checkpoints/embeddings/Deep Negative V1 75T.pt", token='DeepNegative')
        pipe.load_textual_inversion("checkpoints/embeddings/easynegative.safetensors", token='EasyNegative')
        pipe.load_textual_inversion("checkpoints/embeddings/Negative Hand Embedding.pt", token='negative_hand-neg')

        # Load LoRA
        pipe.load_lora_weights("checkpoints/lora/detailed style SD1.5.safetensors", adapter_name='detail')
        pipe.load_lora_weights("checkpoints/lora/perfection style SD1.5.safetensors", adapter_name='perfection')
        pipe.load_lora_weights("checkpoints/lora/Hand v3 SD1.5.safetensors", adapter_name='hands')
        pipe.set_adapters(['detail', 'hands'], adapter_weights=[0.5, 0.5])

    # VRAM が少ないとき用の対策
    if use_model_offload:
        pipe.enable_model_cpu_offload()
    else:
        pipe = pipe.to(DEVICE)

    if not safety_checker:
        pipe.safety_checker = None

    return pipe


@spaces.GPU(duration=120)
@torch.inference_mode()
def inference(
        prompt: str,
        model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers",
        negative_prompt: str = "",
        width: int = 512,
        height: int = 512,
        guidance_scale: float = 7.5,
        num_inference_steps: int = 50,
        num_images: int = 4,
        safety_checker: bool = True,
        use_model_offload: bool = False,
        seed: int = 8888,
        progress=gr.Progress(track_tqdm=True),
) -> Image.Image:
    progress(0, 'Loading pipeline...')
    pipe = load_pipeline(model_id, use_model_offload, safety_checker)

    # Seed 固定
    generator = torch.Generator(device=DEVICE).manual_seed(seed)

    if isinstance(pipe, StableDiffusionPipeline):
        # For Compel
        textual_inversion_manager = DiffusersTextualInversionManager(pipe)
        compel_procs = Compel(
            tokenizer=pipe.tokenizer,
            text_encoder=pipe.text_encoder,
            textual_inversion_manager=textual_inversion_manager,
            truncate_long_prompts=False,
        )
        prompt_embed = compel_procs(prompt)
        negative_prompt_embed = compel_procs(negative_prompt)

        prompt_embed, negative_prompt_embed = compel_procs.pad_conditioning_tensors_to_same_length(
            [prompt_embed, negative_prompt_embed]
        )

        progress(0.3, 'Generating images...')
        images = pipe(
            prompt_embeds=prompt_embed,
            negative_prompt_embeds=negative_prompt_embed,
            width=width,
            height=height,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            num_images_per_prompt=num_images,
            generator=generator,
        ).images
    else:
        progress(0.3, 'Generating images...')
        images = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            width=width,
            height=height,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            num_images_per_prompt=num_images,
            generator=generator,
        ).images

    progress(0.9, f'Done generating {num_images} images')
    if num_images % 2 == 1:
        image = make_image_grid(images, rows=num_images, cols=1)
    else:
        image = make_image_grid(images, rows=2, cols=num_images // 2)

    return image