Spaces:
Runtime error
Runtime error
aded files without images
Browse files- .gitignore +1 -0
- README.md +7 -5
- app.py +1179 -0
- checkpoints/put_checkpoint_models_here.txt +0 -0
- data.py +252 -0
- examples/prompt_background.txt +8 -0
- examples/prompt_background_advanced.txt +0 -0
- examples/prompt_boy.txt +15 -0
- examples/prompt_girl.txt +16 -0
- examples/prompt_props.txt +43 -0
- model.py +1212 -0
- prompt_util.py +154 -0
- requirements.txt +13 -0
- timer/LICENSE_timer.txt +21 -0
- timer/index.html +95 -0
- timer/script.js +73 -0
- timer/style.css +504 -0
- util.py +315 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*/.ipynb_checkpoints
|
README.md
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
---
|
2 |
title: StreamMultiDiffusion
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: mit
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: StreamMultiDiffusion
|
3 |
+
emoji: 🦦🦦🦦🦦
|
4 |
+
colorFrom: #feecd6
|
5 |
+
colorTo: #732a14
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.26.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: mit
|
11 |
+
models:
|
12 |
+
- KBlueLeaf/kohaku-v2.1
|
13 |
---
|
14 |
|
15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,1179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
import sys
|
22 |
+
|
23 |
+
sys.path.append('../../src')
|
24 |
+
|
25 |
+
import argparse
|
26 |
+
import random
|
27 |
+
import time
|
28 |
+
import json
|
29 |
+
import os
|
30 |
+
import glob
|
31 |
+
import pathlib
|
32 |
+
from functools import partial
|
33 |
+
from pprint import pprint
|
34 |
+
|
35 |
+
import numpy as np
|
36 |
+
from PIL import Image
|
37 |
+
import torch
|
38 |
+
|
39 |
+
import gradio as gr
|
40 |
+
from huggingface_hub import snapshot_download
|
41 |
+
|
42 |
+
# from model import StreamMultiDiffusionSDXL
|
43 |
+
from model import StreamMultiDiffusion
|
44 |
+
from util import seed_everything
|
45 |
+
from prompt_util import preprocess_prompts, _quality_dict, _style_dict
|
46 |
+
|
47 |
+
|
48 |
+
### Utils
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
def log_state(state):
|
54 |
+
pprint(vars(opt))
|
55 |
+
if isinstance(state, gr.State):
|
56 |
+
state = state.value
|
57 |
+
pprint(vars(state))
|
58 |
+
|
59 |
+
|
60 |
+
def is_empty_image(im: Image.Image) -> bool:
|
61 |
+
if im is None:
|
62 |
+
return True
|
63 |
+
im = np.array(im)
|
64 |
+
has_alpha = (im.shape[2] == 4)
|
65 |
+
if not has_alpha:
|
66 |
+
return False
|
67 |
+
elif im.sum() == 0:
|
68 |
+
return True
|
69 |
+
else:
|
70 |
+
return False
|
71 |
+
|
72 |
+
|
73 |
+
### Argument passing
|
74 |
+
|
75 |
+
# parser = argparse.ArgumentParser(description='Semantic Palette demo powered by StreamMultiDiffusion with SDXL support.')
|
76 |
+
# parser.add_argument('-H', '--height', type=int, default=1024)
|
77 |
+
# parser.add_argument('-W', '--width', type=int, default=1024)
|
78 |
+
parser = argparse.ArgumentParser(description='Semantic Palette demo powered by StreamMultiDiffusion.')
|
79 |
+
parser.add_argument('-H', '--height', type=int, default=768)
|
80 |
+
parser.add_argument('-W', '--width', type=int, default=768)
|
81 |
+
parser.add_argument('--model', type=str, default=None, help='Hugging face model repository or local path for a SD1.5 model checkpoint to run.')
|
82 |
+
parser.add_argument('--bootstrap_steps', type=int, default=1)
|
83 |
+
parser.add_argument('--guidance_scale', type=float, default=0) # 1.2
|
84 |
+
parser.add_argument('--run_time', type=float, default=60)
|
85 |
+
parser.add_argument('--seed', type=int, default=-1)
|
86 |
+
parser.add_argument('--device', type=int, default=0)
|
87 |
+
parser.add_argument('--port', type=int, default=8000)
|
88 |
+
opt = parser.parse_args()
|
89 |
+
|
90 |
+
|
91 |
+
### Global variables and data structures
|
92 |
+
|
93 |
+
device = f'cuda:{opt.device}' if opt.device >= 0 else 'cpu'
|
94 |
+
|
95 |
+
|
96 |
+
if opt.model is None:
|
97 |
+
# opt.model = 'cagliostrolab/animagine-xl-3.1'
|
98 |
+
# opt.model = 'ironjr/BlazingDriveV11m'
|
99 |
+
opt.model = 'KBlueLeaf/kohaku-v2.1'
|
100 |
+
else:
|
101 |
+
if opt.model.endswith('.safetensors'):
|
102 |
+
opt.model = os.path.abspath(os.path.join('checkpoints', opt.model))
|
103 |
+
|
104 |
+
# model = StreamMultiDiffusionSDXL(
|
105 |
+
model = StreamMultiDiffusion(
|
106 |
+
device,
|
107 |
+
hf_key=opt.model,
|
108 |
+
height=opt.height,
|
109 |
+
width=opt.width,
|
110 |
+
cfg_type="full",
|
111 |
+
autoflush=True,
|
112 |
+
use_tiny_vae=True,
|
113 |
+
mask_type='continuous',
|
114 |
+
bootstrap_steps=opt.bootstrap_steps,
|
115 |
+
bootstrap_mix_steps=opt.bootstrap_steps,
|
116 |
+
guidance_scale=opt.guidance_scale,
|
117 |
+
seed=opt.seed,
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
prompt_suggestions = [
|
122 |
+
'1girl, souryuu asuka langley, neon genesis evangelion, solo, upper body, v, smile, looking at viewer',
|
123 |
+
'1boy, solo, portrait, looking at viewer, white t-shirt, brown hair',
|
124 |
+
'1girl, arima kana, oshi no ko, solo, upper body, from behind',
|
125 |
+
]
|
126 |
+
|
127 |
+
opt.max_palettes = 3
|
128 |
+
opt.default_prompt_strength = 1.0
|
129 |
+
opt.default_mask_strength = 1.0
|
130 |
+
opt.default_mask_std = 0.0
|
131 |
+
opt.default_negative_prompt = (
|
132 |
+
'nsfw, worst quality, bad quality, normal quality, cropped, framed'
|
133 |
+
)
|
134 |
+
opt.verbose = True
|
135 |
+
opt.colors = [
|
136 |
+
'#000000',
|
137 |
+
'#2692F3',
|
138 |
+
'#F89E12',
|
139 |
+
'#16C232',
|
140 |
+
# '#F92F6C',
|
141 |
+
# '#AC6AEB',
|
142 |
+
# '#92C62C',
|
143 |
+
# '#92C6EC',
|
144 |
+
# '#FECAC0',
|
145 |
+
]
|
146 |
+
|
147 |
+
|
148 |
+
### Event handlers
|
149 |
+
|
150 |
+
def add_palette(state):
|
151 |
+
old_actives = state.active_palettes
|
152 |
+
state.active_palettes = min(state.active_palettes + 1, opt.max_palettes)
|
153 |
+
|
154 |
+
if opt.verbose:
|
155 |
+
log_state(state)
|
156 |
+
|
157 |
+
if state.active_palettes != old_actives:
|
158 |
+
return [state] + [
|
159 |
+
gr.update() if state.active_palettes != opt.max_palettes else gr.update(visible=False)
|
160 |
+
] + [
|
161 |
+
gr.update() if i != state.active_palettes - 1 else gr.update(value=state.prompt_names[i + 1], visible=True)
|
162 |
+
for i in range(opt.max_palettes)
|
163 |
+
]
|
164 |
+
else:
|
165 |
+
return [state] + [gr.update() for i in range(opt.max_palettes + 1)]
|
166 |
+
|
167 |
+
|
168 |
+
def select_palette(state, button, idx):
|
169 |
+
if idx < 0 or idx > opt.max_palettes:
|
170 |
+
idx = 0
|
171 |
+
old_idx = state.current_palette
|
172 |
+
if old_idx == idx:
|
173 |
+
return [state] + [gr.update() for _ in range(opt.max_palettes + 7)]
|
174 |
+
|
175 |
+
state.current_palette = idx
|
176 |
+
|
177 |
+
if opt.verbose:
|
178 |
+
log_state(state)
|
179 |
+
|
180 |
+
updates = [state] + [
|
181 |
+
gr.update() if i not in (idx, old_idx) else
|
182 |
+
gr.update(variant='secondary') if i == old_idx else gr.update(variant='primary')
|
183 |
+
for i in range(opt.max_palettes + 1)
|
184 |
+
]
|
185 |
+
label = 'Background' if idx == 0 else f'Palette {idx}'
|
186 |
+
updates.extend([
|
187 |
+
gr.update(value=button, interactive=(idx > 0)),
|
188 |
+
gr.update(value=state.prompts[idx], label=f'Edit Prompt for {label}'),
|
189 |
+
gr.update(value=state.neg_prompts[idx], label=f'Edit Negative Prompt for {label}'),
|
190 |
+
(
|
191 |
+
gr.update(value=state.mask_strengths[idx - 1], interactive=True) if idx > 0 else
|
192 |
+
gr.update(value=opt.default_mask_strength, interactive=False)
|
193 |
+
),
|
194 |
+
(
|
195 |
+
gr.update(value=state.prompt_strengths[idx - 1], interactive=True) if idx > 0 else
|
196 |
+
gr.update(value=opt.default_prompt_strength, interactive=False)
|
197 |
+
),
|
198 |
+
(
|
199 |
+
gr.update(value=state.mask_stds[idx - 1], interactive=True) if idx > 0 else
|
200 |
+
gr.update(value=opt.default_mask_std, interactive=False)
|
201 |
+
),
|
202 |
+
])
|
203 |
+
return updates
|
204 |
+
|
205 |
+
|
206 |
+
def change_prompt_strength(state, strength):
|
207 |
+
if state.current_palette == 0:
|
208 |
+
return state
|
209 |
+
|
210 |
+
state.prompt_strengths[state.current_palette - 1] = strength
|
211 |
+
if opt.verbose:
|
212 |
+
log_state(state)
|
213 |
+
|
214 |
+
return state
|
215 |
+
|
216 |
+
|
217 |
+
def change_std(state, std):
|
218 |
+
if state.current_palette == 0:
|
219 |
+
return state
|
220 |
+
|
221 |
+
state.mask_stds[state.current_palette - 1] = std
|
222 |
+
if opt.verbose:
|
223 |
+
log_state(state)
|
224 |
+
|
225 |
+
return state
|
226 |
+
|
227 |
+
|
228 |
+
def change_mask_strength(state, strength):
|
229 |
+
if state.current_palette == 0:
|
230 |
+
return state
|
231 |
+
|
232 |
+
state.mask_strengths[state.current_palette - 1] = strength
|
233 |
+
if opt.verbose:
|
234 |
+
log_state(state)
|
235 |
+
|
236 |
+
return state
|
237 |
+
|
238 |
+
|
239 |
+
def reset_seed(state, seed):
|
240 |
+
state.seed = seed
|
241 |
+
if opt.verbose:
|
242 |
+
log_state(state)
|
243 |
+
|
244 |
+
return state
|
245 |
+
|
246 |
+
|
247 |
+
def rename_prompt(state, name):
|
248 |
+
state.prompt_names[state.current_palette] = name
|
249 |
+
if opt.verbose:
|
250 |
+
log_state(state)
|
251 |
+
|
252 |
+
return [state] + [
|
253 |
+
gr.update() if i != state.current_palette else gr.update(value=name)
|
254 |
+
for i in range(opt.max_palettes + 1)
|
255 |
+
]
|
256 |
+
|
257 |
+
|
258 |
+
def change_prompt(state, prompt):
|
259 |
+
state.prompts[state.current_palette] = prompt
|
260 |
+
if opt.verbose:
|
261 |
+
log_state(state)
|
262 |
+
|
263 |
+
return state
|
264 |
+
|
265 |
+
|
266 |
+
def change_neg_prompt(state, neg_prompt):
|
267 |
+
state.neg_prompts[state.current_palette] = neg_prompt
|
268 |
+
if opt.verbose:
|
269 |
+
log_state(state)
|
270 |
+
|
271 |
+
return state
|
272 |
+
|
273 |
+
|
274 |
+
# def select_style(state, style_name):
|
275 |
+
# state.style_name = style_name
|
276 |
+
# if opt.verbose:
|
277 |
+
# log_state(state)
|
278 |
+
|
279 |
+
# return state
|
280 |
+
|
281 |
+
|
282 |
+
# def select_quality(state, quality_name):
|
283 |
+
# state.quality_name = quality_name
|
284 |
+
# if opt.verbose:
|
285 |
+
# log_state(state)
|
286 |
+
|
287 |
+
# return state
|
288 |
+
|
289 |
+
|
290 |
+
def import_state(state, json_text):
|
291 |
+
current_palette = state.current_palette
|
292 |
+
# active_palettes = state.active_palettes
|
293 |
+
state_dict = json.loads(json_text)
|
294 |
+
for k in ('inpainting_mode', 'is_runing', 'active_palettes', 'current_palette'):
|
295 |
+
if k in state_dict:
|
296 |
+
del state_dict[k]
|
297 |
+
state = argparse.Namespace(**state_dict)
|
298 |
+
state.active_palettes = opt.max_palettes
|
299 |
+
return [state] + [
|
300 |
+
gr.update(value=v, visible=True) for v in state.prompt_names
|
301 |
+
] + [
|
302 |
+
# state.style_name,
|
303 |
+
# state.quality_name,
|
304 |
+
state.prompts[current_palette],
|
305 |
+
state.prompt_names[current_palette],
|
306 |
+
state.neg_prompts[current_palette],
|
307 |
+
state.prompt_strengths[current_palette - 1],
|
308 |
+
state.mask_strengths[current_palette - 1],
|
309 |
+
state.mask_stds[current_palette - 1],
|
310 |
+
state.seed,
|
311 |
+
]
|
312 |
+
|
313 |
+
|
314 |
+
### Main worker
|
315 |
+
|
316 |
+
def generate():
|
317 |
+
return model()
|
318 |
+
|
319 |
+
|
320 |
+
def register(state, drawpad):
|
321 |
+
seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
|
322 |
+
print('Generate!')
|
323 |
+
|
324 |
+
background = drawpad['background'].convert('RGBA')
|
325 |
+
inpainting_mode = np.asarray(background).sum() != 0
|
326 |
+
if not inpainting_mode:
|
327 |
+
background = Image.new(size=(opt.width, opt.height), mode='RGB', color=(255, 255, 255))
|
328 |
+
print('Inpainting mode: ', inpainting_mode)
|
329 |
+
|
330 |
+
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
331 |
+
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
332 |
+
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
333 |
+
|
334 |
+
palette = torch.tensor([
|
335 |
+
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
336 |
+
for s in opt.colors[1:]
|
337 |
+
]) # (N, 3)
|
338 |
+
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
339 |
+
# has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
340 |
+
has_masks = list(range(opt.max_palettes))
|
341 |
+
print('Has mask: ', has_masks)
|
342 |
+
masks = masks * foreground_mask
|
343 |
+
masks = masks[has_masks]
|
344 |
+
|
345 |
+
# if inpainting_mode:
|
346 |
+
prompts = [state.prompts[v + 1] for v in has_masks]
|
347 |
+
negative_prompts = [state.neg_prompts[v + 1] for v in has_masks]
|
348 |
+
mask_strengths = [state.mask_strengths[v] for v in has_masks]
|
349 |
+
mask_stds = [state.mask_stds[v] for v in has_masks]
|
350 |
+
prompt_strengths = [state.prompt_strengths[v] for v in has_masks]
|
351 |
+
# else:
|
352 |
+
# masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
|
353 |
+
# prompts = [state.prompts[0]] + [state.prompts[v + 1] for v in has_masks]
|
354 |
+
# negative_prompts = [state.neg_prompts[0]] + [state.neg_prompts[v + 1] for v in has_masks]
|
355 |
+
# mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
356 |
+
# mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
357 |
+
# prompt_strengths = [1] + [state.prompt_strengths[v] for v in has_masks]
|
358 |
+
|
359 |
+
# prompts, negative_prompts = preprocess_prompts(
|
360 |
+
# prompts, negative_prompts, style_name=state.style_name, quality_name=state.quality_name)
|
361 |
+
|
362 |
+
model.update_background(
|
363 |
+
background.convert('RGB'),
|
364 |
+
prompt=None,
|
365 |
+
negative_prompt=None,
|
366 |
+
)
|
367 |
+
state.prompts[0] = model.background.prompt
|
368 |
+
state.neg_prompts[0] = model.background.negative_prompt
|
369 |
+
|
370 |
+
model.update_layers(
|
371 |
+
prompts=prompts,
|
372 |
+
negative_prompts=negative_prompts,
|
373 |
+
masks=masks.to(device),
|
374 |
+
mask_strengths=mask_strengths,
|
375 |
+
mask_stds=mask_stds,
|
376 |
+
prompt_strengths=prompt_strengths,
|
377 |
+
)
|
378 |
+
|
379 |
+
state.inpainting_mode = inpainting_mode
|
380 |
+
return state
|
381 |
+
|
382 |
+
|
383 |
+
def run(state, drawpad):
|
384 |
+
state = register(state, drawpad)
|
385 |
+
state.is_running = True
|
386 |
+
|
387 |
+
tic = time.time()
|
388 |
+
while True:
|
389 |
+
yield [state, generate()]
|
390 |
+
toc = time.time()
|
391 |
+
tdelta = toc - tic
|
392 |
+
if tdelta > opt.run_time:
|
393 |
+
state.is_running = False
|
394 |
+
return [state, generate()]
|
395 |
+
|
396 |
+
|
397 |
+
def hide_element():
|
398 |
+
return gr.update(visible=False)
|
399 |
+
|
400 |
+
|
401 |
+
def show_element():
|
402 |
+
return gr.update(visible=True)
|
403 |
+
|
404 |
+
|
405 |
+
def draw(state, drawpad):
|
406 |
+
if not state.is_running:
|
407 |
+
return
|
408 |
+
|
409 |
+
user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
|
410 |
+
foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
|
411 |
+
user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
|
412 |
+
|
413 |
+
palette = torch.tensor([
|
414 |
+
tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
|
415 |
+
for s in opt.colors[1:]
|
416 |
+
]) # (N, 3)
|
417 |
+
masks = (palette[:, None, None, :] == user_input[None]).all(dim=-1)[:, None, ...] # (N, 1, H, W)
|
418 |
+
# has_masks = [i for i, m in enumerate(masks.sum(dim=(1, 2, 3)) == 0) if not m]
|
419 |
+
has_masks = list(range(opt.max_palettes))
|
420 |
+
print('Has mask: ', has_masks)
|
421 |
+
masks = masks * foreground_mask
|
422 |
+
masks = masks[has_masks]
|
423 |
+
|
424 |
+
# if state.inpainting_mode:
|
425 |
+
mask_strengths = [state.mask_strengths[v] for v in has_masks]
|
426 |
+
mask_stds = [state.mask_stds[v] for v in has_masks]
|
427 |
+
# else:
|
428 |
+
# masks = torch.cat([torch.ones_like(foreground_mask), masks], dim=0)
|
429 |
+
# mask_strengths = [1] + [state.mask_strengths[v] for v in has_masks]
|
430 |
+
# mask_stds = [0] + [state.mask_stds[v] for v in has_masks]
|
431 |
+
|
432 |
+
for i in range(len(has_masks)):
|
433 |
+
model.update_single_layer(
|
434 |
+
idx=i,
|
435 |
+
mask=masks[i],
|
436 |
+
mask_strength=mask_strengths[i],
|
437 |
+
mask_std=mask_stds[i],
|
438 |
+
)
|
439 |
+
|
440 |
+
### Load examples
|
441 |
+
|
442 |
+
|
443 |
+
root = pathlib.Path(__file__).parent
|
444 |
+
print(root)
|
445 |
+
example_root = os.path.join(root, 'examples')
|
446 |
+
example_images = glob.glob(os.path.join(example_root, '*.png'))
|
447 |
+
example_images = [Image.open(i) for i in example_images]
|
448 |
+
|
449 |
+
# with open(os.path.join(example_root, 'prompt_background_advanced.txt')) as f:
|
450 |
+
# prompts_background = [l.strip() for l in f.readlines() if l.strip() != '']
|
451 |
+
|
452 |
+
# with open(os.path.join(example_root, 'prompt_girl.txt')) as f:
|
453 |
+
# prompts_girl = [l.strip() for l in f.readlines() if l.strip() != '']
|
454 |
+
|
455 |
+
# with open(os.path.join(example_root, 'prompt_boy.txt')) as f:
|
456 |
+
# prompts_boy = [l.strip() for l in f.readlines() if l.strip() != '']
|
457 |
+
|
458 |
+
# with open(os.path.join(example_root, 'prompt_props.txt')) as f:
|
459 |
+
# prompts_props = [l.strip() for l in f.readlines() if l.strip() != '']
|
460 |
+
# prompts_props = {l.split(',')[0].strip(): ','.join(l.split(',')[1:]).strip() for l in prompts_props}
|
461 |
+
|
462 |
+
# prompt_background = lambda: random.choice(prompts_background)
|
463 |
+
# prompt_girl = lambda: random.choice(prompts_girl)
|
464 |
+
# prompt_boy = lambda: random.choice(prompts_boy)
|
465 |
+
# prompt_props = lambda: np.random.choice(list(prompts_props.keys()), size=(opt.max_palettes - 2), replace=False).tolist()
|
466 |
+
|
467 |
+
|
468 |
+
### Main application
|
469 |
+
|
470 |
+
css = f"""
|
471 |
+
#run-button {{
|
472 |
+
font-size: 18pt;
|
473 |
+
background-image: linear-gradient(to right, #4338ca 0%, #26a0da 51%, #4338ca 100%);
|
474 |
+
margin: 0;
|
475 |
+
padding: 15px 45px;
|
476 |
+
text-align: center;
|
477 |
+
// text-transform: uppercase;
|
478 |
+
transition: 0.5s;
|
479 |
+
background-size: 200% auto;
|
480 |
+
color: white;
|
481 |
+
box-shadow: 0 0 20px #eee;
|
482 |
+
border-radius: 10px;
|
483 |
+
// display: block;
|
484 |
+
background-position: right center;
|
485 |
+
}}
|
486 |
+
|
487 |
+
#run-button:hover {{
|
488 |
+
background-position: left center;
|
489 |
+
color: #fff;
|
490 |
+
text-decoration: none;
|
491 |
+
}}
|
492 |
+
|
493 |
+
#run-anim {{
|
494 |
+
padding: 40px 45px;
|
495 |
+
}}
|
496 |
+
|
497 |
+
#semantic-palette {{
|
498 |
+
border-style: solid;
|
499 |
+
border-width: 0.2em;
|
500 |
+
border-color: #eee;
|
501 |
+
}}
|
502 |
+
|
503 |
+
#semantic-palette:hover {{
|
504 |
+
box-shadow: 0 0 20px #eee;
|
505 |
+
}}
|
506 |
+
|
507 |
+
#output-screen {{
|
508 |
+
width: 100%;
|
509 |
+
aspect-ratio: {opt.width} / {opt.height};
|
510 |
+
}}
|
511 |
+
|
512 |
+
.layer-wrap {{
|
513 |
+
display: none;
|
514 |
+
}}
|
515 |
+
"""
|
516 |
+
|
517 |
+
for i in range(opt.max_palettes + 1):
|
518 |
+
css = css + f"""
|
519 |
+
.secondary#semantic-palette-{i} {{
|
520 |
+
background-image: linear-gradient(to right, #374151 0%, #374151 71%, {opt.colors[i]} 100%);
|
521 |
+
color: white;
|
522 |
+
}}
|
523 |
+
|
524 |
+
.primary#semantic-palette-{i} {{
|
525 |
+
background-image: linear-gradient(to right, #4338ca 0%, #4338ca 71%, {opt.colors[i]} 100%);
|
526 |
+
color: white;
|
527 |
+
}}
|
528 |
+
"""
|
529 |
+
|
530 |
+
css = css + f"""
|
531 |
+
|
532 |
+
.mask-red {{
|
533 |
+
left: 0;
|
534 |
+
width: 0;
|
535 |
+
color: #BE002A;
|
536 |
+
-webkit-animation: text-red {opt.run_time:.1f}s ease infinite;
|
537 |
+
animation: text-red {opt.run_time:.1f}s ease infinite;
|
538 |
+
z-index: 2;
|
539 |
+
background: transparent;
|
540 |
+
}}
|
541 |
+
.mask-white {{
|
542 |
+
right: 0;
|
543 |
+
}}
|
544 |
+
|
545 |
+
/* Flames */
|
546 |
+
|
547 |
+
#red-flame {{
|
548 |
+
opacity: 0;
|
549 |
+
-webkit-animation: show-flames {opt.run_time:.1f}s ease infinite, red-flame 120ms ease infinite;
|
550 |
+
animation: show-flames {opt.run_time:.1f}s ease infinite, red-flame 120ms ease infinite;
|
551 |
+
transform-origin: center bottom;
|
552 |
+
}}
|
553 |
+
|
554 |
+
#yellow-flame {{
|
555 |
+
opacity: 0;
|
556 |
+
-webkit-animation: show-flames {opt.run_time:.1f}s ease infinite, yellow-flame 120ms ease infinite;
|
557 |
+
animation: show-flames {opt.run_time:.1f}s ease infinite, yellow-flame 120ms ease infinite;
|
558 |
+
transform-origin: center bottom;
|
559 |
+
}}
|
560 |
+
|
561 |
+
#white-flame {{
|
562 |
+
opacity: 0;
|
563 |
+
-webkit-animation: show-flames {opt.run_time:.1f}s ease infinite, red-flame 100ms ease infinite;
|
564 |
+
animation: show-flames {opt.run_time:.1f}s ease infinite, red-flame 100ms ease infinite;
|
565 |
+
transform-origin: center bottom;
|
566 |
+
}}
|
567 |
+
"""
|
568 |
+
|
569 |
+
with open(os.path.join(root, 'timer', 'style.css')) as f:
|
570 |
+
added_css = ''.join(f.readlines())
|
571 |
+
css = css + added_css
|
572 |
+
|
573 |
+
# js = ''
|
574 |
+
|
575 |
+
# with open(os.path.join(root, 'timer', 'script.js')) as f:
|
576 |
+
# added_js = ''.join(f.readlines())
|
577 |
+
# js = js + added_js
|
578 |
+
|
579 |
+
head = f"""
|
580 |
+
<link href='https://fonts.googleapis.com/css?family=Oswald' rel='stylesheet' type='text/css'>
|
581 |
+
<script src='https://code.jquery.com/jquery-2.2.4.min.js'></script>
|
582 |
+
"""
|
583 |
+
|
584 |
+
|
585 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=css, head=head) as demo:
|
586 |
+
|
587 |
+
iface = argparse.Namespace()
|
588 |
+
|
589 |
+
def _define_state():
|
590 |
+
state = argparse.Namespace()
|
591 |
+
|
592 |
+
# Cursor.
|
593 |
+
state.is_running = False
|
594 |
+
state.inpainting_mode = False
|
595 |
+
state.current_palette = 0 # 0: Background; 1,2,3,...: Layers
|
596 |
+
state.model_id = opt.model
|
597 |
+
state.style_name = '(None)'
|
598 |
+
state.quality_name = 'Standard v3.1'
|
599 |
+
|
600 |
+
# State variables (one-hot).
|
601 |
+
state.active_palettes = 5
|
602 |
+
|
603 |
+
# Front-end initialized to the default values.
|
604 |
+
# prompt_props_ = prompt_props()
|
605 |
+
state.prompt_names = [
|
606 |
+
'🌄 Background',
|
607 |
+
'👧 Girl',
|
608 |
+
'🐶 Dog',
|
609 |
+
'💐 Garden',
|
610 |
+
] + ['🎨 New Palette' for _ in range(opt.max_palettes - 3)]
|
611 |
+
state.prompts = [
|
612 |
+
'',
|
613 |
+
'A girl smiling at viewer',
|
614 |
+
'Doggy body part',
|
615 |
+
'Flower garden',
|
616 |
+
] + ['' for _ in range(opt.max_palettes - 3)]
|
617 |
+
state.neg_prompts = [
|
618 |
+
opt.default_negative_prompt
|
619 |
+
+ (', humans, humans, humans' if i == 0 else '')
|
620 |
+
for i in range(opt.max_palettes + 1)
|
621 |
+
]
|
622 |
+
state.prompt_strengths = [opt.default_prompt_strength for _ in range(opt.max_palettes)]
|
623 |
+
state.mask_strengths = [opt.default_mask_strength for _ in range(opt.max_palettes)]
|
624 |
+
state.mask_stds = [opt.default_mask_std for _ in range(opt.max_palettes)]
|
625 |
+
state.seed = opt.seed
|
626 |
+
return state
|
627 |
+
|
628 |
+
state = gr.State(value=_define_state)
|
629 |
+
|
630 |
+
|
631 |
+
### Demo user interface
|
632 |
+
|
633 |
+
gr.HTML(
|
634 |
+
"""
|
635 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
636 |
+
<div>
|
637 |
+
<h1>🦦🦦 StreamMultiDiffusion: Real-Time Interactive Generation with Region-Based Semantic Control 🦦🦦</h1>
|
638 |
+
<h5 style="margin: 0;">If you ❤️ our project, please visit our Github and give us a 🌟!</h5>
|
639 |
+
</br>
|
640 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
641 |
+
<a href='https://arxiv.org/abs/2403.09055'>
|
642 |
+
<img src="https://img.shields.io/badge/arXiv-2403.09055-red">
|
643 |
+
</a>
|
644 |
+
|
645 |
+
<a href='https://jaerinlee.com/research/StreamMultiDiffusion'>
|
646 |
+
<img src='https://img.shields.io/badge/Project-Page-green' alt='Project Page'>
|
647 |
+
</a>
|
648 |
+
|
649 |
+
<a href='https://github.com/ironjr/StreamMultiDiffusion'>
|
650 |
+
<img src='https://img.shields.io/github/stars/ironjr/StreamMultiDiffusion?label=Github&color=blue'>
|
651 |
+
</a>
|
652 |
+
|
653 |
+
<a href='https://twitter.com/_ironjr_'>
|
654 |
+
<img src='https://img.shields.io/twitter/url?label=_ironjr_&url=https%3A%2F%2Ftwitter.com%2F_ironjr_'>
|
655 |
+
</a>
|
656 |
+
|
657 |
+
<a href='https://github.com/ironjr/StreamMultiDiffusion/blob/main/LICENSE'>
|
658 |
+
<img src='https://img.shields.io/badge/license-MIT-lightgrey'>
|
659 |
+
</a>
|
660 |
+
|
661 |
+
<a href='https://huggingface.co/papers/2403.09055'>
|
662 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Paper-StreamMultiDiffusion-yellow'>
|
663 |
+
</a>
|
664 |
+
|
665 |
+
<a href='https://huggingface.co/spaces/ironjr/StreamMultiDiffusion'>
|
666 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-StreamMultiDiffusion-yellow'>
|
667 |
+
</a>
|
668 |
+
|
669 |
+
<a href='https://huggingface.co/spaces/ironjr/SemanticPalette'>
|
670 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SemanticPaletteSD1.5-yellow'>
|
671 |
+
</a>
|
672 |
+
|
673 |
+
<a href='https://huggingface.co/spaces/ironjr/SemanticPaletteXL'>
|
674 |
+
<img src='https://img.shields.io/badge/%F0%9F%A4%97%20Demo-SemanticPaletteSDXL-yellow'>
|
675 |
+
</a>
|
676 |
+
|
677 |
+
<a href='https://colab.research.google.com/github/camenduru/SemanticPalette-jupyter/blob/main/SemanticPalette_jupyter.ipynb'>
|
678 |
+
<img src='https://colab.research.google.com/assets/colab-badge.svg'>
|
679 |
+
</a>
|
680 |
+
</div>
|
681 |
+
</div>
|
682 |
+
</div>
|
683 |
+
<div>
|
684 |
+
</br>
|
685 |
+
</div>
|
686 |
+
"""
|
687 |
+
)
|
688 |
+
|
689 |
+
with gr.Row():
|
690 |
+
|
691 |
+
with gr.Column(scale=1):
|
692 |
+
|
693 |
+
with gr.Group(elem_id='semantic-palette'):
|
694 |
+
|
695 |
+
gr.HTML(
|
696 |
+
"""
|
697 |
+
<div style="justify-content: center; align-items: center;">
|
698 |
+
<br/>
|
699 |
+
<h3 style="margin: 0; text-align: center;"><b>🧠 Semantic Palette 🎨</b></h3>
|
700 |
+
<br/>
|
701 |
+
</div>
|
702 |
+
"""
|
703 |
+
)
|
704 |
+
|
705 |
+
iface.btn_semantics = [gr.Button(
|
706 |
+
value=state.value.prompt_names[0],
|
707 |
+
variant='primary',
|
708 |
+
elem_id='semantic-palette-0',
|
709 |
+
)]
|
710 |
+
for i in range(opt.max_palettes):
|
711 |
+
iface.btn_semantics.append(gr.Button(
|
712 |
+
value=state.value.prompt_names[i + 1],
|
713 |
+
variant='secondary',
|
714 |
+
visible=(i < state.value.active_palettes),
|
715 |
+
elem_id=f'semantic-palette-{i + 1}'
|
716 |
+
))
|
717 |
+
|
718 |
+
iface.btn_add_palette = gr.Button(
|
719 |
+
value='Create New Semantic Brush',
|
720 |
+
variant='primary',
|
721 |
+
visible=(state.value.active_palettes < opt.max_palettes),
|
722 |
+
)
|
723 |
+
|
724 |
+
with gr.Accordion(label='Import/Export Semantic Palette', open=True):
|
725 |
+
iface.tbox_state_import = gr.Textbox(label='Put Palette JSON Here To Import')
|
726 |
+
iface.json_state_export = gr.JSON(label='Exported Palette')
|
727 |
+
iface.btn_export_state = gr.Button("Export Palette ➡️ JSON", variant='primary')
|
728 |
+
iface.btn_import_state = gr.Button("Import JSON ➡️ Palette", variant='secondary')
|
729 |
+
|
730 |
+
gr.HTML(
|
731 |
+
"""
|
732 |
+
<div>
|
733 |
+
</br>
|
734 |
+
</div>
|
735 |
+
<div style="justify-content: center; align-items: center;">
|
736 |
+
<h3 style="margin: 0; text-align: center;"><b>❓Usage❓</b></h3>
|
737 |
+
</br>
|
738 |
+
<div style="justify-content: center; align-items: left; text-align: left;">
|
739 |
+
<p>1-1. Type in the background prompt. Background is not required if you paint the whole drawpad.</p>
|
740 |
+
<p>1-2. (Optional: <em><b>Inpainting mode</b></em>) Uploading a background image will make the app into inpainting mode. Removing the image returns to the creation mode. In the inpainting mode, increasing the <em>Mask Blur STD</em> > 8 for every colored palette is recommended for smooth boundaries.</p>
|
741 |
+
<p>2. Select a semantic brush by clicking onto one in the <b>Semantic Palette</b> above. Edit prompt for the semantic brush.</p>
|
742 |
+
<p>2-1. If you are willing to draw more diverse images, try <b>Create New Semantic Brush</b>.</p>
|
743 |
+
<p>3. Start drawing in the <b>Semantic Drawpad</b> tab. The brush color is directly linked to the semantic brushes.</p>
|
744 |
+
<p>4. Click [<b>GENERATE!</b>] button to create your (large-scale) artwork!</p>
|
745 |
+
</div>
|
746 |
+
</div>
|
747 |
+
"""
|
748 |
+
)
|
749 |
+
|
750 |
+
gr.HTML(
|
751 |
+
"""
|
752 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
753 |
+
<h5 style="margin: 0;"><b>... or run in your own 🤗 space!</b></h5>
|
754 |
+
</div>
|
755 |
+
"""
|
756 |
+
)
|
757 |
+
|
758 |
+
gr.DuplicateButton()
|
759 |
+
|
760 |
+
with gr.Column(scale=4):
|
761 |
+
|
762 |
+
with gr.Row():
|
763 |
+
|
764 |
+
with gr.Column(scale=2):
|
765 |
+
|
766 |
+
iface.ctrl_semantic = gr.ImageEditor(
|
767 |
+
image_mode='RGBA',
|
768 |
+
sources=['upload', 'clipboard', 'webcam'],
|
769 |
+
transforms=['crop'],
|
770 |
+
crop_size=(opt.width, opt.height),
|
771 |
+
brush=gr.Brush(
|
772 |
+
colors=opt.colors[1:],
|
773 |
+
color_mode="fixed",
|
774 |
+
),
|
775 |
+
type='pil',
|
776 |
+
label='Semantic Drawpad',
|
777 |
+
elem_id='drawpad',
|
778 |
+
)
|
779 |
+
|
780 |
+
# with gr.Accordion(label='Prompt Engineering', open=False):
|
781 |
+
# iface.quality_select = gr.Dropdown(
|
782 |
+
# label='Quality Presets',
|
783 |
+
# interactive=True,
|
784 |
+
# choices=list(_quality_dict.keys()),
|
785 |
+
# value='Standard v3.1',
|
786 |
+
# )
|
787 |
+
|
788 |
+
# iface.style_select = gr.Radio(
|
789 |
+
# label='Style Preset',
|
790 |
+
# container=True,
|
791 |
+
# interactive=True,
|
792 |
+
# choices=list(_style_dict.keys()),
|
793 |
+
# value='(None)',
|
794 |
+
# )
|
795 |
+
|
796 |
+
with gr.Column(scale=2):
|
797 |
+
|
798 |
+
iface.image_slot = gr.Image(
|
799 |
+
interactive=False,
|
800 |
+
show_label=False,
|
801 |
+
show_download_button=True,
|
802 |
+
type='pil',
|
803 |
+
label='Generated Result',
|
804 |
+
elem_id='output-screen',
|
805 |
+
value=lambda: random.choice(example_images),
|
806 |
+
)
|
807 |
+
|
808 |
+
iface.btn_generate = gr.Button(
|
809 |
+
value=f'Lemme try! ({int(opt.run_time // 60)} min)',
|
810 |
+
variant='primary',
|
811 |
+
# scale=1,
|
812 |
+
elem_id='run-button'
|
813 |
+
)
|
814 |
+
|
815 |
+
iface.run_animation = gr.HTML(
|
816 |
+
f"""
|
817 |
+
<div id="deadline">
|
818 |
+
<svg preserveAspectRatio="none" id="line" viewBox="0 0 581 158" enable-background="new 0 0 581 158">
|
819 |
+
<g id="fire">
|
820 |
+
<rect id="mask-fire-black" x="511" y="41" width="38" height="34"/>
|
821 |
+
<g>
|
822 |
+
<defs>
|
823 |
+
<rect id="mask_fire" x="511" y="41" width="38" height="34"/>
|
824 |
+
</defs>
|
825 |
+
<clipPath id="mask-fire_1_">
|
826 |
+
<use xlink:href="#mask_fire" overflow="visible"/>
|
827 |
+
</clipPath>
|
828 |
+
<g id="group-fire" clip-path="url(#mask-fire_1_)">
|
829 |
+
<path id="red-flame" fill="#B71342" d="M528.377,100.291c6.207,0,10.947-3.272,10.834-8.576 c-0.112-5.305-2.934-8.803-8.237-10.383c-5.306-1.581-3.838-7.9-0.79-9.707c-7.337,2.032-7.581,5.891-7.11,8.238 c0.789,3.951,7.56,4.402,5.077,9.48c-2.482,5.079-8.012,1.129-6.319-2.257c-2.843,2.233-4.78,6.681-2.259,9.703 C521.256,98.809,524.175,100.291,528.377,100.291z"/>
|
830 |
+
<path id="yellow-flame" opacity="0.71" fill="#F7B523" d="M528.837,100.291c4.197,0,5.108-1.854,5.974-5.417 c0.902-3.724-1.129-6.207-5.305-9.931c-2.396-2.137-1.581-4.176-0.565-6.32c-4.401,1.918-3.384,5.304-2.482,6.658 c1.511,2.267,2.099,2.364,0.42,5.8c-1.679,3.435-5.42,0.764-4.275-1.527c-1.921,1.512-2.373,4.04-1.528,6.563 C522.057,99.051,525.994,100.291,528.837,100.291z"/>
|
831 |
+
<path id="white-flame" opacity="0.81" fill="#FFFFFF" d="M529.461,100.291c-2.364,0-4.174-1.322-4.129-3.469 c0.04-2.145,1.117-3.56,3.141-4.198c2.022-0.638,1.463-3.195,0.302-3.925c2.798,0.821,2.89,2.382,2.711,3.332 c-0.301,1.597-2.883,1.779-1.938,3.834c0.912,1.975,3.286,0.938,2.409-0.913c1.086,0.903,1.826,2.701,0.864,3.924 C532.18,99.691,531.064,100.291,529.461,100.291z"/>
|
832 |
+
</g>
|
833 |
+
</g>
|
834 |
+
</g>
|
835 |
+
<g id="progress-trail">
|
836 |
+
<path fill="#FFFFFF" d="M491.979,83.878c1.215-0.73-0.62-5.404-3.229-11.044c-2.583-5.584-5.034-10.066-7.229-8.878
|
837 |
+
c-2.854,1.544-0.192,6.286,2.979,11.628C487.667,80.917,490.667,84.667,491.979,83.878z"/>
|
838 |
+
<path fill="#FFFFFF" d="M571,76v-5h-23.608c0.476-9.951-4.642-13.25-4.642-13.25l-3.125,4c0,0,3.726,2.7,3.625,5.125
|
839 |
+
c-0.071,1.714-2.711,3.18-4.962,4.125H517v5h10v24h-25v-5.666c0,0,0.839,0,2.839-0.667s6.172-3.667,4.005-6.333
|
840 |
+
s-7.49,0.333-9.656,0.166s-6.479-1.5-8.146,1.917c-1.551,3.178,0.791,5.25,5.541,6.083l-0.065,4.5H16c-2.761,0-5,2.238-5,5v17
|
841 |
+
c0,2.762,2.239,5,5,5h549c2.762,0,5-2.238,5-5v-17c0-2.762-2.238-5-5-5h-3V76H571z"/>
|
842 |
+
<path fill="#FFFFFF" d="M535,65.625c1.125,0.625,2.25-1.125,2.25-1.125l11.625-22.375c0,0,0.75-0.875-1.75-2.125
|
843 |
+
s-3.375,0.25-3.375,0.25s-8.75,21.625-9.875,23.5S533.875,65,535,65.625z"/>
|
844 |
+
</g>
|
845 |
+
<g>
|
846 |
+
<defs>
|
847 |
+
<path id="SVGID_1_" d="M484.5,75.584c-3.172-5.342-5.833-10.084-2.979-11.628c2.195-1.188,4.646,3.294,7.229,8.878
|
848 |
+
c2.609,5.64,4.444,10.313,3.229,11.044C490.667,84.667,487.667,80.917,484.5,75.584z M571,76v-5h-23.608
|
849 |
+
c0.476-9.951-4.642-13.25-4.642-13.25l-3.125,4c0,0,3.726,2.7,3.625,5.125c-0.071,1.714-2.711,3.18-4.962,4.125H517v5h10v24h-25
|
850 |
+
v-5.666c0,0,0.839,0,2.839-0.667s6.172-3.667,4.005-6.333s-7.49,0.333-9.656,0.166s-6.479-1.5-8.146,1.917
|
851 |
+
c-1.551,3.178,0.791,5.25,5.541,6.083l-0.065,4.5H16c-2.761,0-5,2.238-5,5v17c0,2.762,2.239,5,5,5h549c2.762,0,5-2.238,5-5v-17
|
852 |
+
c0-2.762-2.238-5-5-5h-3V76H571z M535,65.625c1.125,0.625,2.25-1.125,2.25-1.125l11.625-22.375c0,0,0.75-0.875-1.75-2.125
|
853 |
+
s-3.375,0.25-3.375,0.25s-8.75,21.625-9.875,23.5S533.875,65,535,65.625z"/>
|
854 |
+
</defs>
|
855 |
+
<clipPath id="SVGID_2_">
|
856 |
+
<use xlink:href="#SVGID_1_" overflow="visible"/>
|
857 |
+
</clipPath>
|
858 |
+
<rect id="progress-time-fill" x="-100%" y="34" clip-path="url(#SVGID_2_)" fill="#BE002A" width="586" height="103"/>
|
859 |
+
</g>
|
860 |
+
|
861 |
+
<g id="death-group">
|
862 |
+
<path id="death" fill="#BE002A" d="M-46.25,40.416c-5.42-0.281-8.349,3.17-13.25,3.918c-5.716,0.871-10.583-0.918-10.583-0.918
|
863 |
+
C-67.5,49-65.175,50.6-62.083,52c5.333,2.416,4.083,3.5,2.084,4.5c-16.5,4.833-15.417,27.917-15.417,27.917L-75.5,84.75
|
864 |
+
c-1,12.25-20.25,18.75-20.25,18.75s39.447,13.471,46.25-4.25c3.583-9.333-1.553-16.869-1.667-22.75
|
865 |
+
c-0.076-3.871,2.842-8.529,6.084-12.334c3.596-4.22,6.958-10.374,6.958-15.416C-38.125,43.186-39.833,40.75-46.25,40.416z
|
866 |
+
M-40,51.959c-0.882,3.004-2.779,6.906-4.154,6.537s-0.939-4.32,0.112-7.704c0.82-2.64,2.672-5.96,3.959-5.583
|
867 |
+
C-39.005,45.523-39.073,48.8-40,51.959z"/>
|
868 |
+
<path id="death-arm" fill="#BE002A" d="M-53.375,75.25c0,0,9.375,2.25,11.25,0.25s2.313-2.342,3.375-2.791
|
869 |
+
c1.083-0.459,4.375-1.75,4.292-4.75c-0.101-3.627,0.271-4.594,1.333-5.043c1.083-0.457,2.75-1.666,2.75-1.666
|
870 |
+
s0.708-0.291,0.5-0.875s-0.791-2.125-1.583-2.959c-0.792-0.832-2.375-1.874-2.917-1.332c-0.542,0.541-7.875,7.166-7.875,7.166
|
871 |
+
s-2.667,2.791-3.417,0.125S-49.833,61-49.833,61s-3.417,1.416-3.417,1.541s-1.25,5.834-1.25,5.834l-0.583,5.833L-53.375,75.25z"/>
|
872 |
+
<path id="death-tool" fill="#BE002A" d="M-20.996,26.839l-42.819,91.475l1.812,0.848l38.342-81.909c0,0,8.833,2.643,12.412,7.414
|
873 |
+
c5,6.668,4.75,14.084,4.75,14.084s4.354-7.732,0.083-17.666C-10,32.75-19.647,28.676-19.647,28.676l0.463-0.988L-20.996,26.839z"/>
|
874 |
+
</g>
|
875 |
+
<path id="designer-body" fill="#FEFFFE" d="M514.75,100.334c0,0,1.25-16.834-6.75-16.5c-5.501,0.229-5.583,3-10.833,1.666
|
876 |
+
c-3.251-0.826-5.084-15.75-0.834-22c4.948-7.277,12.086-9.266,13.334-7.833c2.25,2.583-2,10.833-4.5,14.167
|
877 |
+
c-2.5,3.333-1.833,10.416,0.5,9.916s8.026-0.141,10,2.25c3.166,3.834,4.916,17.667,4.916,17.667l0.917,2.5l-4,0.167L514.75,100.334z
|
878 |
+
"/>
|
879 |
+
|
880 |
+
<circle id="designer-head" fill="#FEFFFE" cx="516.083" cy="53.25" r="6.083"/>
|
881 |
+
|
882 |
+
<g id="designer-arm-grop">
|
883 |
+
<path id="designer-arm" fill="#FEFFFE" d="M505.875,64.875c0,0,5.875,7.5,13.042,6.791c6.419-0.635,11.833-2.791,13.458-4.041s2-3.5,0.25-3.875
|
884 |
+
s-11.375,5.125-16,3.25c-5.963-2.418-8.25-7.625-8.25-7.625l-2,1.125L505.875,64.875z"/>
|
885 |
+
<path id="designer-pen" fill="#FEFFFE" d="M525.75,59.084c0,0-0.423-0.262-0.969,0.088c-0.586,0.375-0.547,0.891-0.547,0.891l7.172,8.984l1.261,0.453
|
886 |
+
l-0.104-1.328L525.75,59.084z"/>
|
887 |
+
</g>
|
888 |
+
</svg>
|
889 |
+
|
890 |
+
<div class="deadline-timer">
|
891 |
+
Remaining <span class="day">{opt.run_time}</span> <span class="days">s</span>
|
892 |
+
</div>
|
893 |
+
|
894 |
+
</div>
|
895 |
+
""",
|
896 |
+
elem_id='run-anim',
|
897 |
+
visible=False,
|
898 |
+
)
|
899 |
+
|
900 |
+
with gr.Group(elem_id='control-panel'):
|
901 |
+
|
902 |
+
with gr.Row():
|
903 |
+
iface.tbox_prompt = gr.Textbox(
|
904 |
+
label='Edit Prompt for Background',
|
905 |
+
info='What do you want to draw?',
|
906 |
+
value=state.value.prompts[0],
|
907 |
+
placeholder=lambda: random.choice(prompt_suggestions),
|
908 |
+
scale=2,
|
909 |
+
)
|
910 |
+
|
911 |
+
iface.slider_strength = gr.Slider(
|
912 |
+
label='Prompt Strength',
|
913 |
+
info='Blends fg & bg in the prompt level, >0.8 Preferred.',
|
914 |
+
minimum=0.5,
|
915 |
+
maximum=1.0,
|
916 |
+
value=opt.default_prompt_strength,
|
917 |
+
scale=1,
|
918 |
+
)
|
919 |
+
|
920 |
+
with gr.Row():
|
921 |
+
iface.tbox_neg_prompt = gr.Textbox(
|
922 |
+
label='Edit Negative Prompt for Background',
|
923 |
+
info='Add unwanted objects for this semantic brush.',
|
924 |
+
value=opt.default_negative_prompt,
|
925 |
+
scale=2,
|
926 |
+
)
|
927 |
+
|
928 |
+
iface.tbox_name = gr.Textbox(
|
929 |
+
label='Edit Brush Name',
|
930 |
+
info='Just for your convenience.',
|
931 |
+
value=state.value.prompt_names[0],
|
932 |
+
placeholder='🌄 Background',
|
933 |
+
scale=1,
|
934 |
+
)
|
935 |
+
|
936 |
+
with gr.Row():
|
937 |
+
iface.slider_alpha = gr.Slider(
|
938 |
+
label='Mask Alpha',
|
939 |
+
info='Factor multiplied to the mask before quantization. Extremely sensitive, >0.98 Preferred.',
|
940 |
+
minimum=0.5,
|
941 |
+
maximum=1.0,
|
942 |
+
value=opt.default_mask_strength,
|
943 |
+
)
|
944 |
+
|
945 |
+
iface.slider_std = gr.Slider(
|
946 |
+
label='Mask Blur STD',
|
947 |
+
info='Blends fg & bg in the latent level, 0 for generation, 8-32 for inpainting.',
|
948 |
+
minimum=0.0001,
|
949 |
+
maximum=100.0,
|
950 |
+
value=opt.default_mask_std,
|
951 |
+
)
|
952 |
+
|
953 |
+
iface.slider_seed = gr.Slider(
|
954 |
+
label='Seed',
|
955 |
+
info='The global seed.',
|
956 |
+
minimum=-1,
|
957 |
+
maximum=2147483647,
|
958 |
+
step=1,
|
959 |
+
value=opt.seed,
|
960 |
+
)
|
961 |
+
|
962 |
+
### Attach event handlers
|
963 |
+
|
964 |
+
for idx, btn in enumerate(iface.btn_semantics):
|
965 |
+
btn.click(
|
966 |
+
fn=partial(select_palette, idx=idx),
|
967 |
+
inputs=[state, btn],
|
968 |
+
outputs=[state] + iface.btn_semantics + [
|
969 |
+
iface.tbox_name,
|
970 |
+
iface.tbox_prompt,
|
971 |
+
iface.tbox_neg_prompt,
|
972 |
+
iface.slider_alpha,
|
973 |
+
iface.slider_strength,
|
974 |
+
iface.slider_std,
|
975 |
+
],
|
976 |
+
api_name=f'select_palette_{idx}',
|
977 |
+
)
|
978 |
+
|
979 |
+
iface.btn_add_palette.click(
|
980 |
+
fn=add_palette,
|
981 |
+
inputs=state,
|
982 |
+
outputs=[state, iface.btn_add_palette] + iface.btn_semantics[1:],
|
983 |
+
api_name='create_new',
|
984 |
+
)
|
985 |
+
|
986 |
+
run_event = iface.btn_generate.click(
|
987 |
+
fn=hide_element,
|
988 |
+
inputs=None,
|
989 |
+
outputs=iface.btn_generate,
|
990 |
+
api_name='hide_run_button',
|
991 |
+
).then(
|
992 |
+
fn=show_element,
|
993 |
+
inputs=None,
|
994 |
+
outputs=iface.run_animation,
|
995 |
+
api_name='show_run_animation',
|
996 |
+
)
|
997 |
+
|
998 |
+
run_event.then(
|
999 |
+
fn=run,
|
1000 |
+
inputs=[state, iface.ctrl_semantic],
|
1001 |
+
outputs=[state, iface.image_slot],
|
1002 |
+
api_name='run',
|
1003 |
+
).then(
|
1004 |
+
fn=hide_element,
|
1005 |
+
inputs=None,
|
1006 |
+
outputs=iface.run_animation,
|
1007 |
+
api_name='hide_run_animation',
|
1008 |
+
).then(
|
1009 |
+
fn=show_element,
|
1010 |
+
inputs=None,
|
1011 |
+
outputs=iface.btn_generate,
|
1012 |
+
api_name='show_run_button',
|
1013 |
+
)
|
1014 |
+
|
1015 |
+
run_event.then(
|
1016 |
+
fn=None,
|
1017 |
+
inputs=None,
|
1018 |
+
outputs=None,
|
1019 |
+
api_name='run_animation',
|
1020 |
+
js=f"""
|
1021 |
+
async () => {{
|
1022 |
+
// timer arguments:
|
1023 |
+
// #1 - time of animation in mileseconds,
|
1024 |
+
// #2 - days to deadline
|
1025 |
+
const animationTime = {opt.run_time};
|
1026 |
+
const days = {opt.run_time};
|
1027 |
+
|
1028 |
+
jQuery('#progress-time-fill, #death-group').css({{'animation-duration': animationTime+'s'}});
|
1029 |
+
|
1030 |
+
var deadlineAnimation = function () {{
|
1031 |
+
setTimeout(function() {{
|
1032 |
+
jQuery('#designer-arm-grop').css({{'animation-duration': '1.5s'}});
|
1033 |
+
}}, 0);
|
1034 |
+
|
1035 |
+
setTimeout(function() {{
|
1036 |
+
jQuery('#designer-arm-grop').css({{'animation-duration': '1.0s'}});
|
1037 |
+
}}, {int(opt.run_time * 1000 * 0.2)});
|
1038 |
+
|
1039 |
+
setTimeout(function() {{
|
1040 |
+
jQuery('#designer-arm-grop').css({{'animation-duration': '0.7s'}});
|
1041 |
+
}}, {int(opt.run_time * 1000 * 0.4)});
|
1042 |
+
|
1043 |
+
setTimeout(function() {{
|
1044 |
+
jQuery('#designer-arm-grop').css({{'animation-duration': '0.3s'}});
|
1045 |
+
}}, {int(opt.run_time * 1000 * 0.6)});
|
1046 |
+
|
1047 |
+
setTimeout(function() {{
|
1048 |
+
jQuery('#designer-arm-grop').css({{'animation-duration': '0.2s'}});
|
1049 |
+
}}, {int(opt.run_time * 1000 * 0.75)});
|
1050 |
+
}};
|
1051 |
+
|
1052 |
+
var deadlineTextFinished = function () {{
|
1053 |
+
var el = jQuery('.deadline-timer');
|
1054 |
+
var html = 'Done! Retry?';
|
1055 |
+
el.html(html);
|
1056 |
+
}};
|
1057 |
+
|
1058 |
+
function timer(totalTime, deadline) {{
|
1059 |
+
var time = totalTime * 1000;
|
1060 |
+
var dayDuration = time / deadline;
|
1061 |
+
var actualDay = deadline;
|
1062 |
+
|
1063 |
+
var timer = setInterval(countTime, dayDuration);
|
1064 |
+
|
1065 |
+
function countTime() {{
|
1066 |
+
--actualDay;
|
1067 |
+
jQuery('.deadline-timer .day').text(actualDay);
|
1068 |
+
|
1069 |
+
if (actualDay == 0) {{
|
1070 |
+
clearInterval(timer);
|
1071 |
+
// jQuery('.deadline-timer .day').text(deadline);
|
1072 |
+
deadlineTextFinished();
|
1073 |
+
}}
|
1074 |
+
}}
|
1075 |
+
}}
|
1076 |
+
|
1077 |
+
var deadlineText = function () {{
|
1078 |
+
var el = jQuery('.deadline-timer');
|
1079 |
+
var htmlBase = 'Remaining <span class="day">{opt.run_time}</span> <span class="days">s</span>';
|
1080 |
+
el.html(html);
|
1081 |
+
var html = '<div class="mask-red"><div class="inner">' + htmlBase + '</div></div><div class="mask-white"><div class="inner">' + htmlBase + '</div></div>';
|
1082 |
+
el.html(html);
|
1083 |
+
}};
|
1084 |
+
|
1085 |
+
var runAnimation = function() {{
|
1086 |
+
timer(animationTime, days);
|
1087 |
+
deadlineAnimation();
|
1088 |
+
deadlineText();
|
1089 |
+
|
1090 |
+
console.log('begin interval', animationTime * 1000);
|
1091 |
+
}};
|
1092 |
+
runAnimation();
|
1093 |
+
}}
|
1094 |
+
"""
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
iface.slider_alpha.input(
|
1098 |
+
fn=change_mask_strength,
|
1099 |
+
inputs=[state, iface.slider_alpha],
|
1100 |
+
outputs=state,
|
1101 |
+
api_name='change_alpha',
|
1102 |
+
)
|
1103 |
+
iface.slider_std.input(
|
1104 |
+
fn=change_std,
|
1105 |
+
inputs=[state, iface.slider_std],
|
1106 |
+
outputs=state,
|
1107 |
+
api_name='change_std',
|
1108 |
+
)
|
1109 |
+
iface.slider_strength.input(
|
1110 |
+
fn=change_prompt_strength,
|
1111 |
+
inputs=[state, iface.slider_strength],
|
1112 |
+
outputs=state,
|
1113 |
+
api_name='change_strength',
|
1114 |
+
)
|
1115 |
+
iface.slider_seed.input(
|
1116 |
+
fn=reset_seed,
|
1117 |
+
inputs=[state, iface.slider_seed],
|
1118 |
+
outputs=state,
|
1119 |
+
api_name='reset_seed',
|
1120 |
+
)
|
1121 |
+
|
1122 |
+
iface.tbox_name.input(
|
1123 |
+
fn=rename_prompt,
|
1124 |
+
inputs=[state, iface.tbox_name],
|
1125 |
+
outputs=[state] + iface.btn_semantics,
|
1126 |
+
api_name='prompt_rename',
|
1127 |
+
)
|
1128 |
+
iface.tbox_prompt.input(
|
1129 |
+
fn=change_prompt,
|
1130 |
+
inputs=[state, iface.tbox_prompt],
|
1131 |
+
outputs=state,
|
1132 |
+
api_name='prompt_edit',
|
1133 |
+
)
|
1134 |
+
iface.tbox_neg_prompt.input(
|
1135 |
+
fn=change_neg_prompt,
|
1136 |
+
inputs=[state, iface.tbox_neg_prompt],
|
1137 |
+
outputs=state,
|
1138 |
+
api_name='neg_prompt_edit',
|
1139 |
+
)
|
1140 |
+
|
1141 |
+
# iface.style_select.change(
|
1142 |
+
# fn=select_style,
|
1143 |
+
# inputs=[state, iface.style_select],
|
1144 |
+
# outputs=state,
|
1145 |
+
# api_name='style_select',
|
1146 |
+
# )
|
1147 |
+
# iface.quality_select.change(
|
1148 |
+
# fn=select_quality,
|
1149 |
+
# inputs=[state, iface.quality_select],
|
1150 |
+
# outputs=state,
|
1151 |
+
# api_name='quality_select',
|
1152 |
+
# )
|
1153 |
+
|
1154 |
+
iface.btn_export_state.click(lambda x: vars(x), state, iface.json_state_export)
|
1155 |
+
iface.btn_import_state.click(import_state, [state, iface.tbox_state_import], [
|
1156 |
+
state,
|
1157 |
+
*iface.btn_semantics,
|
1158 |
+
# iface.style_select,
|
1159 |
+
# iface.quality_select,
|
1160 |
+
iface.tbox_prompt,
|
1161 |
+
iface.tbox_name,
|
1162 |
+
iface.tbox_neg_prompt,
|
1163 |
+
iface.slider_strength,
|
1164 |
+
iface.slider_alpha,
|
1165 |
+
iface.slider_std,
|
1166 |
+
iface.slider_seed,
|
1167 |
+
])
|
1168 |
+
|
1169 |
+
# Realtime user input.
|
1170 |
+
iface.ctrl_semantic.change(
|
1171 |
+
fn=draw,
|
1172 |
+
inputs=[state, iface.ctrl_semantic],
|
1173 |
+
outputs=None,
|
1174 |
+
api_name='draw',
|
1175 |
+
)
|
1176 |
+
|
1177 |
+
|
1178 |
+
if __name__ == '__main__':
|
1179 |
+
demo.launch(server_port=opt.port)
|
checkpoints/put_checkpoint_models_here.txt
ADDED
File without changes
|
data.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
import copy
|
22 |
+
from typing import Optional, Union
|
23 |
+
from PIL import Image
|
24 |
+
import torch
|
25 |
+
|
26 |
+
|
27 |
+
class BackgroundObject:
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
image: Optional[Image.Image] = None,
|
31 |
+
prompt: Optional[str] = None,
|
32 |
+
negative_prompt: Optional[str] = None,
|
33 |
+
) -> None:
|
34 |
+
self.image = image
|
35 |
+
self.prompt = prompt
|
36 |
+
self.negative_prompt = negative_prompt
|
37 |
+
|
38 |
+
@property
|
39 |
+
def is_empty(self) -> bool:
|
40 |
+
return (
|
41 |
+
self.image is None and
|
42 |
+
self.prompt is None and
|
43 |
+
self.negative_prompt is None
|
44 |
+
)
|
45 |
+
|
46 |
+
def extra_repr(self) -> str:
|
47 |
+
return ''
|
48 |
+
|
49 |
+
def __repr__(self) -> str:
|
50 |
+
strings = []
|
51 |
+
if self.image is not None:
|
52 |
+
if isinstance(self.image, Image.Image):
|
53 |
+
image_str = f'Image(size={str(self.image.size)})'
|
54 |
+
else:
|
55 |
+
image_str = f'Tensor(shape={str(self.image.shape)})'
|
56 |
+
strings.append(f'image={image_str}')
|
57 |
+
if self.prompt is not None:
|
58 |
+
strings.append(f'prompt="{self.prompt}"')
|
59 |
+
if self.negative_prompt is not None:
|
60 |
+
strings.append(f'negative_prompt="{self.negative_prompt}"')
|
61 |
+
extra_repr = self.extra_repr()
|
62 |
+
if extra_repr != '':
|
63 |
+
strings.append(extra_repr)
|
64 |
+
return f'{type(self).__name__}({", ".join(strings)})'
|
65 |
+
|
66 |
+
|
67 |
+
class LayerObject:
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
idx: Optional[int] = None,
|
71 |
+
prompt: Optional[str] = None,
|
72 |
+
negative_prompt: Optional[str] = None,
|
73 |
+
suffix: Optional[str] = None,
|
74 |
+
prompt_strength: Optional[float] = None,
|
75 |
+
mask: Optional[Union[torch.Tensor, Image.Image]] = None,
|
76 |
+
mask_std: Optional[float] = None,
|
77 |
+
mask_strength: Optional[float] = None,
|
78 |
+
) -> None:
|
79 |
+
self.idx = idx
|
80 |
+
self.prompt = prompt
|
81 |
+
self.negative_prompt = negative_prompt
|
82 |
+
self.suffix = suffix
|
83 |
+
self.prompt_strength = prompt_strength
|
84 |
+
self.mask = mask
|
85 |
+
self.mask_std = mask_std
|
86 |
+
self.mask_strength = mask_strength
|
87 |
+
|
88 |
+
@property
|
89 |
+
def is_empty(self) -> bool:
|
90 |
+
return (
|
91 |
+
self.prompt is None and
|
92 |
+
self.negative_prompt is None and
|
93 |
+
self.prompt_strength is None and
|
94 |
+
self.mask is None and
|
95 |
+
self.mask_strength is None and
|
96 |
+
self.mask_std is None
|
97 |
+
)
|
98 |
+
|
99 |
+
def merge(self, other: 'LayerObject') -> bool: # Overriden or not.
|
100 |
+
if self.idx != other.idx:
|
101 |
+
# Merge only the modification requests for the same layer.
|
102 |
+
return False
|
103 |
+
|
104 |
+
if self.prompt is None and other.prompt is not None:
|
105 |
+
self.prompt = copy.deepcopy(other.prompt)
|
106 |
+
if self.negative_prompt is None and other.negative_prompt is not None:
|
107 |
+
self.negative_prompt = copy.deepcopy(other.negative_prompt)
|
108 |
+
if self.suffix is None and other.suffix is not None:
|
109 |
+
self.suffix = copy.deepcopy(other.suffix)
|
110 |
+
if self.prompt_strength is None and other.prompt_strength is not None:
|
111 |
+
self.prompt_strength = copy.deepcopy(other.prompt_strength)
|
112 |
+
if self.mask is None and other.mask is not None:
|
113 |
+
self.mask = copy.deepcopy(other.mask)
|
114 |
+
if self.mask_strength is None and other.mask_strength is not None:
|
115 |
+
self.mask_strength = copy.deepcopy(other.mask_strength)
|
116 |
+
if self.mask_std is None and other.mask_std is not None:
|
117 |
+
self.mask_std = copy.deepcopy(other.mask_std)
|
118 |
+
return True
|
119 |
+
|
120 |
+
def extra_repr(self) -> str:
|
121 |
+
return ''
|
122 |
+
|
123 |
+
def __repr__(self) -> str:
|
124 |
+
strings = []
|
125 |
+
if self.idx is not None:
|
126 |
+
strings.append(f'idx={self.idx}')
|
127 |
+
if self.prompt is not None:
|
128 |
+
strings.append(f'prompt="{self.prompt}"')
|
129 |
+
if self.negative_prompt is not None:
|
130 |
+
strings.append(f'negative_prompt="{self.negative_prompt}"')
|
131 |
+
if self.suffix is not None:
|
132 |
+
strings.append(f'suffix="{self.suffix}"')
|
133 |
+
if self.prompt_strength is not None:
|
134 |
+
strings.append(f'prompt_strength={self.prompt_strength}')
|
135 |
+
if self.mask is not None:
|
136 |
+
if isinstance(self.mask, Image.Image):
|
137 |
+
mask_str = f'Image(size={str(self.mask.size)})'
|
138 |
+
else:
|
139 |
+
mask_str = f'Tensor(shape={str(self.mask.shape)})'
|
140 |
+
strings.append(f'mask={mask_str}')
|
141 |
+
if self.mask_std is not None:
|
142 |
+
strings.append(f'mask_std={self.mask_std}')
|
143 |
+
if self.mask_strength is not None:
|
144 |
+
strings.append(f'mask_strength={self.mask_strength}')
|
145 |
+
extra_repr = self.extra_repr()
|
146 |
+
if extra_repr != '':
|
147 |
+
strings.append(extra_repr)
|
148 |
+
return f'{type(self).__name__}({", ".join(strings)})'
|
149 |
+
|
150 |
+
|
151 |
+
class BackgroundState(BackgroundObject):
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
image: Optional[Image.Image] = None,
|
155 |
+
prompt: Optional[str] = None,
|
156 |
+
negative_prompt: Optional[str] = None,
|
157 |
+
latent: Optional[torch.Tensor] = None,
|
158 |
+
embed: Optional[torch.Tensor] = None,
|
159 |
+
) -> None:
|
160 |
+
super().__init__(image, prompt, negative_prompt)
|
161 |
+
self.latent = latent
|
162 |
+
self.embed = embed
|
163 |
+
|
164 |
+
@property
|
165 |
+
def is_incomplete(self) -> bool:
|
166 |
+
return (
|
167 |
+
self.image is None or
|
168 |
+
self.prompt is None or
|
169 |
+
self.negative_prompt is None or
|
170 |
+
self.latent is None or
|
171 |
+
self.embed is None
|
172 |
+
)
|
173 |
+
|
174 |
+
def extra_repr(self) -> str:
|
175 |
+
strings = []
|
176 |
+
if self.latent is not None:
|
177 |
+
strings.append(f'latent=Tensor(shape={str(self.latent.shape)})')
|
178 |
+
if self.embed is not None:
|
179 |
+
strings.append(f'embed=Tuple[Tensor(shape={str(self.embed[0].shape)})]')
|
180 |
+
return ', '.join(strings)
|
181 |
+
|
182 |
+
|
183 |
+
# TODO
|
184 |
+
# class LayerState:
|
185 |
+
# def __init__(
|
186 |
+
# self,
|
187 |
+
# prompst: List[str] = [],
|
188 |
+
# negative_prompts: List[str] = [],
|
189 |
+
# suffix: List[str] = [],
|
190 |
+
# masks: Optional[torch.Tensor] = None,
|
191 |
+
# mask_std: Optional[torch.Tensor] = None,
|
192 |
+
# mask_strength: Optional[torch.Tensor] = None,
|
193 |
+
# original_masks: Optional[Union[torch.Tensor, List[Image.Image]]] = None,
|
194 |
+
# ) -> None:
|
195 |
+
# self.prompts = prompts
|
196 |
+
# self.negative_prompts = negative_prompts
|
197 |
+
# self.suffix = suffix
|
198 |
+
# self.masks = masks
|
199 |
+
# self.mask_std = mask_std
|
200 |
+
# self.mask_strength = mask_strength
|
201 |
+
# self.original_masks = original_masks
|
202 |
+
|
203 |
+
# def __len__(self) -> int:
|
204 |
+
# self.check_integrity(True)
|
205 |
+
# return len(self.prompts)
|
206 |
+
|
207 |
+
# @property
|
208 |
+
# def is_empty(self) -> bool:
|
209 |
+
# self.check_integrity(True)
|
210 |
+
# return len(self.prompt) == 0
|
211 |
+
|
212 |
+
# def check_integrity(self, throw_error: bool = True) -> bool:
|
213 |
+
# p = len(self.prompts)
|
214 |
+
# flag = (
|
215 |
+
# p != len(self.negative_prompts) or
|
216 |
+
# p != len(self.suffix) or
|
217 |
+
# p != len(self.masks) or
|
218 |
+
# p != len(self.mask_std) or
|
219 |
+
# p != len(self.mask_strength) or
|
220 |
+
# p != len(self.original_masks)
|
221 |
+
# )
|
222 |
+
# if flag and throw_error:
|
223 |
+
# print(
|
224 |
+
# f'LayerState(\n\tlen(prompts): {p},\n\tlen(negative_prompts): {len(self.negative_prompts)},\n\t'
|
225 |
+
# f'len(suffix): {len(self.suffix)},\n\tlen(masks): {len(self.masks)},\n\t'
|
226 |
+
# f'len(mask_std): {len(self.mask_std)},\n\tlen(mask_strength): {len(self.mask_strength)},\n\t'
|
227 |
+
# f'len(original_masks): {len(self.original_masks)}\n)'
|
228 |
+
# )
|
229 |
+
# raise ValueError('LayerState is corrupted!')
|
230 |
+
# return not flag
|
231 |
+
|
232 |
+
# def extra_repr(self) -> str:
|
233 |
+
# strings = []
|
234 |
+
# if self.idx is not None:
|
235 |
+
# strings.append(f'idx={self.idx}')
|
236 |
+
# if self.prompt is not None:
|
237 |
+
# strings.append(f'prompt="{self.prompt}"')
|
238 |
+
# if self.negative_prompt is not None:
|
239 |
+
# strings.append(f'negative_prompt="{self.negative_prompt}"')
|
240 |
+
# if self.suffix is not None:
|
241 |
+
# strings.append(f'suffix="{self.suffix}"')
|
242 |
+
# if self.mask is not None:
|
243 |
+
# if isinstance(self.mask, Image.Image):
|
244 |
+
# mask_str = f'PIL.Image.Image(size={str(self.mask.size)})'
|
245 |
+
# else:
|
246 |
+
# mask_str = f'torch.Tensor(shape={str(self.mask.shape)})'
|
247 |
+
# strings.append(f'mask={mask_str}')
|
248 |
+
# if self.mask_std is not None:
|
249 |
+
# strings.append(f'mask_std={self.mask_std}')
|
250 |
+
# if self.mask_strength is not None:
|
251 |
+
# strings.append(f'mask_strength={self.mask_strength}')
|
252 |
+
# return f'{type(self).__name__}({", ".join(strings)})'
|
examples/prompt_background.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Maximalism, best quality, high quality, no humans, background, clear sky, ㅠblack sky, starry universe, planets
|
2 |
+
Maximalism, best quality, high quality, no humans, background, clear sky, blue sky
|
3 |
+
Maximalism, best quality, high quality, no humans, background, universe, void, black, galaxy, galaxy, stars, stars, stars
|
4 |
+
Maximalism, best quality, high quality, no humans, background, galaxy
|
5 |
+
Maximalism, best quality, high quality, no humans, background, sky, daylight
|
6 |
+
Maximalism, best quality, high quality, no humans, background, skyscrappers, rooftop, city of light, helicopters, bright night, sky
|
7 |
+
Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden, no humans, background
|
8 |
+
Maximalism, best quality, high quality, flowers, flowers, flowers, flower garden
|
examples/prompt_background_advanced.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/prompt_boy.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1boy, looking at viewer, brown hair, blue shirt
|
2 |
+
1boy, looking at viewer, brown hair, red shirt
|
3 |
+
1boy, looking at viewer, brown hair, purple shirt
|
4 |
+
1boy, looking at viewer, brown hair, orange shirt
|
5 |
+
1boy, looking at viewer, brown hair, yellow shirt
|
6 |
+
1boy, looking at viewer, brown hair, green shirt
|
7 |
+
1boy, looking back, side shaved hair, cyberpunk cloths, robotic suit, large body
|
8 |
+
1boy, looking back, short hair, renaissance cloths, noble boy
|
9 |
+
1boy, looking back, long hair, ponytail, leather jacket, heavy metal boy
|
10 |
+
1boy, looking at viewer, a king, kingly grace, majestic cloths, crown
|
11 |
+
1boy, looking at viewer, an astronaut, brown hair, faint smile, engineer
|
12 |
+
1boy, looking at viewer, a medieval knight, helmet, swordman, plate armour
|
13 |
+
1boy, looking at viewer, black haired, old eastern cloth
|
14 |
+
1boy, looking back, messy hair, suit, short beard, noir
|
15 |
+
1boy, looking at viewer, cute face, light smile, starry eyes, jeans
|
examples/prompt_girl.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1girl, looking at viewer, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, chinese cloths
|
2 |
+
1girl, looking at viewer, princess, pretty face, light smile, haughty smile, proud, long wavy hair, charcoal dark eyes, majestic gown
|
3 |
+
1girl, looking at viewer, astronaut girl, long red hair, space suit, black starry eyes, happy face, pretty face
|
4 |
+
1girl, looking at viewer, fantasy adventurer, backpack
|
5 |
+
1girl, looking at viewer, astronaut girl, spacesuit, eva, happy face
|
6 |
+
1girl, looking at viewer, soldier, rusty cloths, backpack, pretty face, sad smile, tears
|
7 |
+
1girl, looking at viewer, majestic cloths, long hair, glittering eye, pretty face
|
8 |
+
1girl, looking at viewer, from behind, majestic cloths, long hair, glittering eye
|
9 |
+
1girl, looking at viewer, evil smile, very short hair, suit, evil genius
|
10 |
+
1girl, looking at viewer, elven queen, green hair, haughty face, eyes wide open, crazy smile, brown jacket, leaves
|
11 |
+
1girl, looking at viewer, purple hair, happy face, black leather jacket
|
12 |
+
1girl, looking at viewer, pink hair, happy face, blue jeans, black leather jacket
|
13 |
+
1girl, looking at viewer, knight, medium length hair, red hair, plate armour, blue eyes, sad, pretty face, determined face
|
14 |
+
1girl, looking at viewer, pretty face, light smile, orange hair, casual cloths
|
15 |
+
1girl, looking at viewer, pretty face, large smile, open mouth, uniform, mcdonald employee, short wavy hair
|
16 |
+
1girl, looking at viewer, brown hair, ponytail, happy face, bright smile, blue jeans and white shirt
|
examples/prompt_props.txt
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
🏯 Palace, Gyeongbokgung palace
|
2 |
+
🌳 Garden, Chinese garden
|
3 |
+
🏛️ Rome, Ancient city of Rome
|
4 |
+
🧱 Wall, Castle wall
|
5 |
+
🔴 Mars, Martian desert, Red rocky desert
|
6 |
+
🌻 Grassland, Grasslands
|
7 |
+
🏡 Village, A fantasy village
|
8 |
+
🐉 Dragon, a flying chinese dragon
|
9 |
+
🌏 Earth, Earth seen from ISS
|
10 |
+
🚀 Space Station, the international space station
|
11 |
+
🪻 Grassland, Rusty grassland with flowers
|
12 |
+
🖼️ Tapestry, majestic tapestry, glittering effect, glowing in light, mural painting with mountain
|
13 |
+
🏙️ City Ruin, city, ruins, ruins, ruins, deserted
|
14 |
+
🏙️ Renaissance City, renaissance city, renaissance city, renaissance city
|
15 |
+
🌷 Flowers, Flower garden
|
16 |
+
🌼 Flowers, Flower garden, spring garden
|
17 |
+
🌹 Flowers, Flowers flowers, flowers
|
18 |
+
⛰️ Dolomites Mountains, Dolomites
|
19 |
+
⛰️ Himalayas Mountains, Himalayas
|
20 |
+
⛰️ Alps Mountains, Alps
|
21 |
+
⛰️ Mountains, Mountains
|
22 |
+
❄️⛰️ Mountains, Winter mountains
|
23 |
+
🌷⛰️ Mountains, Spring mountains
|
24 |
+
🌞⛰️ Mountains, Summer mountains
|
25 |
+
🌵 Desert, A sandy desert, dunes
|
26 |
+
🪨🌵 Desert, A rocky desert
|
27 |
+
💦 Waterfall, A giant waterfall
|
28 |
+
🌊 Ocean, Ocean
|
29 |
+
⛱️ Seashore, Seashore
|
30 |
+
🌅 Sea Horizon, Sea horizon
|
31 |
+
🌊 Lake, Clear blue lake
|
32 |
+
💻 Computer, A giant supecomputer
|
33 |
+
🌳 Tree, A giant tree
|
34 |
+
🌳 Forest, A forest
|
35 |
+
🌳🌳 Forest, A dense forest
|
36 |
+
🌲 Forest, Winter forest
|
37 |
+
🌴 Forest, Summer forest, tropical forest
|
38 |
+
👒 Hat, A hat
|
39 |
+
🐶 Dog, Doggy body parts
|
40 |
+
😻 Cat, A cat
|
41 |
+
🦉 Owl, A small sitting owl
|
42 |
+
🦅 Eagle, A small sitting eagle
|
43 |
+
🚀 Rocket, A flying rocket
|
model.py
ADDED
@@ -0,0 +1,1212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
22 |
+
from diffusers import DiffusionPipeline, LCMScheduler, EulerDiscreteScheduler, AutoencoderTiny
|
23 |
+
from huggingface_hub import hf_hub_download
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.nn.functional as F
|
28 |
+
import torchvision.transforms as T
|
29 |
+
from einops import rearrange
|
30 |
+
|
31 |
+
from collections import deque
|
32 |
+
from typing import Tuple, List, Literal, Optional, Union
|
33 |
+
from PIL import Image
|
34 |
+
|
35 |
+
from util import load_model, gaussian_lowpass, shift_to_mask_bbox_center
|
36 |
+
from data import BackgroundObject, LayerObject, BackgroundState #, LayerState
|
37 |
+
|
38 |
+
|
39 |
+
class StreamMultiDiffusion(nn.Module):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
device: torch.device,
|
43 |
+
dtype: torch.dtype = torch.float16,
|
44 |
+
sd_version: Literal['1.5'] = '1.5',
|
45 |
+
hf_key: Optional[str] = None,
|
46 |
+
lora_key: Optional[str] = None,
|
47 |
+
use_tiny_vae: bool = True,
|
48 |
+
t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], Magic number.
|
49 |
+
width: int = 512,
|
50 |
+
height: int = 512,
|
51 |
+
frame_buffer_size: int = 1,
|
52 |
+
num_inference_steps: int = 50,
|
53 |
+
guidance_scale: float = 1.2,
|
54 |
+
delta: float = 1.0,
|
55 |
+
cfg_type: Literal['none', 'full', 'self', 'initialize'] = 'none',
|
56 |
+
seed: int = 2024,
|
57 |
+
autoflush: bool = True,
|
58 |
+
default_mask_std: float = 8.0,
|
59 |
+
default_mask_strength: float = 1.0,
|
60 |
+
default_prompt_strength: float = 0.95,
|
61 |
+
bootstrap_steps: int = 1,
|
62 |
+
bootstrap_mix_steps: float = 1.0,
|
63 |
+
# bootstrap_leak_sensitivity: float = 0.2,
|
64 |
+
preprocess_mask_cover_alpha: float = 0.3, # TODO
|
65 |
+
prompt_queue_capacity: int = 256,
|
66 |
+
mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'continuous',
|
67 |
+
use_xformers: bool = True,
|
68 |
+
) -> None:
|
69 |
+
super().__init__()
|
70 |
+
|
71 |
+
self.device = device
|
72 |
+
self.dtype = dtype
|
73 |
+
self.seed = seed
|
74 |
+
self.sd_version = sd_version
|
75 |
+
|
76 |
+
self.autoflush = autoflush
|
77 |
+
self.default_mask_std = default_mask_std
|
78 |
+
self.default_mask_strength = default_mask_strength
|
79 |
+
self.default_prompt_strength = default_prompt_strength
|
80 |
+
self.bootstrap_steps = (
|
81 |
+
bootstrap_steps > torch.arange(len(t_index_list))).to(dtype=self.dtype, device=self.device)
|
82 |
+
self.bootstrap_mix_steps = bootstrap_mix_steps
|
83 |
+
self.bootstrap_mix_ratios = (
|
84 |
+
bootstrap_mix_steps - torch.arange(len(t_index_list), dtype=self.dtype, device=self.device)).clip_(0, 1)
|
85 |
+
# self.bootstrap_leak_sensitivity = bootstrap_leak_sensitivity
|
86 |
+
self.preprocess_mask_cover_alpha = preprocess_mask_cover_alpha
|
87 |
+
self.mask_type = mask_type
|
88 |
+
|
89 |
+
### State definition
|
90 |
+
|
91 |
+
# [0. Start] -(prepare)-> [1. Initialized]
|
92 |
+
# [1. Initialized] -(update_background)-> [2. Background Registered] (len(self.prompts)==0)
|
93 |
+
# [2. Background Registered] -(update_layers)-> [3. Unflushed] (len(self.prompts)>0)
|
94 |
+
|
95 |
+
# [3. Unflushed] -(flush)-> [4. Ready]
|
96 |
+
# [4. Ready] -(any updates)-> [3. Unflushed]
|
97 |
+
# [4. Ready] -(__call__)-> [4. Ready], continuously returns generated image.
|
98 |
+
|
99 |
+
self.ready_checklist = {
|
100 |
+
'initialized': False,
|
101 |
+
'background_registered': False,
|
102 |
+
'layers_ready': False,
|
103 |
+
'flushed': False,
|
104 |
+
}
|
105 |
+
|
106 |
+
### Session state update queue: for lazy update policy for streaming applications.
|
107 |
+
|
108 |
+
self.update_buffer = {
|
109 |
+
'background': None, # Maintains a single instance of BackgroundObject
|
110 |
+
'layers': deque(maxlen=prompt_queue_capacity), # Maintains a queue of LayerObjects
|
111 |
+
}
|
112 |
+
|
113 |
+
print(f'[INFO] Loading Stable Diffusion...')
|
114 |
+
get_scheduler = lambda pipe: LCMScheduler.from_config(pipe.scheduler.config)
|
115 |
+
lora_weight_name = None
|
116 |
+
if self.sd_version == '1.5':
|
117 |
+
if hf_key is not None:
|
118 |
+
print(f'[INFO] Using custom model key: {hf_key}')
|
119 |
+
model_key = hf_key
|
120 |
+
else:
|
121 |
+
model_key = 'runwayml/stable-diffusion-v1-5'
|
122 |
+
lora_key = 'latent-consistency/lcm-lora-sdv1-5'
|
123 |
+
lora_weight_name = 'pytorch_lora_weights.safetensors'
|
124 |
+
# elif self.sd_version == 'xl':
|
125 |
+
# model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
|
126 |
+
# lora_key = 'latent-consistency/lcm-lora-sdxl'
|
127 |
+
# lora_weight_name = 'pytorch_lora_weights.safetensors'
|
128 |
+
else:
|
129 |
+
raise ValueError(f'Stable Diffusion version {self.sd_version} not supported.')
|
130 |
+
|
131 |
+
### Internally stored "Session" states
|
132 |
+
|
133 |
+
self.state = {
|
134 |
+
'background': BackgroundState(), # Maintains a single instance of BackgroundState
|
135 |
+
# 'layers': LayerState(), # Maintains a single instance of LayerState
|
136 |
+
'model_key': model_key, # The Hugging Face model ID.
|
137 |
+
}
|
138 |
+
|
139 |
+
# Create model
|
140 |
+
self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
|
141 |
+
self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
|
142 |
+
|
143 |
+
self.pipe = load_model(model_key, self.sd_version, self.device, self.dtype)
|
144 |
+
|
145 |
+
self.pipe.load_lora_weights(lora_key, weight_name=lora_weight_name, adapter_name='lcm')
|
146 |
+
self.pipe.fuse_lora(
|
147 |
+
fuse_unet=True,
|
148 |
+
fuse_text_encoder=True,
|
149 |
+
lora_scale=1.0,
|
150 |
+
safe_fusing=False,
|
151 |
+
)
|
152 |
+
if use_xformers:
|
153 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
154 |
+
|
155 |
+
self.vae = (
|
156 |
+
AutoencoderTiny.from_pretrained('madebyollin/taesd').to(device=self.device, dtype=self.dtype)
|
157 |
+
if use_tiny_vae else self.pipe.vae
|
158 |
+
)
|
159 |
+
# self.tokenizer = self.pipe.tokenizer
|
160 |
+
self.text_encoder = self.pipe.text_encoder
|
161 |
+
self.unet = self.pipe.unet
|
162 |
+
self.vae_scale_factor = self.pipe.vae_scale_factor
|
163 |
+
|
164 |
+
self.scheduler = get_scheduler(self.pipe)
|
165 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
166 |
+
|
167 |
+
self.generator = None
|
168 |
+
|
169 |
+
# Lock the canvas size--changing the canvas size can be implemented by reloading the module.
|
170 |
+
self.height = height
|
171 |
+
self.width = width
|
172 |
+
self.latent_height = int(height // self.pipe.vae_scale_factor)
|
173 |
+
self.latent_width = int(width // self.pipe.vae_scale_factor)
|
174 |
+
|
175 |
+
# For bootstrapping.
|
176 |
+
self.white = self.encode_imgs(torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device))
|
177 |
+
|
178 |
+
# StreamDiffusion setting.
|
179 |
+
self.t_list = t_index_list
|
180 |
+
assert len(self.t_list) > 1, 'Current version only supports diffusion models with multiple steps.'
|
181 |
+
self.frame_bff_size = frame_buffer_size # f
|
182 |
+
self.denoising_steps_num = len(self.t_list) # t=2
|
183 |
+
self.cfg_type = cfg_type
|
184 |
+
self.num_inference_steps = num_inference_steps
|
185 |
+
self.guidance_scale = 1.0 if self.cfg_type == 'none' else guidance_scale
|
186 |
+
self.delta = delta
|
187 |
+
|
188 |
+
self.batch_size = self.denoising_steps_num * frame_buffer_size # T = t*f
|
189 |
+
if self.cfg_type == 'initialize':
|
190 |
+
self.trt_unet_batch_size = (self.denoising_steps_num + 1) * self.frame_bff_size
|
191 |
+
elif self.cfg_type == 'full':
|
192 |
+
self.trt_unet_batch_size = 2 * self.denoising_steps_num * self.frame_bff_size
|
193 |
+
else:
|
194 |
+
self.trt_unet_batch_size = self.denoising_steps_num * frame_buffer_size
|
195 |
+
|
196 |
+
print(f'[INFO] Model is loaded!')
|
197 |
+
|
198 |
+
self.reset_seed(self.generator, seed)
|
199 |
+
self.reset_latent()
|
200 |
+
self.prepare()
|
201 |
+
|
202 |
+
print(f'[INFO] Parameters prepared!')
|
203 |
+
|
204 |
+
self.ready_checklist['initialized'] = True
|
205 |
+
|
206 |
+
@property
|
207 |
+
def background(self) -> BackgroundState:
|
208 |
+
return self.state['background']
|
209 |
+
|
210 |
+
# @property
|
211 |
+
# def layers(self) -> LayerState:
|
212 |
+
# return self.state['layers']
|
213 |
+
|
214 |
+
@property
|
215 |
+
def num_layers(self) -> int:
|
216 |
+
return len(self.prompts) if hasattr(self, 'prompts') else 0
|
217 |
+
|
218 |
+
@property
|
219 |
+
def is_ready_except_flush(self) -> bool:
|
220 |
+
return all(v for k, v in self.ready_checklist.items() if k != 'flushed')
|
221 |
+
|
222 |
+
@property
|
223 |
+
def is_flush_needed(self) -> bool:
|
224 |
+
return self.autoflush and not self.ready_checklist['flushed']
|
225 |
+
|
226 |
+
@property
|
227 |
+
def is_ready(self) -> bool:
|
228 |
+
return self.is_ready_except_flush and not self.is_flush_needed
|
229 |
+
|
230 |
+
@property
|
231 |
+
def is_dirty(self) -> bool:
|
232 |
+
return not (self.update_buffer['background'] is None and len(self.update_buffer['layers']) == 0)
|
233 |
+
|
234 |
+
@property
|
235 |
+
def has_background(self) -> bool:
|
236 |
+
return self.background.is_empty
|
237 |
+
|
238 |
+
# @property
|
239 |
+
# def has_layers(self) -> bool:
|
240 |
+
# return len(self.layers) > 0
|
241 |
+
|
242 |
+
def __repr__(self) -> str:
|
243 |
+
return (
|
244 |
+
f'{type(self).__name__}(\n\tbackground: {str(self.background)},\n\t'
|
245 |
+
f'model_key: {self.state["model_key"]}\n)'
|
246 |
+
# f'layers: {str(self.layers)},\n\tmodel_key: {self.state["model_key"]}\n)'
|
247 |
+
)
|
248 |
+
|
249 |
+
def check_integrity(self, throw_error: bool = True) -> bool:
|
250 |
+
p = len(self.prompts)
|
251 |
+
flag = (
|
252 |
+
p != len(self.negative_prompts) or
|
253 |
+
p != len(self.prompt_strengths) or
|
254 |
+
p != len(self.masks) or
|
255 |
+
p != len(self.mask_strengths) or
|
256 |
+
p != len(self.mask_stds) or
|
257 |
+
p != len(self.original_masks)
|
258 |
+
)
|
259 |
+
if flag and throw_error:
|
260 |
+
print(
|
261 |
+
f'LayerState(\n\tlen(prompts): {p},\n\tlen(negative_prompts): {len(self.negative_prompts)},\n\t'
|
262 |
+
f'len(prompt_strengths): {len(self.prompt_strengths)},\n\tlen(masks): {len(self.masks)},\n\t'
|
263 |
+
f'len(mask_stds): {len(self.mask_stds)},\n\tlen(mask_strengths): {len(self.mask_strengths)},\n\t'
|
264 |
+
f'len(original_masks): {len(self.original_masks)}\n)'
|
265 |
+
)
|
266 |
+
raise ValueError('[ERROR] LayerState is corrupted!')
|
267 |
+
return not flag
|
268 |
+
|
269 |
+
def check_ready(self) -> bool:
|
270 |
+
all_except_flushed = all(v for k, v in self.ready_checklist.items() if k != 'flushed')
|
271 |
+
if all_except_flushed:
|
272 |
+
if self.is_flush_needed:
|
273 |
+
self.flush()
|
274 |
+
return True
|
275 |
+
|
276 |
+
print('[WARNING] MagicDraw module is not ready yet! Complete the checklist:')
|
277 |
+
for k, v in self.ready_checklist.items():
|
278 |
+
prefix = ' [ v ] ' if v else ' [ x ] '
|
279 |
+
print(prefix + k.replace('_', ' '))
|
280 |
+
return False
|
281 |
+
|
282 |
+
def reset_seed(self, generator: Optional[torch.Generator] = None, seed: Optional[int] = None) -> None:
|
283 |
+
generator = torch.Generator(self.device) if generator is None else generator
|
284 |
+
seed = self.seed if seed is None else seed
|
285 |
+
self.generator = generator
|
286 |
+
self.generator.manual_seed(seed)
|
287 |
+
|
288 |
+
self.init_noise = torch.randn((self.batch_size, 4, self.latent_height, self.latent_width),
|
289 |
+
generator=generator, device=self.device, dtype=self.dtype)
|
290 |
+
self.stock_noise = torch.zeros_like(self.init_noise)
|
291 |
+
|
292 |
+
self.ready_checklist['flushed'] = False
|
293 |
+
|
294 |
+
def reset_latent(self) -> None:
|
295 |
+
# initialize x_t_latent (it can be any random tensor)
|
296 |
+
b = (self.denoising_steps_num - 1) * self.frame_bff_size
|
297 |
+
self.x_t_latent_buffer = torch.zeros(
|
298 |
+
(b, 4, self.latent_height, self.latent_width), dtype=self.dtype, device=self.device)
|
299 |
+
|
300 |
+
def reset_state(self) -> None:
|
301 |
+
# TODO Reset states for context switch between multiple users.
|
302 |
+
pass
|
303 |
+
|
304 |
+
def prepare(self) -> None:
|
305 |
+
# make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
|
306 |
+
self.timesteps = self.scheduler.timesteps.to(self.device)
|
307 |
+
self.sub_timesteps = []
|
308 |
+
for t in self.t_list:
|
309 |
+
self.sub_timesteps.append(self.timesteps[t])
|
310 |
+
sub_timesteps_tensor = torch.tensor(self.sub_timesteps, dtype=torch.long, device=self.device)
|
311 |
+
self.sub_timesteps_tensor = sub_timesteps_tensor.repeat_interleave(self.frame_bff_size, dim=0)
|
312 |
+
|
313 |
+
c_skip_list = []
|
314 |
+
c_out_list = []
|
315 |
+
for timestep in self.sub_timesteps:
|
316 |
+
c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep)
|
317 |
+
c_skip_list.append(c_skip)
|
318 |
+
c_out_list.append(c_out)
|
319 |
+
self.c_skip = torch.stack(c_skip_list).view(len(self.t_list), 1, 1, 1).to(dtype=self.dtype, device=self.device)
|
320 |
+
self.c_out = torch.stack(c_out_list).view(len(self.t_list), 1, 1, 1).to(dtype=self.dtype, device=self.device)
|
321 |
+
|
322 |
+
alpha_prod_t_sqrt_list = []
|
323 |
+
beta_prod_t_sqrt_list = []
|
324 |
+
for timestep in self.sub_timesteps:
|
325 |
+
alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
|
326 |
+
beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
|
327 |
+
alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
|
328 |
+
beta_prod_t_sqrt_list.append(beta_prod_t_sqrt)
|
329 |
+
alpha_prod_t_sqrt = (torch.stack(alpha_prod_t_sqrt_list).view(len(self.t_list), 1, 1, 1)
|
330 |
+
.to(dtype=self.dtype, device=self.device))
|
331 |
+
beta_prod_t_sqrt = (torch.stack(beta_prod_t_sqrt_list).view(len(self.t_list), 1, 1, 1)
|
332 |
+
.to(dtype=self.dtype, device=self.device))
|
333 |
+
self.alpha_prod_t_sqrt = alpha_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0)
|
334 |
+
self.beta_prod_t_sqrt = beta_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0)
|
335 |
+
|
336 |
+
noise_lvs = ((1 - self.scheduler.alphas_cumprod.to(self.device)[self.sub_timesteps_tensor]) ** 0.5)
|
337 |
+
self.noise_lvs = noise_lvs[None, :, None, None, None]
|
338 |
+
self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
|
339 |
+
|
340 |
+
@torch.no_grad()
|
341 |
+
def get_text_prompts(self, image: Image.Image) -> str:
|
342 |
+
r"""A convenient method to extract text prompt from an image.
|
343 |
+
|
344 |
+
This is called if the user does not provide background prompt but only
|
345 |
+
the background image. We use BLIP-2 to automatically generate prompts.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
image (Image.Image): A PIL image.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
A single string of text prompt.
|
352 |
+
"""
|
353 |
+
question = 'Question: What are in the image? Answer:'
|
354 |
+
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
355 |
+
out = self.i2t_model.generate(**inputs, max_new_tokens=77)
|
356 |
+
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
357 |
+
return prompt
|
358 |
+
|
359 |
+
@torch.no_grad()
|
360 |
+
def encode_imgs(
|
361 |
+
self,
|
362 |
+
imgs: torch.Tensor,
|
363 |
+
generator: Optional[torch.Generator] = None,
|
364 |
+
add_noise: bool = False,
|
365 |
+
) -> torch.Tensor:
|
366 |
+
r"""A wrapper function for VAE encoder of the latent diffusion model.
|
367 |
+
|
368 |
+
Args:
|
369 |
+
imgs (torch.Tensor): An image to get StableDiffusion latents.
|
370 |
+
Expected shape: (B, 3, H, W). Expected pixel scale: [0, 1].
|
371 |
+
generator (Optional[torch.Generator]): Seed for KL-Autoencoder.
|
372 |
+
add_noise (bool): Turn this on for a noisy latent.
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
An image latent embedding with 1/8 size (depending on the auto-
|
376 |
+
encoder. Shape: (B, 4, H//8, W//8).
|
377 |
+
"""
|
378 |
+
def _retrieve_latents(
|
379 |
+
encoder_output: torch.Tensor,
|
380 |
+
generator: Optional[torch.Generator] = None,
|
381 |
+
sample_mode: str = 'sample',
|
382 |
+
):
|
383 |
+
if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample':
|
384 |
+
return encoder_output.latent_dist.sample(generator)
|
385 |
+
elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax':
|
386 |
+
return encoder_output.latent_dist.mode()
|
387 |
+
elif hasattr(encoder_output, 'latents'):
|
388 |
+
return encoder_output.latents
|
389 |
+
else:
|
390 |
+
raise AttributeError('[ERROR] Could not access latents of provided encoder_output')
|
391 |
+
|
392 |
+
imgs = 2 * imgs - 1
|
393 |
+
latents = self.vae.config.scaling_factor * _retrieve_latents(self.vae.encode(imgs), generator=generator)
|
394 |
+
if add_noise:
|
395 |
+
latents = self.alpha_prod_t_sqrt[0] * latents + self.beta_prod_t_sqrt[0] * self.init_noise[0]
|
396 |
+
return latents
|
397 |
+
|
398 |
+
@torch.no_grad()
|
399 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
400 |
+
r"""A wrapper function for VAE decoder of the latent diffusion model.
|
401 |
+
|
402 |
+
Args:
|
403 |
+
latents (torch.Tensor): An image latent to get associated images.
|
404 |
+
Expected shape: (B, 4, H//8, W//8).
|
405 |
+
|
406 |
+
Returns:
|
407 |
+
An image latent embedding with 1/8 size (depending on the auto-
|
408 |
+
encoder. Shape: (B, 3, H, W).
|
409 |
+
"""
|
410 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
411 |
+
imgs = self.vae.decode(latents).sample
|
412 |
+
imgs = (imgs / 2 + 0.5).clip_(0, 1)
|
413 |
+
return imgs
|
414 |
+
|
415 |
+
@torch.no_grad()
|
416 |
+
def update_background(
|
417 |
+
self,
|
418 |
+
image: Optional[Image.Image] = None,
|
419 |
+
prompt: Optional[str] = None,
|
420 |
+
negative_prompt: Optional[str] = None,
|
421 |
+
) -> bool:
|
422 |
+
flag_changed = False
|
423 |
+
if image is not None:
|
424 |
+
image_ = image.resize((self.width, self.height))
|
425 |
+
prompt = self.get_text_prompts(image_) if prompt is None else prompt
|
426 |
+
negative_prompt = '' if negative_prompt is None else negative_prompt
|
427 |
+
embed = self.pipe.encode_prompt(
|
428 |
+
prompt=[prompt],
|
429 |
+
device=self.device,
|
430 |
+
num_images_per_prompt=1,
|
431 |
+
do_classifier_free_guidance=(self.guidance_scale > 1.0),
|
432 |
+
negative_prompt=[negative_prompt],
|
433 |
+
) # ((1, 77, 768): cond, (1, 77, 768): uncond)
|
434 |
+
|
435 |
+
self.state['background'].image = image
|
436 |
+
self.state['background'].latent = (
|
437 |
+
self.encode_imgs(T.ToTensor()(image_)[None].to(self.device, self.dtype))
|
438 |
+
) # (1, 3, H, W)
|
439 |
+
self.state['background'].prompt = prompt
|
440 |
+
self.state['background'].negative_prompt = negative_prompt
|
441 |
+
self.state['background'].embed = embed
|
442 |
+
|
443 |
+
if self.bootstrap_steps[0] > 0:
|
444 |
+
mix_ratio = self.bootstrap_mix_ratios[:, None, None, None]
|
445 |
+
self.bootstrap_latent = mix_ratio * self.white + (1.0 - mix_ratio) * self.state['background'].latent
|
446 |
+
|
447 |
+
self.ready_checklist['background_registered'] = True
|
448 |
+
flag_changed = True
|
449 |
+
else:
|
450 |
+
if not self.ready_checklist['background_registered']:
|
451 |
+
print('[WARNING] Register background image first! Request ignored.')
|
452 |
+
return False
|
453 |
+
|
454 |
+
if prompt is not None:
|
455 |
+
self.background.prompt = prompt
|
456 |
+
flag_changed = True
|
457 |
+
if negative_prompt is not None:
|
458 |
+
self.background.negative_prompt = negative_prompt
|
459 |
+
flag_changed = True
|
460 |
+
if flag_changed:
|
461 |
+
self.background.embed = self.pipe.encode_prompt(
|
462 |
+
prompt=[self.background.prompt],
|
463 |
+
device=self.device,
|
464 |
+
num_images_per_prompt=1,
|
465 |
+
do_classifier_free_guidance=(self.guidance_scale > 1.0),
|
466 |
+
negative_prompt=[self.background.negative_prompt],
|
467 |
+
) # ((1, 77, 768): cond, (1, 77, 768): uncond)
|
468 |
+
|
469 |
+
self.ready_checklist['flushed'] = not flag_changed
|
470 |
+
return flag_changed
|
471 |
+
|
472 |
+
@torch.no_grad()
|
473 |
+
def process_mask(
|
474 |
+
self,
|
475 |
+
masks: Optional[Union[torch.Tensor, Image.Image, List[Image.Image]]] = None,
|
476 |
+
strength: Optional[Union[torch.Tensor, float]] = None,
|
477 |
+
std: Optional[Union[torch.Tensor, float]] = None,
|
478 |
+
) -> Tuple[torch.Tensor]:
|
479 |
+
r"""Fast preprocess of masks for region-based generation with fine-
|
480 |
+
grained controls.
|
481 |
+
|
482 |
+
Mask preprocessing is done in four steps:
|
483 |
+
1. Resizing: Resize the masks into the specified width and height by
|
484 |
+
nearest neighbor interpolation.
|
485 |
+
2. (Optional) Ordering: Masks with higher indices are considered to
|
486 |
+
cover the masks with smaller indices. Covered masks are decayed
|
487 |
+
in its alpha value by the specified factor of
|
488 |
+
`preprocess_mask_cover_alpha`.
|
489 |
+
3. Blurring: Gaussian blur is applied to the mask with the specified
|
490 |
+
standard deviation (isotropic). This results in gradual increase of
|
491 |
+
masked region as the timesteps evolve, naturally blending fore-
|
492 |
+
ground and the predesignated background. Not strictly required if
|
493 |
+
you want to produce images from scratch withoout background.
|
494 |
+
4. Quantization: Split the real-numbered masks of value between [0, 1]
|
495 |
+
into predefined noise levels for each quantized scheduling step of
|
496 |
+
the diffusion sampler. For example, if the diffusion model sampler
|
497 |
+
has noise level of [0.9977, 0.9912, 0.9735, 0.8499, 0.5840], which
|
498 |
+
is the default noise level of this module with schedule [0, 4, 12,
|
499 |
+
25, 37], the masks are split into binary masks whose values are
|
500 |
+
greater than these levels. This results in tradual increase of mask
|
501 |
+
region as the timesteps increase. Details are described in our
|
502 |
+
paper at https://arxiv.org/pdf/2403.09055.pdf.
|
503 |
+
|
504 |
+
On the Three Modes of `mask_type`:
|
505 |
+
`self.mask_type` is predefined at the initialization stage of this
|
506 |
+
pipeline. Three possible modes are available: 'discrete', 'semi-
|
507 |
+
continuous', and 'continuous'. These define the mask quantization
|
508 |
+
modes we use. Basically, this (subtly) controls the smoothness of
|
509 |
+
foreground-background blending. Continuous modes produces nonbinary
|
510 |
+
masks to further blend foreground and background latents by linear-
|
511 |
+
ly interpolating between them. Semi-continuous masks only applies
|
512 |
+
continuous mask at the last step of the LCM sampler. Due to the
|
513 |
+
large step size of the LCM scheduler, we find that our continuous
|
514 |
+
blending helps generating seamless inpainting and editing results.
|
515 |
+
|
516 |
+
Args:
|
517 |
+
masks (Union[torch.Tensor, Image.Image, List[Image.Image]]): Masks.
|
518 |
+
strength (Optional[Union[torch.Tensor, float]]): Mask strength that
|
519 |
+
overrides the default value. A globally multiplied factor to
|
520 |
+
the mask at the initial stage of processing. Can be applied
|
521 |
+
seperately for each mask.
|
522 |
+
std (Optional[Union[torch.Tensor, float]]): Mask blurring Gaussian
|
523 |
+
kernel's standard deviation. Overrides the default value. Can
|
524 |
+
be applied seperately for each mask.
|
525 |
+
|
526 |
+
Returns: A tuple of tensors.
|
527 |
+
- masks: Preprocessed (ordered, blurred, and quantized) binary/non-
|
528 |
+
binary masks (see the explanation on `mask_type` above) for
|
529 |
+
region-based image synthesis.
|
530 |
+
- strengths: Return mask strengths for caching.
|
531 |
+
- std: Return mask blur standard deviations for caching.
|
532 |
+
- original_masks: Return original masks for caching.
|
533 |
+
"""
|
534 |
+
if masks is None:
|
535 |
+
kwargs = {'dtype': self.dtype, 'device': self.device}
|
536 |
+
original_masks = torch.zeros((0, 1, self.latent_height, self.latent_width), dtype=self.dtype)
|
537 |
+
masks = torch.zeros((0, self.batch_size, 1, self.latent_height, self.latent_width), **kwargs)
|
538 |
+
strength = torch.zeros((0,), **kwargs)
|
539 |
+
std = torch.zeros((0,), **kwargs)
|
540 |
+
return masks, strength, std, original_masks
|
541 |
+
|
542 |
+
if isinstance(masks, Image.Image):
|
543 |
+
masks = [masks]
|
544 |
+
if isinstance(masks, (tuple, list)):
|
545 |
+
# Assumes white background for Image.Image;
|
546 |
+
# inverted boolean masks with shape (1, 1, H, W) for torch.Tensor.
|
547 |
+
masks = torch.cat([
|
548 |
+
# (T.ToTensor()(mask.resize((self.width, self.height), Image.NEAREST)) < 0.5)[None, :1]
|
549 |
+
(1.0 - T.ToTensor()(mask.resize((self.width, self.height), Image.BILINEAR)))[None, :1]
|
550 |
+
for mask in masks
|
551 |
+
], dim=0).float().clip_(0, 1)
|
552 |
+
original_masks = masks
|
553 |
+
masks = masks.float().to(self.device)
|
554 |
+
|
555 |
+
# Background mask alpha is decayed by the specified factor where foreground masks covers it.
|
556 |
+
if self.preprocess_mask_cover_alpha > 0:
|
557 |
+
masks = torch.stack([
|
558 |
+
torch.where(
|
559 |
+
masks[i + 1:].sum(dim=0) > 0,
|
560 |
+
mask * self.preprocess_mask_cover_alpha,
|
561 |
+
mask,
|
562 |
+
) if i < len(masks) - 1 else mask
|
563 |
+
for i, mask in enumerate(masks)
|
564 |
+
], dim=0)
|
565 |
+
|
566 |
+
if std is None:
|
567 |
+
std = self.default_mask_std
|
568 |
+
if isinstance(std, (int, float)):
|
569 |
+
std = [std] * len(masks)
|
570 |
+
if isinstance(std, (list, tuple)):
|
571 |
+
std = torch.as_tensor(std, dtype=torch.float, device=self.device)
|
572 |
+
|
573 |
+
# Mask preprocessing parameters are fetched from the default settings.
|
574 |
+
if strength is None:
|
575 |
+
strength = self.default_mask_strength
|
576 |
+
if isinstance(strength, (int, float)):
|
577 |
+
strength = [strength] * len(masks)
|
578 |
+
if isinstance(strength, (list, tuple)):
|
579 |
+
strength = torch.as_tensor(strength, dtype=torch.float, device=self.device)
|
580 |
+
|
581 |
+
if (std > 0).any():
|
582 |
+
std = torch.where(std > 0, std, 1e-5)
|
583 |
+
masks = gaussian_lowpass(masks, std)
|
584 |
+
# NOTE: This `strength` aligns with `denoising strength`. However, with LCM, using strength < 0.96
|
585 |
+
# gives unpleasant results.
|
586 |
+
masks = masks * strength[:, None, None, None]
|
587 |
+
masks = masks.unsqueeze(1).repeat(1, self.noise_lvs.shape[1], 1, 1, 1)
|
588 |
+
|
589 |
+
if self.mask_type == 'discrete':
|
590 |
+
# Discrete mode.
|
591 |
+
masks = masks > self.noise_lvs
|
592 |
+
elif self.mask_type == 'semi-continuous':
|
593 |
+
# Semi-continuous mode (continuous at the last step only).
|
594 |
+
masks = torch.cat((
|
595 |
+
masks[:, :-1] > self.noise_lvs[:, :-1],
|
596 |
+
(
|
597 |
+
(masks[:, -1:] - self.next_noise_lvs[:, -1:])
|
598 |
+
/ (self.noise_lvs[:, -1:] - self.next_noise_lvs[:, -1:])
|
599 |
+
).clip_(0, 1),
|
600 |
+
), dim=1)
|
601 |
+
elif self.mask_type == 'continuous':
|
602 |
+
# Continuous mode: Have the exact same `1` coverage with discrete mode, but the mask gradually
|
603 |
+
# decreases continuously after the discrete mode boundary to become `0` at the
|
604 |
+
# next lower threshold.
|
605 |
+
masks = ((masks - self.next_noise_lvs) / (self.noise_lvs - self.next_noise_lvs)).clip_(0, 1)
|
606 |
+
|
607 |
+
# NOTE: Post processing mask strength does not align with conventional 'denoising_strength'. However,
|
608 |
+
# fine-grained mask alpha channel tuning is available with this form.
|
609 |
+
# masks = masks * strength[None, :, None, None, None]
|
610 |
+
|
611 |
+
masks = rearrange(masks.float(), 'p t () h w -> (p t) () h w')
|
612 |
+
masks = F.interpolate(masks, size=(self.latent_height, self.latent_width), mode='nearest')
|
613 |
+
masks = rearrange(masks.to(self.dtype), '(p t) () h w -> p t () h w', p=len(std))
|
614 |
+
return masks, strength, std, original_masks
|
615 |
+
|
616 |
+
@torch.no_grad()
|
617 |
+
def update_layers(
|
618 |
+
self,
|
619 |
+
prompts: Union[str, List[str]],
|
620 |
+
negative_prompts: Optional[Union[str, List[str]]] = None,
|
621 |
+
suffix: Optional[str] = None, #', background is ',
|
622 |
+
prompt_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
623 |
+
masks: Optional[Union[torch.Tensor, Image.Image, List[Image.Image]]] = None,
|
624 |
+
mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
625 |
+
mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None,
|
626 |
+
) -> None:
|
627 |
+
if not self.ready_checklist['background_registered']:
|
628 |
+
print('[WARNING] Register background image first! Request ignored.')
|
629 |
+
return
|
630 |
+
|
631 |
+
### Register prompts
|
632 |
+
|
633 |
+
if isinstance(prompts, str):
|
634 |
+
prompts = [prompts]
|
635 |
+
if negative_prompts is None:
|
636 |
+
negative_prompts = ''
|
637 |
+
if isinstance(negative_prompts, str):
|
638 |
+
negative_prompts = [negative_prompts]
|
639 |
+
fg_prompt = [p + suffix + self.background.prompt if suffix is not None else p for p in prompts]
|
640 |
+
self.prompts = fg_prompt
|
641 |
+
self.negative_prompts = negative_prompts
|
642 |
+
p = self.num_layers
|
643 |
+
|
644 |
+
e = self.pipe.encode_prompt(
|
645 |
+
prompt=fg_prompt,
|
646 |
+
device=self.device,
|
647 |
+
num_images_per_prompt=1,
|
648 |
+
do_classifier_free_guidance=(self.guidance_scale > 1.0),
|
649 |
+
negative_prompt=negative_prompts,
|
650 |
+
) # (p, 77, 768)
|
651 |
+
|
652 |
+
if prompt_strengths is None:
|
653 |
+
prompt_strengths = self.default_prompt_strength
|
654 |
+
if isinstance(prompt_strengths, (int, float)):
|
655 |
+
prompt_strengths = [prompt_strengths] * p
|
656 |
+
if isinstance(prompt_strengths, (list, tuple)):
|
657 |
+
prompt_strengths = torch.as_tensor(prompt_strengths, dtype=self.dtype, device=self.device)
|
658 |
+
self.prompt_strengths = prompt_strengths
|
659 |
+
|
660 |
+
s = prompt_strengths[:, None, None]
|
661 |
+
self.prompt_embeds = torch.lerp(self.background.embed[0], e[0], s).repeat(self.batch_size, 1, 1) # (T * p, 77, 768)
|
662 |
+
if self.guidance_scale > 1.0 and self.cfg_type in ('initialize', 'full'):
|
663 |
+
b = self.batch_size if self.cfg_type == 'full' else self.frame_bff_size
|
664 |
+
uncond_prompt_embeds = torch.lerp(self.background.embed[1], e[1], s).repeat(b, 1, 1) # (T * p, 77, 768)
|
665 |
+
self.prompt_embeds = torch.cat([uncond_prompt_embeds, self.prompt_embeds], dim=0) # (2 * T * p, 77, 768)
|
666 |
+
|
667 |
+
self.sub_timesteps_tensor_ = self.sub_timesteps_tensor.repeat_interleave(p) # (T * p,)
|
668 |
+
self.init_noise_ = self.init_noise.repeat_interleave(p, dim=0) # (T * p, 77, 768)
|
669 |
+
self.stock_noise_ = self.stock_noise.repeat_interleave(p, dim=0) # (T * p, 77, 768)
|
670 |
+
self.c_out_ = self.c_out.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
|
671 |
+
self.c_skip_ = self.c_skip.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
|
672 |
+
self.beta_prod_t_sqrt_ = self.beta_prod_t_sqrt.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
|
673 |
+
self.alpha_prod_t_sqrt_ = self.alpha_prod_t_sqrt.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
|
674 |
+
|
675 |
+
### Register new masks
|
676 |
+
|
677 |
+
if isinstance(masks, Image.Image):
|
678 |
+
masks = [masks]
|
679 |
+
n = len(masks) if masks is not None else 0
|
680 |
+
|
681 |
+
# Modificiation.
|
682 |
+
masks, mask_strengths, mask_stds, original_masks = self.process_mask(masks, mask_strengths, mask_stds)
|
683 |
+
|
684 |
+
self.counts = masks.sum(dim=0) # (T, 1, h, w)
|
685 |
+
self.bg_mask = (1 - self.counts).clip_(0, 1) # (T, 1, h, w)
|
686 |
+
self.masks = masks # (p, T, 1, h, w)
|
687 |
+
self.mask_strengths = mask_strengths # (p,)
|
688 |
+
self.mask_stds = mask_stds # (p,)
|
689 |
+
self.original_masks = original_masks # (p, 1, h, w)
|
690 |
+
|
691 |
+
if p > n:
|
692 |
+
# Add more masks: counts and bg_masks are not changed, but only masks are changed.
|
693 |
+
self.masks = torch.cat((
|
694 |
+
self.masks,
|
695 |
+
torch.zeros(
|
696 |
+
(p - n, self.batch_size, 1, self.latent_height, self.latent_width),
|
697 |
+
dtype=self.dtype,
|
698 |
+
device=self.device,
|
699 |
+
),
|
700 |
+
), dim=0)
|
701 |
+
print(f'[WARNING] Detected more prompts ({p}) than masks ({n}). '
|
702 |
+
'Automatically adds blank masks for the additional prompts.')
|
703 |
+
elif p < n:
|
704 |
+
# Warns user to add more prompts.
|
705 |
+
print(f'[WARNING] Detected more masks ({n}) than prompts ({p}). '
|
706 |
+
'Additional masks are ignored until more prompts are provided.')
|
707 |
+
|
708 |
+
self.ready_checklist['layers_ready'] = True
|
709 |
+
self.ready_checklist['flushed'] = False
|
710 |
+
|
711 |
+
@torch.no_grad()
|
712 |
+
def update_single_layer(
|
713 |
+
self,
|
714 |
+
idx: Optional[int] = None,
|
715 |
+
prompt: Optional[str] = None,
|
716 |
+
negative_prompt: Optional[str] = None,
|
717 |
+
suffix: Optional[str] = None, #', background is ',
|
718 |
+
prompt_strength: Optional[float] = None,
|
719 |
+
mask: Optional[Union[torch.Tensor, Image.Image]] = None,
|
720 |
+
mask_strength: Optional[float] = None,
|
721 |
+
mask_std: Optional[float] = None,
|
722 |
+
) -> None:
|
723 |
+
|
724 |
+
### Possible input combinations and expected behaviors
|
725 |
+
|
726 |
+
# The module will consider a layer, a pair of (prompt, mask), to be 'active' only if a prompt
|
727 |
+
# is registered. A blank mask will be assigned if no mask is provided for the 'active' layer.
|
728 |
+
# The layers should be in either of ('active', 'inactive') states. 'inactive' layers will not
|
729 |
+
# receive any input unless equipped with prompt. 'active' layers receive any input and modify
|
730 |
+
# their states accordingly. In the actual implementation, only the 'active' layers are stored
|
731 |
+
# and can be accessed by the fields. Values len(self.prompts) = self.num_layers is the number
|
732 |
+
# of 'active' layers.
|
733 |
+
|
734 |
+
# If no background is registered. The layers should be all 'inactive'.
|
735 |
+
if not self.ready_checklist['background_registered']:
|
736 |
+
print('[WARNING] Register background image first! Request ignored.')
|
737 |
+
return
|
738 |
+
|
739 |
+
# The first layer create request should be carrying a prompt. If only mask is drawn without a
|
740 |
+
# prompt, it just ignores the request--the user will update her request soon.
|
741 |
+
if self.num_layers == 0:
|
742 |
+
if prompt is not None:
|
743 |
+
self.update_layers(
|
744 |
+
prompts=prompt,
|
745 |
+
negative_prompts=negative_prompt,
|
746 |
+
suffix=suffix,
|
747 |
+
prompt_strengths=prompt_strength,
|
748 |
+
masks=mask,
|
749 |
+
mask_strengths=mask_strength,
|
750 |
+
mask_stds=mask_std,
|
751 |
+
)
|
752 |
+
return
|
753 |
+
|
754 |
+
# Invalid request indices -> considered as a layer add request.
|
755 |
+
if idx is None or idx > self.num_layers or idx < 0:
|
756 |
+
idx = self.num_layers
|
757 |
+
|
758 |
+
# Two modes for the layer edits: 'append mode' and 'edit mode'. 'append mode' appends a new
|
759 |
+
# layer at the end of the layers list. 'edit mode' modifies internal variables for the given
|
760 |
+
# index. 'append mode' is defined by the request index and strictly requires a prompt input.
|
761 |
+
is_appending = idx == self.num_layers
|
762 |
+
if is_appending and prompt is None:
|
763 |
+
print(f'[WARNING] Creating a new prompt at index ({idx}) but found no prompt. Request ignored.')
|
764 |
+
return
|
765 |
+
|
766 |
+
### Register prompts
|
767 |
+
|
768 |
+
# | prompt | neg_prompt | append mode (idx==len) | edit mode (idx<len) |
|
769 |
+
# | --------- | ---------- | ----------------------- | -------------------- |
|
770 |
+
# | given | given | append new prompt embed | replace prompt embed |
|
771 |
+
# | given | not given | append new prompt embed | replace prompt embed |
|
772 |
+
# | not given | given | NOT ALLOWED | replace prompt embed |
|
773 |
+
# | not given | not given | NOT ALLOWED | do nothing |
|
774 |
+
|
775 |
+
# | prompt_strength | append mode (idx==len) | edit mode (idx<len) |
|
776 |
+
# | --------------- | ---------------------- | ---------------------------------------------- |
|
777 |
+
# | given | use given strength | use given strength |
|
778 |
+
# | not given | use default strength | replace strength / if no existing, use default |
|
779 |
+
|
780 |
+
p = self.num_layers
|
781 |
+
|
782 |
+
flag_prompt_edited = (
|
783 |
+
prompt is not None or
|
784 |
+
negative_prompt is not None or
|
785 |
+
prompt_strength is not None
|
786 |
+
)
|
787 |
+
|
788 |
+
if flag_prompt_edited:
|
789 |
+
is_double_cond = self.guidance_scale > 1.0 and self.cfg_type in ('initialize', 'full')
|
790 |
+
|
791 |
+
# Synchonize the internal state.
|
792 |
+
|
793 |
+
# We have asserted that prompt is not None if the mode is 'appending'.
|
794 |
+
if prompt is not None:
|
795 |
+
if suffix is not None:
|
796 |
+
prompt = prompt + suffix + self.background.prompt
|
797 |
+
if is_appending:
|
798 |
+
self.prompts.append(prompt)
|
799 |
+
else:
|
800 |
+
self.prompts[idx] = prompt
|
801 |
+
|
802 |
+
if negative_prompt is not None:
|
803 |
+
if is_appending:
|
804 |
+
self.negative_prompts.append(negative_prompt)
|
805 |
+
else:
|
806 |
+
self.negative_prompts[idx] = negative_prompt
|
807 |
+
elif is_appending:
|
808 |
+
# Make sure that negative prompts are well specified.
|
809 |
+
self.negative_prompts.append('')
|
810 |
+
|
811 |
+
if is_appending:
|
812 |
+
if prompt_strength is None:
|
813 |
+
prompt_strength = self.default_prompt_strength
|
814 |
+
self.prompt_strengths = torch.cat((
|
815 |
+
self.prompt_strengths,
|
816 |
+
torch.as_tensor([prompt_strength], dtype=self.dtype, device=self.device),
|
817 |
+
), dim=0)
|
818 |
+
elif prompt_strength is not None:
|
819 |
+
self.prompt_strengths[idx] = prompt_strength
|
820 |
+
|
821 |
+
# Edit currently stored prompt embeddings.
|
822 |
+
|
823 |
+
if is_double_cond:
|
824 |
+
uncond_prompt_embed_, prompt_embed_ = torch.chunk(self.prompt_embeds, 2, dim=0)
|
825 |
+
uncond_prompt_embed_ = rearrange(uncond_prompt_embed_, '(t p) c1 c2 -> t p c1 c2', p=p)
|
826 |
+
prompt_embed_ = rearrange(prompt_embed_, '(t p) c1 c2 -> t p c1 c2', p=p)
|
827 |
+
else:
|
828 |
+
uncond_prompt_embed_ = None
|
829 |
+
prompt_embed_ = rearrange(self.prompt_embeds, '(t p) c1 c2 -> t p c1 c2', p=p)
|
830 |
+
|
831 |
+
e = self.pipe.encode_prompt(
|
832 |
+
prompt=self.prompts[idx],
|
833 |
+
device=self.device,
|
834 |
+
num_images_per_prompt=1,
|
835 |
+
do_classifier_free_guidance=(self.guidance_scale > 1.0),
|
836 |
+
negative_prompt=self.negative_prompts[idx],
|
837 |
+
) # (1, 77, 768), (1, 77, 768)
|
838 |
+
|
839 |
+
s = self.prompt_strengths[idx]
|
840 |
+
t = prompt_embed_.shape[0]
|
841 |
+
prompt_embed = torch.lerp(self.background.embed[0], e[0], s)[None].repeat(t, 1, 1, 1) # (1, 77, 768)
|
842 |
+
if is_double_cond:
|
843 |
+
uncond_prompt_embed = torch.lerp(self.background.embed[1], e[1], s)[None].repeat(t, 1, 1, 1) # (1, 77, 768)
|
844 |
+
|
845 |
+
if is_appending:
|
846 |
+
prompt_embed_ = torch.cat((prompt_embed_, prompt_embed), dim=1)
|
847 |
+
if is_double_cond:
|
848 |
+
uncond_prompt_embed_ = torch.cat((uncond_prompt_embed_, uncond_prompt_embed), dim=1)
|
849 |
+
else:
|
850 |
+
prompt_embed_[:, idx:(idx + 1)] = prompt_embed
|
851 |
+
if is_double_cond:
|
852 |
+
uncond_prompt_embed_[:, idx:(idx + 1)] = uncond_prompt_embed
|
853 |
+
|
854 |
+
self.prompt_embeds = rearrange(prompt_embed_, 't p c1 c2 -> (t p) c1 c2')
|
855 |
+
if is_double_cond:
|
856 |
+
uncond_prompt_embeds = rearrange(uncond_prompt_embed_, 't p c1 c2 -> (t p) c1 c2')
|
857 |
+
self.prompt_embeds = torch.cat([uncond_prompt_embeds, self.prompt_embeds], dim=0) # (2 * T * p, 77, 768)
|
858 |
+
|
859 |
+
self.ready_checklist['flushed'] = False
|
860 |
+
|
861 |
+
if is_appending:
|
862 |
+
p = self.num_layers
|
863 |
+
self.sub_timesteps_tensor_ = self.sub_timesteps_tensor.repeat_interleave(p) # (T * p,)
|
864 |
+
self.init_noise_ = self.init_noise.repeat_interleave(p, dim=0) # (T * p, 77, 768)
|
865 |
+
self.stock_noise_ = self.stock_noise.repeat_interleave(p, dim=0) # (T * p, 77, 768)
|
866 |
+
self.c_out_ = self.c_out.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
|
867 |
+
self.c_skip_ = self.c_skip.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
|
868 |
+
self.beta_prod_t_sqrt_ = self.beta_prod_t_sqrt.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
|
869 |
+
self.alpha_prod_t_sqrt_ = self.alpha_prod_t_sqrt.repeat_interleave(p, dim=0) # (T * p, 1, 1, 1)
|
870 |
+
|
871 |
+
### Register new masks
|
872 |
+
|
873 |
+
# | mask | std / str | append mode (idx==len) | edit mode (idx<len) |
|
874 |
+
# | --------- | --------- | ---------------------------- | ----------------------------- |
|
875 |
+
# | given | given | create mask with given val | create mask with given val |
|
876 |
+
# | given | not given | create mask with default val | create mask with existing val |
|
877 |
+
# | not given | given | create blank mask | replace mask with given val |
|
878 |
+
# | not given | not given | create blank mask | do nothing |
|
879 |
+
|
880 |
+
flag_nonzero_mask = False
|
881 |
+
if mask is not None:
|
882 |
+
# Mask image is given -> create mask.
|
883 |
+
mask, strength, std, original_mask = self.process_mask(mask, mask_strength, mask_std)
|
884 |
+
flag_nonzero_mask = True
|
885 |
+
|
886 |
+
elif is_appending:
|
887 |
+
# No given mask & append mode -> create white mask.
|
888 |
+
mask = torch.zeros(
|
889 |
+
(1, self.batch_size, 1, self.latent_height, self.latent_width),
|
890 |
+
dtype=self.dtype,
|
891 |
+
device=self.device,
|
892 |
+
)
|
893 |
+
strength = torch.as_tensor([self.default_mask_strength], dtype=self.dtype, device=self.device)
|
894 |
+
std = torch.as_tensor([self.default_mask_std], dtype=self.dtype, device=self.device)
|
895 |
+
original_mask = torch.zeros((1, 1, self.latent_height, self.latent_width), dtype=self.dtype)
|
896 |
+
|
897 |
+
elif mask_std is not None or mask_strength is not None:
|
898 |
+
# No given mask & edit mode & given std / str -> replace existing mask with given std / str.
|
899 |
+
if mask_std is None:
|
900 |
+
mask_std = self.mask_stds[idx:(idx + 1)]
|
901 |
+
if mask_strength is None:
|
902 |
+
mask_strength = self.mask_strengths[idx:(idx + 1)]
|
903 |
+
mask, strength, std, original_mask = self.process_mask(
|
904 |
+
self.original_masks[idx:(idx + 1)], mask_strength, mask_std)
|
905 |
+
flag_nonzero_mask = True
|
906 |
+
|
907 |
+
else:
|
908 |
+
# No given mask & no given std & edit mode -> Do nothing.
|
909 |
+
return
|
910 |
+
|
911 |
+
if is_appending:
|
912 |
+
# Append mode.
|
913 |
+
self.masks = torch.cat((self.masks, mask), dim=0) # (p, T, 1, h, w)
|
914 |
+
self.mask_strengths = torch.cat((self.mask_strengths, strength), dim=0) # (p,)
|
915 |
+
self.mask_stds = torch.cat((self.mask_stds, std), dim=0) # (p,)
|
916 |
+
self.original_masks = torch.cat((self.original_masks, original_mask), dim=0) # (p, 1, h, w)
|
917 |
+
if flag_nonzero_mask:
|
918 |
+
self.counts = self.counts + mask[0] if hasattr(self, 'counts') else mask[0] # (T, 1, h, w)
|
919 |
+
self.bg_mask = (1 - self.counts).clip_(0, 1) # (T, 1, h, w)
|
920 |
+
else:
|
921 |
+
# Edit mode.
|
922 |
+
if flag_nonzero_mask:
|
923 |
+
self.counts = self.counts - self.masks[idx] + mask[0] # (T, 1, h, w)
|
924 |
+
self.bg_mask = (1 - self.counts).clip_(0, 1) # (T, 1, h, w)
|
925 |
+
self.masks[idx:(idx + 1)] = mask # (p, T, 1, h, w)
|
926 |
+
self.mask_strengths[idx:(idx + 1)] = strength # (p,)
|
927 |
+
self.mask_stds[idx:(idx + 1)] = std # (p,)
|
928 |
+
self.original_masks[idx:(idx + 1)] = original_mask # (p, 1, h, w)
|
929 |
+
|
930 |
+
# if flag_nonzero_mask:
|
931 |
+
# self.ready_checklist['flushed'] = False
|
932 |
+
|
933 |
+
@torch.no_grad()
|
934 |
+
def register_all(
|
935 |
+
self,
|
936 |
+
prompts: Union[str, List[str]],
|
937 |
+
masks: Union[Image.Image, List[Image.Image]],
|
938 |
+
background: Image.Image,
|
939 |
+
background_prompt: Optional[str] = None,
|
940 |
+
background_negative_prompt: str = '',
|
941 |
+
negative_prompts: Union[str, List[str]] = '',
|
942 |
+
suffix: Optional[str] = None, #', background is ',
|
943 |
+
prompt_strengths: float = 1.0,
|
944 |
+
mask_strengths: float = 1.0,
|
945 |
+
mask_stds: Union[torch.Tensor, float] = 10.0,
|
946 |
+
) -> None:
|
947 |
+
# The order of this registration should not be changed!
|
948 |
+
self.update_background(background, background_prompt, background_negative_prompt)
|
949 |
+
self.update_layers(prompts, negative_prompts, suffix, prompt_strengths, masks, mask_strengths, mask_stds)
|
950 |
+
|
951 |
+
def update(
|
952 |
+
self,
|
953 |
+
background: Optional[Image.Image] = None,
|
954 |
+
background_prompt: Optional[str] = None,
|
955 |
+
background_negative_prompt: Optional[str] = None,
|
956 |
+
idx: Optional[int] = None,
|
957 |
+
prompt: Optional[str] = None,
|
958 |
+
negative_prompt: Optional[str] = None,
|
959 |
+
suffix: Optional[str] = None,
|
960 |
+
prompt_strength: Optional[float] = None,
|
961 |
+
mask: Optional[Union[torch.Tensor, Image.Image]] = None,
|
962 |
+
mask_strength: Optional[float] = None,
|
963 |
+
mask_std: Optional[float] = None,
|
964 |
+
) -> None:
|
965 |
+
# For lazy update (to solve minor synchonization problem with gradio).
|
966 |
+
bq = BackgroundObject(
|
967 |
+
image=background,
|
968 |
+
prompt=background_prompt,
|
969 |
+
negative_prompt=background_negative_prompt,
|
970 |
+
)
|
971 |
+
if not bq.is_empty:
|
972 |
+
self.update_buffer['background'] = bq
|
973 |
+
|
974 |
+
lq = LayerObject(
|
975 |
+
idx=idx,
|
976 |
+
prompt=prompt,
|
977 |
+
negative_prompt=negative_prompt,
|
978 |
+
suffix=suffix,
|
979 |
+
prompt_strength=prompt_strength,
|
980 |
+
mask=mask,
|
981 |
+
mask_strength=mask_strength,
|
982 |
+
mask_std=mask_std,
|
983 |
+
)
|
984 |
+
if not lq.is_empty:
|
985 |
+
limit = self.update_buffer['layers'].maxlen
|
986 |
+
|
987 |
+
# Optimize the prompt queue: Overrride uncommitted layers with the same idx.
|
988 |
+
new_q = deque(maxlen=limit)
|
989 |
+
for _ in range(len(self.update_buffer['layers'])):
|
990 |
+
# Check from the newest to the oldest.
|
991 |
+
# Copy old requests only if the current query does not carry those requests.
|
992 |
+
query = self.update_buffer['layers'].pop()
|
993 |
+
overriden = lq.merge(query)
|
994 |
+
if not overriden:
|
995 |
+
new_q.appendleft(query)
|
996 |
+
self.update_buffer['layers'] = new_q
|
997 |
+
|
998 |
+
if len(self.update_buffer['layers']) == limit:
|
999 |
+
print(f'[WARNING] Maximum prompt change query limit ({limit}) is reached. '
|
1000 |
+
f'Current query {lq} will be ignored.')
|
1001 |
+
else:
|
1002 |
+
self.update_buffer['layers'].append(lq)
|
1003 |
+
|
1004 |
+
@torch.no_grad()
|
1005 |
+
def commit(self) -> None:
|
1006 |
+
flag_changed = self.is_dirty
|
1007 |
+
bq = self.update_buffer['background']
|
1008 |
+
lq = self.update_buffer['layers']
|
1009 |
+
count_bq_req = int(bq is not None and not bq.is_empty)
|
1010 |
+
count_lq_req = len(lq)
|
1011 |
+
|
1012 |
+
if flag_changed:
|
1013 |
+
print(f'[INFO] Requests found: {count_bq_req} background requests '
|
1014 |
+
f'& {count_lq_req} layer requests:\n{str(bq)}, {", ".join([str(l) for l in lq])}')
|
1015 |
+
|
1016 |
+
bq = self.update_buffer['background']
|
1017 |
+
if bq is not None:
|
1018 |
+
self.update_background(**vars(bq))
|
1019 |
+
self.update_buffer['background'] = None
|
1020 |
+
|
1021 |
+
while len(lq) > 0:
|
1022 |
+
l = lq.popleft()
|
1023 |
+
self.update_single_layer(**vars(l))
|
1024 |
+
|
1025 |
+
if flag_changed:
|
1026 |
+
print(f'[INFO] Requests resolved: {count_bq_req} background requests '
|
1027 |
+
f'& {count_lq_req} layer requests.')
|
1028 |
+
|
1029 |
+
def scheduler_step_batch(
|
1030 |
+
self,
|
1031 |
+
model_pred_batch: torch.Tensor,
|
1032 |
+
x_t_latent_batch: torch.Tensor,
|
1033 |
+
idx: Optional[int] = None,
|
1034 |
+
) -> torch.Tensor:
|
1035 |
+
r"""Denoise-only step for reverse diffusion scheduler.
|
1036 |
+
|
1037 |
+
Args:
|
1038 |
+
model_pred_batch (torch.Tensor): Noise prediction results.
|
1039 |
+
x_t_latent_batch (torch.Tensor): Noisy latent.
|
1040 |
+
idx (Optional[int]): Instead of timesteps (in [0, 1000]-scale) use
|
1041 |
+
indices for the timesteps tensor (ranged in
|
1042 |
+
[0, len(timesteps)-1]). Specify only if a single-index, not
|
1043 |
+
stream-batched inference is what you want.
|
1044 |
+
|
1045 |
+
Returns:
|
1046 |
+
A denoised tensor with the same size as latent.
|
1047 |
+
"""
|
1048 |
+
if idx is None:
|
1049 |
+
F_theta = (x_t_latent_batch - self.beta_prod_t_sqrt_ * model_pred_batch) / self.alpha_prod_t_sqrt_
|
1050 |
+
denoised_batch = self.c_out_ * F_theta + self.c_skip_ * x_t_latent_batch
|
1051 |
+
else:
|
1052 |
+
F_theta = (x_t_latent_batch - self.beta_prod_t_sqrt[idx] * model_pred_batch) / self.alpha_prod_t_sqrt[idx]
|
1053 |
+
denoised_batch = self.c_out[idx] * F_theta + self.c_skip[idx] * x_t_latent_batch
|
1054 |
+
return denoised_batch
|
1055 |
+
|
1056 |
+
def unet_step(
|
1057 |
+
self,
|
1058 |
+
x_t_latent: torch.Tensor, # (T, 4, h, w)
|
1059 |
+
idx: Optional[int] = None,
|
1060 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1061 |
+
p = self.num_layers
|
1062 |
+
x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
|
1063 |
+
|
1064 |
+
if self.bootstrap_steps[0] > 0:
|
1065 |
+
# Background bootstrapping.
|
1066 |
+
bootstrap_latent = self.scheduler.add_noise(
|
1067 |
+
self.bootstrap_latent,
|
1068 |
+
self.stock_noise,
|
1069 |
+
torch.tensor(self.sub_timesteps_tensor, device=self.device),
|
1070 |
+
)
|
1071 |
+
x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
|
1072 |
+
bootstrap_mask = (
|
1073 |
+
self.masks * self.bootstrap_steps[None, :, None, None, None]
|
1074 |
+
+ (1.0 - self.bootstrap_steps[None, :, None, None, None])
|
1075 |
+
) # (p, t, c, h, w)
|
1076 |
+
x_t_latent = (1.0 - bootstrap_mask) * bootstrap_latent[None] + bootstrap_mask * x_t_latent
|
1077 |
+
x_t_latent = rearrange(x_t_latent, 'p t c h w -> (t p) c h w')
|
1078 |
+
|
1079 |
+
# Centering.
|
1080 |
+
x_t_latent = shift_to_mask_bbox_center(x_t_latent, rearrange(self.masks, 'p t c h w -> (t p) c h w'), reverse=True)
|
1081 |
+
|
1082 |
+
t_list = self.sub_timesteps_tensor_ # (T * p,)
|
1083 |
+
if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':
|
1084 |
+
x_t_latent_plus_uc = torch.concat([x_t_latent[:p], x_t_latent], dim=0) # (T * p + 1, 4, h, w)
|
1085 |
+
t_list = torch.concat([t_list[:p], t_list], dim=0) # (T * p + 1, 4, h, w)
|
1086 |
+
elif self.guidance_scale > 1.0 and self.cfg_type == 'full':
|
1087 |
+
x_t_latent_plus_uc = torch.concat([x_t_latent, x_t_latent], dim=0) # (2 * T * p, 4, h, w)
|
1088 |
+
t_list = torch.concat([t_list, t_list], dim=0) # (2 * T * p,)
|
1089 |
+
else:
|
1090 |
+
x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
|
1091 |
+
|
1092 |
+
model_pred = self.unet(
|
1093 |
+
x_t_latent_plus_uc, # (B, 4, h, w)
|
1094 |
+
t_list, # (B,)
|
1095 |
+
encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
|
1096 |
+
return_dict=False,
|
1097 |
+
# TODO: Add SDXL Support.
|
1098 |
+
# added_cond_kwargs={'text_embeds': add_text_embeds, 'time_ids': add_time_ids},
|
1099 |
+
)[0] # (B, 4, h, w)
|
1100 |
+
|
1101 |
+
if self.bootstrap_steps[0] > 0:
|
1102 |
+
# Uncentering.
|
1103 |
+
bootstrap_mask = rearrange(self.masks, 'p t c h w -> (t p) c h w')
|
1104 |
+
if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':
|
1105 |
+
bootstrap_mask_ = torch.concat([bootstrap_mask[:p], bootstrap_mask], dim=0)
|
1106 |
+
elif self.guidance_scale > 1.0 and self.cfg_type == 'full':
|
1107 |
+
bootstrap_mask_ = torch.concat([bootstrap_mask, bootstrap_mask], dim=0)
|
1108 |
+
else:
|
1109 |
+
bootstrap_mask_ = bootstrap_mask
|
1110 |
+
model_pred = shift_to_mask_bbox_center(model_pred, bootstrap_mask_)
|
1111 |
+
x_t_latent = shift_to_mask_bbox_center(x_t_latent, bootstrap_mask)
|
1112 |
+
|
1113 |
+
# # Remove leakage (optional).
|
1114 |
+
# leak = (latent_ - bg_latent_).pow(2).mean(dim=1, keepdim=True)
|
1115 |
+
# leak_sigmoid = torch.sigmoid(leak / self.bootstrap_leak_sensitivity) * 2 - 1
|
1116 |
+
# fg_mask_ = fg_mask_ * leak_sigmoid
|
1117 |
+
|
1118 |
+
### noise_pred_text, noise_pred_uncond: (T * p, 4, h, w)
|
1119 |
+
### self.stock_noise, init_noise: (T, 4, h, w)
|
1120 |
+
|
1121 |
+
if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':
|
1122 |
+
noise_pred_text = model_pred[p:]
|
1123 |
+
self.stock_noise_ = torch.concat([model_pred[:p], self.stock_noise_[p:]], dim=0)
|
1124 |
+
elif self.guidance_scale > 1.0 and self.cfg_type == 'full':
|
1125 |
+
noise_pred_uncond, noise_pred_text = model_pred.chunk(2)
|
1126 |
+
else:
|
1127 |
+
noise_pred_text = model_pred
|
1128 |
+
if self.guidance_scale > 1.0 and self.cfg_type in ('self', 'initialize'):
|
1129 |
+
noise_pred_uncond = self.stock_noise_ * self.delta
|
1130 |
+
|
1131 |
+
if self.guidance_scale > 1.0 and self.cfg_type != 'none':
|
1132 |
+
model_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1133 |
+
else:
|
1134 |
+
model_pred = noise_pred_text
|
1135 |
+
|
1136 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1137 |
+
denoised_batch = self.scheduler_step_batch(model_pred, x_t_latent, idx)
|
1138 |
+
|
1139 |
+
if self.cfg_type in ('self' , 'initialize'):
|
1140 |
+
scaled_noise = self.beta_prod_t_sqrt_ * self.stock_noise_
|
1141 |
+
delta_x = self.scheduler_step_batch(model_pred, scaled_noise, idx)
|
1142 |
+
|
1143 |
+
# Do mask edit.
|
1144 |
+
alpha_next = torch.concat([self.alpha_prod_t_sqrt_[p:], torch.ones_like(self.alpha_prod_t_sqrt_[:p])], dim=0)
|
1145 |
+
delta_x = alpha_next * delta_x
|
1146 |
+
beta_next = torch.concat([self.beta_prod_t_sqrt_[p:], torch.ones_like(self.beta_prod_t_sqrt_[:p])], dim=0)
|
1147 |
+
delta_x = delta_x / beta_next
|
1148 |
+
init_noise = torch.concat([self.init_noise_[p:], self.init_noise_[:p]], dim=0)
|
1149 |
+
self.stock_noise_ = init_noise + delta_x
|
1150 |
+
|
1151 |
+
p2 = len(self.t_list) - 1
|
1152 |
+
background = torch.concat([
|
1153 |
+
self.scheduler.add_noise(
|
1154 |
+
self.background.latent.repeat(p2, 1, 1, 1),
|
1155 |
+
self.stock_noise[1:],
|
1156 |
+
torch.tensor(self.t_list[1:], device=self.device),
|
1157 |
+
),
|
1158 |
+
self.background.latent,
|
1159 |
+
], dim=0)
|
1160 |
+
|
1161 |
+
denoised_batch = rearrange(denoised_batch, '(t p) c h w -> p t c h w', p=p)
|
1162 |
+
latent = (self.masks * denoised_batch).sum(dim=0) # (T, 4, h, w)
|
1163 |
+
latent = torch.where(self.counts > 0, latent / self.counts, latent)
|
1164 |
+
|
1165 |
+
# latent = (
|
1166 |
+
# (1 - self.bg_mask) * self.mask_strengths * latent
|
1167 |
+
# + ((1 - self.bg_mask) * (1.0 - self.mask_strengths) + self.bg_mask) * background
|
1168 |
+
# )
|
1169 |
+
latent = (1 - self.bg_mask) * latent + self.bg_mask * background
|
1170 |
+
|
1171 |
+
return latent
|
1172 |
+
|
1173 |
+
@torch.no_grad()
|
1174 |
+
def __call__(
|
1175 |
+
self,
|
1176 |
+
no_decode: bool = False,
|
1177 |
+
ignore_check_ready: bool = False,
|
1178 |
+
) -> Optional[Union[torch.Tensor, Image.Image]]:
|
1179 |
+
if not ignore_check_ready and not self.check_ready():
|
1180 |
+
return
|
1181 |
+
if not ignore_check_ready and self.is_dirty:
|
1182 |
+
print("I'm so dirty now!")
|
1183 |
+
self.commit()
|
1184 |
+
self.flush()
|
1185 |
+
|
1186 |
+
latent = torch.randn((1, self.unet.config.in_channels, self.latent_height, self.latent_width),
|
1187 |
+
dtype=self.dtype, device=self.device) # (1, 4, h, w)
|
1188 |
+
latent = torch.cat((latent, self.x_t_latent_buffer), dim=0) # (t, 4, h, w)
|
1189 |
+
self.stock_noise = torch.cat((self.init_noise[:1], self.stock_noise[:-1]), dim=0) # (t, 4, h, w)
|
1190 |
+
if self.cfg_type in ('self', 'initialize'):
|
1191 |
+
self.stock_noise_ = self.stock_noise.repeat_interleave(self.num_layers, dim=0) # (T * p, 77, 768)
|
1192 |
+
|
1193 |
+
x_0_pred_batch = self.unet_step(latent)
|
1194 |
+
|
1195 |
+
latent = x_0_pred_batch[-1:]
|
1196 |
+
self.x_t_latent_buffer = (
|
1197 |
+
self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
|
1198 |
+
+ self.beta_prod_t_sqrt[1:] * self.init_noise[1:]
|
1199 |
+
)
|
1200 |
+
|
1201 |
+
# For pipeline flushing.
|
1202 |
+
if no_decode:
|
1203 |
+
return latent
|
1204 |
+
|
1205 |
+
imgs = self.decode_latents(latent.half()) # (1, 3, H, W)
|
1206 |
+
img = T.ToPILImage()(imgs[0].cpu())
|
1207 |
+
return img
|
1208 |
+
|
1209 |
+
def flush(self) -> None:
|
1210 |
+
for _ in self.t_list:
|
1211 |
+
self(True, True)
|
1212 |
+
self.ready_checklist['flushed'] = True
|
prompt_util.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple, Union
|
2 |
+
|
3 |
+
|
4 |
+
quality_prompt_list = [
|
5 |
+
{
|
6 |
+
"name": "(None)",
|
7 |
+
"prompt": "{prompt}",
|
8 |
+
"negative_prompt": "nsfw, lowres",
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"name": "Standard v3.0",
|
12 |
+
"prompt": "{prompt}, masterpiece, best quality",
|
13 |
+
"negative_prompt": "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"name": "Standard v3.1",
|
17 |
+
"prompt": "{prompt}, masterpiece, best quality, very aesthetic, absurdres",
|
18 |
+
"negative_prompt": "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"name": "Light v3.1",
|
22 |
+
"prompt": "{prompt}, (masterpiece), best quality, very aesthetic, perfect face",
|
23 |
+
"negative_prompt": "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"name": "Heavy v3.1",
|
27 |
+
"prompt": "{prompt}, (masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
|
28 |
+
"negative_prompt": "nsfw, longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
|
29 |
+
},
|
30 |
+
]
|
31 |
+
|
32 |
+
style_list = [
|
33 |
+
{
|
34 |
+
"name": "(None)",
|
35 |
+
"prompt": "{prompt}",
|
36 |
+
"negative_prompt": "",
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"name": "Cinematic",
|
40 |
+
"prompt": "{prompt}, cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
41 |
+
"negative_prompt": "nsfw, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"name": "Photographic",
|
45 |
+
"prompt": "{prompt}, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
46 |
+
"negative_prompt": "nsfw, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"name": "Anime",
|
50 |
+
"prompt": "{prompt}, anime artwork, anime style, key visual, vibrant, studio anime, highly detailed",
|
51 |
+
"negative_prompt": "nsfw, photo, deformed, black and white, realism, disfigured, low contrast",
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"name": "Manga",
|
55 |
+
"prompt": "{prompt}, manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
|
56 |
+
"negative_prompt": "nsfw, ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"name": "Digital Art",
|
60 |
+
"prompt": "{prompt}, concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
|
61 |
+
"negative_prompt": "nsfw, photo, photorealistic, realism, ugly",
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"name": "Pixel art",
|
65 |
+
"prompt": "{prompt}, pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
|
66 |
+
"negative_prompt": "nsfw, sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"name": "Fantasy art",
|
70 |
+
"prompt": "{prompt}, ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
71 |
+
"negative_prompt": "nsfw, photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"name": "Neonpunk",
|
75 |
+
"prompt": "{prompt}, neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
76 |
+
"negative_prompt": "nsfw, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"name": "3D Model",
|
80 |
+
"prompt": "{prompt}, professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
|
81 |
+
"negative_prompt": "nsfw, ugly, deformed, noisy, low poly, blurry, painting",
|
82 |
+
},
|
83 |
+
]
|
84 |
+
|
85 |
+
|
86 |
+
_style_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
87 |
+
_quality_dict = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}
|
88 |
+
|
89 |
+
|
90 |
+
def preprocess_prompt(
|
91 |
+
positive: str,
|
92 |
+
negative: str = "",
|
93 |
+
style_dict: Dict[str, dict] = _quality_dict,
|
94 |
+
style_name: str = "Standard v3.1", # "Heavy v3.1"
|
95 |
+
add_style: bool = True,
|
96 |
+
) -> Tuple[str, str]:
|
97 |
+
p, n = style_dict.get(style_name, style_dict["(None)"])
|
98 |
+
|
99 |
+
if add_style and positive.strip():
|
100 |
+
formatted_positive = p.format(prompt=positive)
|
101 |
+
else:
|
102 |
+
formatted_positive = positive
|
103 |
+
|
104 |
+
combined_negative = n
|
105 |
+
if negative.strip():
|
106 |
+
if combined_negative:
|
107 |
+
combined_negative += ", " + negative
|
108 |
+
else:
|
109 |
+
combined_negative = negative
|
110 |
+
|
111 |
+
return formatted_positive, combined_negative
|
112 |
+
|
113 |
+
|
114 |
+
def preprocess_prompts(
|
115 |
+
positives: List[str],
|
116 |
+
negatives: List[str] = None,
|
117 |
+
style_dict = _style_dict,
|
118 |
+
style_name: str = "Manga", # "(None)"
|
119 |
+
quality_dict = _quality_dict,
|
120 |
+
quality_name: str = "Standard v3.1", # "Heavy v3.1"
|
121 |
+
add_style: bool = True,
|
122 |
+
add_quality_tags = True,
|
123 |
+
) -> Tuple[List[str], List[str]]:
|
124 |
+
if negatives is None:
|
125 |
+
negatives = ['' for _ in positives]
|
126 |
+
|
127 |
+
positives_ = []
|
128 |
+
negatives_ = []
|
129 |
+
for pos, neg in zip(positives, negatives):
|
130 |
+
pos, neg = preprocess_prompt(pos, neg, quality_dict, quality_name, add_quality_tags)
|
131 |
+
pos, neg = preprocess_prompt(pos, neg, style_dict, style_name, add_style)
|
132 |
+
positives_.append(pos)
|
133 |
+
negatives_.append(neg)
|
134 |
+
return positives_, negatives_
|
135 |
+
|
136 |
+
|
137 |
+
def print_prompts(
|
138 |
+
positives: Union[str, List[str]],
|
139 |
+
negatives: Union[str, List[str]],
|
140 |
+
has_background: bool = False,
|
141 |
+
) -> None:
|
142 |
+
if isinstance(positives, str):
|
143 |
+
positives = [positives]
|
144 |
+
if isinstance(negatives, str):
|
145 |
+
negatives = [negatives]
|
146 |
+
|
147 |
+
for i, prompt in enumerate(positives):
|
148 |
+
prefix = ((f'Prompt{i}' if i > 0 else 'Background Prompt')
|
149 |
+
if has_background else f'Prompt{i + 1}')
|
150 |
+
print(prefix + ': ' + prompt)
|
151 |
+
for i, prompt in enumerate(negatives):
|
152 |
+
prefix = ((f'Negative Prompt{i}' if i > 0 else 'Background Negative Prompt')
|
153 |
+
if has_background else f'Negative Prompt{i + 1}')
|
154 |
+
print(prefix + ': ' + prompt)
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
torchvision
|
3 |
+
xformers==0.0.22
|
4 |
+
einops
|
5 |
+
diffusers
|
6 |
+
transformers
|
7 |
+
huggingface_hub[torch]
|
8 |
+
Pillow
|
9 |
+
emoji
|
10 |
+
numpy
|
11 |
+
tqdm
|
12 |
+
jupyterlab
|
13 |
+
gradio @ https://gradio-builds.s3.amazonaws.com/7129aa5719aaa95a75397a83d3e1f3b72adf8050/gradio-4.26.0-py3-none-any.whl
|
timer/LICENSE_timer.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The MIT License (MIT)
|
2 |
+
|
3 |
+
Copyright (c) 2024 Jonathan Trancozo (https://codepen.io/jtrancozo/pen/mEoEVw)
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
timer/index.html
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en" >
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<title>CodePen - #2 - Project Deadline - SVG animation with CSS3</title>
|
6 |
+
<link href='https://fonts.googleapis.com/css?family=Oswald' rel='stylesheet' type='text/css'>
|
7 |
+
|
8 |
+
<meta property="og:image" content="https://i.imgur.com/9xiPyyv.png" /><link rel="stylesheet" href="./style.css">
|
9 |
+
|
10 |
+
</head>
|
11 |
+
<body>
|
12 |
+
<!-- partial:index.partial.html -->
|
13 |
+
<div id="deadline">
|
14 |
+
<svg preserveAspectRatio="none" id="line" viewBox="0 0 581 158" enable-background="new 0 0 581 158">
|
15 |
+
<g id="fire">
|
16 |
+
<rect id="mask-fire-black" x="511" y="41" width="38" height="34"/>
|
17 |
+
<g>
|
18 |
+
<defs>
|
19 |
+
<rect id="mask_fire" x="511" y="41" width="38" height="34"/>
|
20 |
+
</defs>
|
21 |
+
<clipPath id="mask-fire_1_">
|
22 |
+
<use xlink:href="#mask_fire" overflow="visible"/>
|
23 |
+
</clipPath>
|
24 |
+
<g id="group-fire" clip-path="url(#mask-fire_1_)">
|
25 |
+
<path id="red-flame" fill="#B71342" d="M528.377,100.291c6.207,0,10.947-3.272,10.834-8.576 c-0.112-5.305-2.934-8.803-8.237-10.383c-5.306-1.581-3.838-7.9-0.79-9.707c-7.337,2.032-7.581,5.891-7.11,8.238 c0.789,3.951,7.56,4.402,5.077,9.48c-2.482,5.079-8.012,1.129-6.319-2.257c-2.843,2.233-4.78,6.681-2.259,9.703 C521.256,98.809,524.175,100.291,528.377,100.291z"/>
|
26 |
+
<path id="yellow-flame" opacity="0.71" fill="#F7B523" d="M528.837,100.291c4.197,0,5.108-1.854,5.974-5.417 c0.902-3.724-1.129-6.207-5.305-9.931c-2.396-2.137-1.581-4.176-0.565-6.32c-4.401,1.918-3.384,5.304-2.482,6.658 c1.511,2.267,2.099,2.364,0.42,5.8c-1.679,3.435-5.42,0.764-4.275-1.527c-1.921,1.512-2.373,4.04-1.528,6.563 C522.057,99.051,525.994,100.291,528.837,100.291z"/>
|
27 |
+
<path id="white-flame" opacity="0.81" fill="#FFFFFF" d="M529.461,100.291c-2.364,0-4.174-1.322-4.129-3.469 c0.04-2.145,1.117-3.56,3.141-4.198c2.022-0.638,1.463-3.195,0.302-3.925c2.798,0.821,2.89,2.382,2.711,3.332 c-0.301,1.597-2.883,1.779-1.938,3.834c0.912,1.975,3.286,0.938,2.409-0.913c1.086,0.903,1.826,2.701,0.864,3.924 C532.18,99.691,531.064,100.291,529.461,100.291z"/>
|
28 |
+
</g>
|
29 |
+
</g>
|
30 |
+
</g>
|
31 |
+
<g id="progress-trail">
|
32 |
+
<path fill="#FFFFFF" d="M491.979,83.878c1.215-0.73-0.62-5.404-3.229-11.044c-2.583-5.584-5.034-10.066-7.229-8.878
|
33 |
+
c-2.854,1.544-0.192,6.286,2.979,11.628C487.667,80.917,490.667,84.667,491.979,83.878z"/>
|
34 |
+
<path fill="#FFFFFF" d="M571,76v-5h-23.608c0.476-9.951-4.642-13.25-4.642-13.25l-3.125,4c0,0,3.726,2.7,3.625,5.125
|
35 |
+
c-0.071,1.714-2.711,3.18-4.962,4.125H517v5h10v24h-25v-5.666c0,0,0.839,0,2.839-0.667s6.172-3.667,4.005-6.333
|
36 |
+
s-7.49,0.333-9.656,0.166s-6.479-1.5-8.146,1.917c-1.551,3.178,0.791,5.25,5.541,6.083l-0.065,4.5H16c-2.761,0-5,2.238-5,5v17
|
37 |
+
c0,2.762,2.239,5,5,5h549c2.762,0,5-2.238,5-5v-17c0-2.762-2.238-5-5-5h-3V76H571z"/>
|
38 |
+
<path fill="#FFFFFF" d="M535,65.625c1.125,0.625,2.25-1.125,2.25-1.125l11.625-22.375c0,0,0.75-0.875-1.75-2.125
|
39 |
+
s-3.375,0.25-3.375,0.25s-8.75,21.625-9.875,23.5S533.875,65,535,65.625z"/>
|
40 |
+
</g>
|
41 |
+
<g>
|
42 |
+
<defs>
|
43 |
+
<path id="SVGID_1_" d="M484.5,75.584c-3.172-5.342-5.833-10.084-2.979-11.628c2.195-1.188,4.646,3.294,7.229,8.878
|
44 |
+
c2.609,5.64,4.444,10.313,3.229,11.044C490.667,84.667,487.667,80.917,484.5,75.584z M571,76v-5h-23.608
|
45 |
+
c0.476-9.951-4.642-13.25-4.642-13.25l-3.125,4c0,0,3.726,2.7,3.625,5.125c-0.071,1.714-2.711,3.18-4.962,4.125H517v5h10v24h-25
|
46 |
+
v-5.666c0,0,0.839,0,2.839-0.667s6.172-3.667,4.005-6.333s-7.49,0.333-9.656,0.166s-6.479-1.5-8.146,1.917
|
47 |
+
c-1.551,3.178,0.791,5.25,5.541,6.083l-0.065,4.5H16c-2.761,0-5,2.238-5,5v17c0,2.762,2.239,5,5,5h549c2.762,0,5-2.238,5-5v-17
|
48 |
+
c0-2.762-2.238-5-5-5h-3V76H571z M535,65.625c1.125,0.625,2.25-1.125,2.25-1.125l11.625-22.375c0,0,0.75-0.875-1.75-2.125
|
49 |
+
s-3.375,0.25-3.375,0.25s-8.75,21.625-9.875,23.5S533.875,65,535,65.625z"/>
|
50 |
+
</defs>
|
51 |
+
<clipPath id="SVGID_2_">
|
52 |
+
<use xlink:href="#SVGID_1_" overflow="visible"/>
|
53 |
+
</clipPath>
|
54 |
+
<rect id="progress-time-fill" x="-100%" y="34" clip-path="url(#SVGID_2_)" fill="#BE002A" width="586" height="103"/>
|
55 |
+
</g>
|
56 |
+
|
57 |
+
<g id="death-group">
|
58 |
+
<path id="death" fill="#BE002A" d="M-46.25,40.416c-5.42-0.281-8.349,3.17-13.25,3.918c-5.716,0.871-10.583-0.918-10.583-0.918
|
59 |
+
C-67.5,49-65.175,50.6-62.083,52c5.333,2.416,4.083,3.5,2.084,4.5c-16.5,4.833-15.417,27.917-15.417,27.917L-75.5,84.75
|
60 |
+
c-1,12.25-20.25,18.75-20.25,18.75s39.447,13.471,46.25-4.25c3.583-9.333-1.553-16.869-1.667-22.75
|
61 |
+
c-0.076-3.871,2.842-8.529,6.084-12.334c3.596-4.22,6.958-10.374,6.958-15.416C-38.125,43.186-39.833,40.75-46.25,40.416z
|
62 |
+
M-40,51.959c-0.882,3.004-2.779,6.906-4.154,6.537s-0.939-4.32,0.112-7.704c0.82-2.64,2.672-5.96,3.959-5.583
|
63 |
+
C-39.005,45.523-39.073,48.8-40,51.959z"/>
|
64 |
+
<path id="death-arm" fill="#BE002A" d="M-53.375,75.25c0,0,9.375,2.25,11.25,0.25s2.313-2.342,3.375-2.791
|
65 |
+
c1.083-0.459,4.375-1.75,4.292-4.75c-0.101-3.627,0.271-4.594,1.333-5.043c1.083-0.457,2.75-1.666,2.75-1.666
|
66 |
+
s0.708-0.291,0.5-0.875s-0.791-2.125-1.583-2.959c-0.792-0.832-2.375-1.874-2.917-1.332c-0.542,0.541-7.875,7.166-7.875,7.166
|
67 |
+
s-2.667,2.791-3.417,0.125S-49.833,61-49.833,61s-3.417,1.416-3.417,1.541s-1.25,5.834-1.25,5.834l-0.583,5.833L-53.375,75.25z"/>
|
68 |
+
<path id="death-tool" fill="#BE002A" d="M-20.996,26.839l-42.819,91.475l1.812,0.848l38.342-81.909c0,0,8.833,2.643,12.412,7.414
|
69 |
+
c5,6.668,4.75,14.084,4.75,14.084s4.354-7.732,0.083-17.666C-10,32.75-19.647,28.676-19.647,28.676l0.463-0.988L-20.996,26.839z"/>
|
70 |
+
</g>
|
71 |
+
<path id="designer-body" fill="#FEFFFE" d="M514.75,100.334c0,0,1.25-16.834-6.75-16.5c-5.501,0.229-5.583,3-10.833,1.666
|
72 |
+
c-3.251-0.826-5.084-15.75-0.834-22c4.948-7.277,12.086-9.266,13.334-7.833c2.25,2.583-2,10.833-4.5,14.167
|
73 |
+
c-2.5,3.333-1.833,10.416,0.5,9.916s8.026-0.141,10,2.25c3.166,3.834,4.916,17.667,4.916,17.667l0.917,2.5l-4,0.167L514.75,100.334z
|
74 |
+
"/>
|
75 |
+
|
76 |
+
<circle id="designer-head" fill="#FEFFFE" cx="516.083" cy="53.25" r="6.083"/>
|
77 |
+
|
78 |
+
<g id="designer-arm-grop">
|
79 |
+
<path id="designer-arm" fill="#FEFFFE" d="M505.875,64.875c0,0,5.875,7.5,13.042,6.791c6.419-0.635,11.833-2.791,13.458-4.041s2-3.5,0.25-3.875
|
80 |
+
s-11.375,5.125-16,3.25c-5.963-2.418-8.25-7.625-8.25-7.625l-2,1.125L505.875,64.875z"/>
|
81 |
+
<path id="designer-pen" fill="#FEFFFE" d="M525.75,59.084c0,0-0.423-0.262-0.969,0.088c-0.586,0.375-0.547,0.891-0.547,0.891l7.172,8.984l1.261,0.453
|
82 |
+
l-0.104-1.328L525.75,59.084z"/>
|
83 |
+
</g>
|
84 |
+
</svg>
|
85 |
+
|
86 |
+
<div class="deadline-days">
|
87 |
+
Deadline <span class="day">7</span> <span class="days">days</span>
|
88 |
+
</div>
|
89 |
+
|
90 |
+
</div>
|
91 |
+
<!-- partial -->
|
92 |
+
<script src='https://code.jquery.com/jquery-2.2.4.min.js'></script><script src="./script.js"></script>
|
93 |
+
|
94 |
+
</body>
|
95 |
+
</html>
|
timer/script.js
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Init
|
2 |
+
var $ = jQuery;
|
3 |
+
var animationTime = 20,
|
4 |
+
days = 7;
|
5 |
+
|
6 |
+
$(document).ready(function(){
|
7 |
+
|
8 |
+
// timer arguments:
|
9 |
+
// #1 - time of animation in mileseconds,
|
10 |
+
// #2 - days to deadline
|
11 |
+
|
12 |
+
$('#progress-time-fill, #death-group').css({'animation-duration': animationTime+'s'});
|
13 |
+
|
14 |
+
var deadlineAnimation = function () {
|
15 |
+
setTimeout(function(){
|
16 |
+
$('#designer-arm-grop').css({'animation-duration': '1.5s'});
|
17 |
+
},0);
|
18 |
+
|
19 |
+
setTimeout(function(){
|
20 |
+
$('#designer-arm-grop').css({'animation-duration': '1s'});
|
21 |
+
},4000);
|
22 |
+
|
23 |
+
setTimeout(function(){
|
24 |
+
$('#designer-arm-grop').css({'animation-duration': '0.7s'});
|
25 |
+
},8000);
|
26 |
+
|
27 |
+
setTimeout(function(){
|
28 |
+
$('#designer-arm-grop').css({'animation-duration': '0.3s'});
|
29 |
+
},12000);
|
30 |
+
|
31 |
+
setTimeout(function(){
|
32 |
+
$('#designer-arm-grop').css({'animation-duration': '0.2s'});
|
33 |
+
},15000);
|
34 |
+
};
|
35 |
+
|
36 |
+
function timer(totalTime, deadline) {
|
37 |
+
var time = totalTime * 1000;
|
38 |
+
var dayDuration = time / deadline;
|
39 |
+
var actualDay = deadline;
|
40 |
+
|
41 |
+
var timer = setInterval(countTime, dayDuration);
|
42 |
+
|
43 |
+
function countTime() {
|
44 |
+
--actualDay;
|
45 |
+
$('.deadline-days .day').text(actualDay);
|
46 |
+
|
47 |
+
if (actualDay == 0) {
|
48 |
+
clearInterval(timer);
|
49 |
+
$('.deadline-days .day').text(deadline);
|
50 |
+
}
|
51 |
+
}
|
52 |
+
}
|
53 |
+
|
54 |
+
var deadlineText = function () {
|
55 |
+
var $el = $('.deadline-days');
|
56 |
+
var html = '<div class="mask-red"><div class="inner">' + $el.html() + '</div></div><div class="mask-white"><div class="inner">' + $el.html() + '</div></div>';
|
57 |
+
$el.html(html);
|
58 |
+
};
|
59 |
+
|
60 |
+
deadlineText();
|
61 |
+
|
62 |
+
deadlineAnimation();
|
63 |
+
timer(animationTime, days);
|
64 |
+
|
65 |
+
setInterval(function(){
|
66 |
+
timer(animationTime, days);
|
67 |
+
deadlineAnimation();
|
68 |
+
|
69 |
+
console.log('begin interval', animationTime * 1000);
|
70 |
+
|
71 |
+
}, animationTime * 1000);
|
72 |
+
|
73 |
+
});
|
timer/style.css
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
Svg Projects
|
3 |
+
Author: Jonathan Trancozo
|
4 |
+
Language: HTML, CSS3 and SVG
|
5 |
+
Project_version: V1
|
6 |
+
Project_description:
|
7 |
+
[pt-br]
|
8 |
+
Por anos eu vi essa imagem e pensava “Isso ficaria bem massa animado” e hoje consegui expressar um pouco da minha imaginação.
|
9 |
+
O desenho foi produzido no Adobe Illustrator e exportado em SVG. As animações foram feitas com CSS3 usando principalmente [transform].
|
10 |
+
|
11 |
+
Até uma próxima.
|
12 |
+
|
13 |
+
|
14 |
+
[en]
|
15 |
+
For years I saw this picture and thought "That would be amazing animated " and today I managed to express some of my imagination.
|
16 |
+
The design was produced in Adobe Illustrator and export in SVG . The animations were made with CSS3 using mainly [ transform ].
|
17 |
+
|
18 |
+
See you.
|
19 |
+
|
20 |
+
*/
|
21 |
+
|
22 |
+
html {
|
23 |
+
font-size: 1em;
|
24 |
+
line-height: 1.4;
|
25 |
+
}
|
26 |
+
|
27 |
+
html,
|
28 |
+
body {
|
29 |
+
height: 100%;
|
30 |
+
}
|
31 |
+
|
32 |
+
body {
|
33 |
+
margin: 0;
|
34 |
+
padding: 0;
|
35 |
+
background: transparent;
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
#deadline {
|
40 |
+
width:581px;
|
41 |
+
max-width: 100%;
|
42 |
+
height:158px;
|
43 |
+
position: absolute;
|
44 |
+
top: 50%;
|
45 |
+
left: 50%;
|
46 |
+
z-index: 1;
|
47 |
+
transform: translate(-50%, -50%);
|
48 |
+
}
|
49 |
+
|
50 |
+
#deadline svg {
|
51 |
+
width: 100%;
|
52 |
+
}
|
53 |
+
|
54 |
+
#progress-time-fill {
|
55 |
+
-webkit-animation-name: progress-fill;
|
56 |
+
animation-name: progress-fill;
|
57 |
+
-webkit-animation-timing-function: linear;
|
58 |
+
animation-timing-function: linear;
|
59 |
+
-webkit-animation-iteration-count: infinite;
|
60 |
+
animation-iteration-count: infinite;
|
61 |
+
}
|
62 |
+
|
63 |
+
/* Death */
|
64 |
+
#death-group {
|
65 |
+
-webkit-animation-name: walk;
|
66 |
+
animation-name: walk;
|
67 |
+
-webkit-animation-timing-function: ease;
|
68 |
+
animation-timing-function: ease;
|
69 |
+
-webkit-animation-iteration-count: infinite;
|
70 |
+
animation-iteration-count: infinite;
|
71 |
+
transform: translateX(0);
|
72 |
+
}
|
73 |
+
|
74 |
+
#death-arm {
|
75 |
+
-webkit-animation: move-arm 3s ease infinite;
|
76 |
+
animation: move-arm 3s ease infinite;
|
77 |
+
/* transform-origin: left center; */
|
78 |
+
transform-origin: -60px 74px;
|
79 |
+
}
|
80 |
+
|
81 |
+
#death-tool {
|
82 |
+
-webkit-animation: move-tool 3s ease infinite;
|
83 |
+
animation: move-tool 3s ease infinite;
|
84 |
+
transform-origin: -48px center;
|
85 |
+
}
|
86 |
+
|
87 |
+
/* Designer */
|
88 |
+
|
89 |
+
#designer-arm-grop {
|
90 |
+
-webkit-animation: write 1.5s ease infinite;
|
91 |
+
animation: write 1.5s ease infinite;
|
92 |
+
transform: translate(0, 0) rotate(0deg) scale(1, 1);
|
93 |
+
transform-origin: 90% top;
|
94 |
+
}
|
95 |
+
|
96 |
+
.deadline-timer {
|
97 |
+
color: #fff;
|
98 |
+
text-align: center;
|
99 |
+
width: 200px;
|
100 |
+
margin: 0 auto;
|
101 |
+
position: relative;
|
102 |
+
height: 40px;
|
103 |
+
font-family: 'Oswald', sans-serif;
|
104 |
+
font-size: 18pt;
|
105 |
+
margin-top: -90px;
|
106 |
+
}
|
107 |
+
|
108 |
+
.deadline-timer .inner {
|
109 |
+
width: 200px;
|
110 |
+
position: relative;
|
111 |
+
top: 0;
|
112 |
+
left: 0;
|
113 |
+
}
|
114 |
+
|
115 |
+
.mask-red,
|
116 |
+
.mask-white {
|
117 |
+
position: absolute;
|
118 |
+
top: 0;
|
119 |
+
width: 100%;
|
120 |
+
overflow: hidden;
|
121 |
+
height: 100%;
|
122 |
+
}
|
123 |
+
|
124 |
+
@-webkit-keyframes progress-fill {
|
125 |
+
0% {
|
126 |
+
x: -100%;
|
127 |
+
}
|
128 |
+
|
129 |
+
100% {
|
130 |
+
x: -3%;
|
131 |
+
}
|
132 |
+
}
|
133 |
+
|
134 |
+
@keyframes progress-fill {
|
135 |
+
0% {
|
136 |
+
x: -100%;
|
137 |
+
}
|
138 |
+
|
139 |
+
100% {
|
140 |
+
x: -3%;
|
141 |
+
}
|
142 |
+
}
|
143 |
+
|
144 |
+
@-webkit-keyframes walk {
|
145 |
+
0% {
|
146 |
+
transform: translateX(0);
|
147 |
+
}
|
148 |
+
6% {
|
149 |
+
transform: translateX(0);
|
150 |
+
}
|
151 |
+
10% {
|
152 |
+
transform: translateX(100px);
|
153 |
+
},
|
154 |
+
|
155 |
+
15% {
|
156 |
+
transform: translateX(140px);
|
157 |
+
}
|
158 |
+
|
159 |
+
25% {
|
160 |
+
transform: translateX(170px);
|
161 |
+
}
|
162 |
+
|
163 |
+
35% {
|
164 |
+
transform: translateX(220px);
|
165 |
+
}
|
166 |
+
|
167 |
+
45% {
|
168 |
+
transform: translateX(280px);
|
169 |
+
}
|
170 |
+
|
171 |
+
55% {
|
172 |
+
transform: translateX(340px);
|
173 |
+
}
|
174 |
+
|
175 |
+
65% {
|
176 |
+
transform: translateX(370px);
|
177 |
+
}
|
178 |
+
|
179 |
+
75% {
|
180 |
+
transform: translateX(430px);
|
181 |
+
}
|
182 |
+
|
183 |
+
85% {
|
184 |
+
transform: translateX(460px);
|
185 |
+
}
|
186 |
+
|
187 |
+
100% {
|
188 |
+
transform: translateX(520px);
|
189 |
+
}
|
190 |
+
}
|
191 |
+
|
192 |
+
@keyframes walk {
|
193 |
+
0% {
|
194 |
+
transform: translateX(0);
|
195 |
+
}
|
196 |
+
6% {
|
197 |
+
transform: translateX(0);
|
198 |
+
}
|
199 |
+
10% {
|
200 |
+
transform: translateX(100px);
|
201 |
+
},
|
202 |
+
|
203 |
+
15% {
|
204 |
+
transform: translateX(140px);
|
205 |
+
}
|
206 |
+
|
207 |
+
25% {
|
208 |
+
transform: translateX(170px);
|
209 |
+
}
|
210 |
+
|
211 |
+
35% {
|
212 |
+
transform: translateX(220px);
|
213 |
+
}
|
214 |
+
|
215 |
+
45% {
|
216 |
+
transform: translateX(280px);
|
217 |
+
}
|
218 |
+
|
219 |
+
55% {
|
220 |
+
transform: translateX(340px);
|
221 |
+
}
|
222 |
+
|
223 |
+
65% {
|
224 |
+
transform: translateX(370px);
|
225 |
+
}
|
226 |
+
|
227 |
+
75% {
|
228 |
+
transform: translateX(430px);
|
229 |
+
}
|
230 |
+
|
231 |
+
85% {
|
232 |
+
transform: translateX(460px);
|
233 |
+
}
|
234 |
+
|
235 |
+
100% {
|
236 |
+
transform: translateX(520px);
|
237 |
+
}
|
238 |
+
}
|
239 |
+
|
240 |
+
@-webkit-keyframes move-arm {
|
241 |
+
0% {
|
242 |
+
transform: rotate(0);
|
243 |
+
}
|
244 |
+
|
245 |
+
5% {
|
246 |
+
transform: rotate(0);
|
247 |
+
}
|
248 |
+
|
249 |
+
9% {
|
250 |
+
transform: rotate(40deg);
|
251 |
+
}
|
252 |
+
|
253 |
+
80% {
|
254 |
+
transform: rotate(0);
|
255 |
+
}
|
256 |
+
}
|
257 |
+
|
258 |
+
@keyframes move-arm {
|
259 |
+
0% {
|
260 |
+
transform: rotate(0);
|
261 |
+
}
|
262 |
+
|
263 |
+
5% {
|
264 |
+
transform: rotate(0);
|
265 |
+
}
|
266 |
+
|
267 |
+
9% {
|
268 |
+
transform: rotate(40deg);
|
269 |
+
}
|
270 |
+
|
271 |
+
80% {
|
272 |
+
transform: rotate(0);
|
273 |
+
}
|
274 |
+
}
|
275 |
+
|
276 |
+
@-webkit-keyframes move-tool {
|
277 |
+
0% {
|
278 |
+
transform: rotate(0);
|
279 |
+
}
|
280 |
+
|
281 |
+
5% {
|
282 |
+
transform: rotate(0);
|
283 |
+
}
|
284 |
+
|
285 |
+
9% {
|
286 |
+
transform: rotate(50deg);
|
287 |
+
}
|
288 |
+
|
289 |
+
80% {
|
290 |
+
transform: rotate(0);
|
291 |
+
}
|
292 |
+
}
|
293 |
+
|
294 |
+
@keyframes move-tool {
|
295 |
+
0% {
|
296 |
+
transform: rotate(0);
|
297 |
+
}
|
298 |
+
|
299 |
+
5% {
|
300 |
+
transform: rotate(0);
|
301 |
+
}
|
302 |
+
|
303 |
+
9% {
|
304 |
+
transform: rotate(50deg);
|
305 |
+
}
|
306 |
+
|
307 |
+
80% {
|
308 |
+
transform: rotate(0);
|
309 |
+
}
|
310 |
+
}
|
311 |
+
|
312 |
+
/* Design animations */
|
313 |
+
|
314 |
+
@-webkit-keyframes write {
|
315 |
+
0% {
|
316 |
+
transform: translate(0, 0) rotate(0deg) scale(1, 1);
|
317 |
+
}
|
318 |
+
|
319 |
+
16% {
|
320 |
+
transform: translate(0px, 0px) rotate(5deg) scale(0.8, 1);
|
321 |
+
}
|
322 |
+
|
323 |
+
32% {
|
324 |
+
transform: translate(0px, 0px) rotate(0deg) scale(1, 1);
|
325 |
+
}
|
326 |
+
|
327 |
+
48% {
|
328 |
+
transform: translate(0px, 0px) rotate(6deg) scale(0.8, 1);
|
329 |
+
}
|
330 |
+
|
331 |
+
65% {
|
332 |
+
transform: translate(0px, 0px) rotate(0deg) scale(1, 1);
|
333 |
+
}
|
334 |
+
|
335 |
+
83% {
|
336 |
+
transform: translate(0px, 0px) rotate(4deg) scale(0.8, 1);
|
337 |
+
}
|
338 |
+
}
|
339 |
+
|
340 |
+
@keyframes write {
|
341 |
+
0% {
|
342 |
+
transform: translate(0, 0) rotate(0deg) scale(1, 1);
|
343 |
+
}
|
344 |
+
|
345 |
+
16% {
|
346 |
+
transform: translate(0px, 0px) rotate(5deg) scale(0.8, 1);
|
347 |
+
}
|
348 |
+
|
349 |
+
32% {
|
350 |
+
transform: translate(0px, 0px) rotate(0deg) scale(1, 1);
|
351 |
+
}
|
352 |
+
|
353 |
+
48% {
|
354 |
+
transform: translate(0px, 0px) rotate(6deg) scale(0.8, 1);
|
355 |
+
}
|
356 |
+
|
357 |
+
65% {
|
358 |
+
transform: translate(0px, 0px) rotate(0deg) scale(1, 1);
|
359 |
+
}
|
360 |
+
|
361 |
+
83% {
|
362 |
+
transform: translate(0px, 0px) rotate(4deg) scale(0.8, 1);
|
363 |
+
}
|
364 |
+
}
|
365 |
+
|
366 |
+
@-webkit-keyframes text-red {
|
367 |
+
0% {
|
368 |
+
width: 0%;
|
369 |
+
}
|
370 |
+
|
371 |
+
100% {
|
372 |
+
width: 98%;
|
373 |
+
}
|
374 |
+
}
|
375 |
+
|
376 |
+
@keyframes text-red {
|
377 |
+
0% {
|
378 |
+
width: 0%;
|
379 |
+
}
|
380 |
+
|
381 |
+
100% {
|
382 |
+
width: 98%;
|
383 |
+
}
|
384 |
+
}
|
385 |
+
|
386 |
+
/* Flames */
|
387 |
+
|
388 |
+
/* @keyframes show-flames {
|
389 |
+
0% {
|
390 |
+
transform: translateY(0);
|
391 |
+
}
|
392 |
+
74% {
|
393 |
+
transform: translateY(0);
|
394 |
+
}
|
395 |
+
80% {
|
396 |
+
transform: translateY(-30px);
|
397 |
+
}
|
398 |
+
97% {
|
399 |
+
transform: translateY(-30px);
|
400 |
+
}
|
401 |
+
100% {
|
402 |
+
transform: translateY(0px);
|
403 |
+
}
|
404 |
+
} */
|
405 |
+
|
406 |
+
@-webkit-keyframes show-flames {
|
407 |
+
0% {
|
408 |
+
opacity: 0;
|
409 |
+
}
|
410 |
+
74% {
|
411 |
+
opacity: 0;
|
412 |
+
}
|
413 |
+
80% {
|
414 |
+
opacity: 1;
|
415 |
+
}
|
416 |
+
99% {
|
417 |
+
opacity: 1;
|
418 |
+
}
|
419 |
+
100% {
|
420 |
+
opacity: 0;
|
421 |
+
}
|
422 |
+
}
|
423 |
+
|
424 |
+
@keyframes show-flames {
|
425 |
+
0% {
|
426 |
+
opacity: 0;
|
427 |
+
}
|
428 |
+
74% {
|
429 |
+
opacity: 0;
|
430 |
+
}
|
431 |
+
80% {
|
432 |
+
opacity: 1;
|
433 |
+
}
|
434 |
+
99% {
|
435 |
+
opacity: 1;
|
436 |
+
}
|
437 |
+
100% {
|
438 |
+
opacity: 0;
|
439 |
+
}
|
440 |
+
}
|
441 |
+
|
442 |
+
@-webkit-keyframes red-flame {
|
443 |
+
0% {
|
444 |
+
transform: translateY(-30px) scale(1, 1);
|
445 |
+
}
|
446 |
+
|
447 |
+
25% {
|
448 |
+
transform: translateY(-30px) scale(1.1, 1.1);
|
449 |
+
}
|
450 |
+
|
451 |
+
75% {
|
452 |
+
transform: translateY(-30px) scale(0.8, 0.7);
|
453 |
+
}
|
454 |
+
|
455 |
+
100% {
|
456 |
+
transform: translateY(-30px) scale(1, 1);
|
457 |
+
}
|
458 |
+
}
|
459 |
+
|
460 |
+
@keyframes red-flame {
|
461 |
+
0% {
|
462 |
+
transform: translateY(-30px) scale(1, 1);
|
463 |
+
}
|
464 |
+
|
465 |
+
25% {
|
466 |
+
transform: translateY(-30px) scale(1.1, 1.1);
|
467 |
+
}
|
468 |
+
|
469 |
+
75% {
|
470 |
+
transform: translateY(-30px) scale(0.8, 0.7);
|
471 |
+
}
|
472 |
+
|
473 |
+
100% {
|
474 |
+
transform: translateY(-30px) scale(1, 1);
|
475 |
+
}
|
476 |
+
}
|
477 |
+
|
478 |
+
@-webkit-keyframes yellow-flame {
|
479 |
+
0% {
|
480 |
+
transform: translateY(-30px) scale(0.8, 0.7);
|
481 |
+
}
|
482 |
+
|
483 |
+
50% {
|
484 |
+
transform: translateY(-30px) scale(1.1, 1.2);
|
485 |
+
}
|
486 |
+
|
487 |
+
100% {
|
488 |
+
transform: translateY(-30px) scale(1, 1);
|
489 |
+
}
|
490 |
+
}
|
491 |
+
|
492 |
+
@keyframes yellow-flame {
|
493 |
+
0% {
|
494 |
+
transform: translateY(-30px) scale(0.8, 0.7);
|
495 |
+
}
|
496 |
+
|
497 |
+
50% {
|
498 |
+
transform: translateY(-30px) scale(1.1, 1.2);
|
499 |
+
}
|
500 |
+
|
501 |
+
100% {
|
502 |
+
transform: translateY(-30px) scale(1, 1);
|
503 |
+
}
|
504 |
+
}
|
util.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Jaerin Lee
|
2 |
+
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
# of this software and associated documentation files (the "Software"), to deal
|
5 |
+
# in the Software without restriction, including without limitation the rights
|
6 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
# copies of the Software, and to permit persons to whom the Software is
|
8 |
+
# furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all
|
11 |
+
# copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
# SOFTWARE.
|
20 |
+
|
21 |
+
import concurrent.futures
|
22 |
+
import time
|
23 |
+
from typing import Any, Callable, List, Literal, Tuple, Union
|
24 |
+
|
25 |
+
from PIL import Image
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.nn.functional as F
|
30 |
+
import torch.cuda.amp as amp
|
31 |
+
import torchvision.transforms as T
|
32 |
+
import torchvision.transforms.functional as TF
|
33 |
+
|
34 |
+
from diffusers import (
|
35 |
+
DiffusionPipeline,
|
36 |
+
StableDiffusionPipeline,
|
37 |
+
StableDiffusionXLPipeline,
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def seed_everything(seed: int) -> None:
|
42 |
+
torch.manual_seed(seed)
|
43 |
+
torch.cuda.manual_seed(seed)
|
44 |
+
torch.backends.cudnn.deterministic = True
|
45 |
+
torch.backends.cudnn.benchmark = True
|
46 |
+
|
47 |
+
|
48 |
+
def load_model(
|
49 |
+
model_key: str,
|
50 |
+
sd_version: Literal['1.5', 'xl'],
|
51 |
+
device: torch.device,
|
52 |
+
dtype: torch.dtype,
|
53 |
+
) -> torch.nn.Module:
|
54 |
+
if model_key.endswith('.safetensors'):
|
55 |
+
if sd_version == '1.5':
|
56 |
+
pipeline = StableDiffusionPipeline
|
57 |
+
elif sd_version == 'xl':
|
58 |
+
pipeline = StableDiffusionXLPipeline
|
59 |
+
else:
|
60 |
+
raise ValueError(f'Stable Diffusion version {sd_version} not supported.')
|
61 |
+
return pipeline.from_single_file(model_key, torch_dtype=dtype).to(device)
|
62 |
+
try:
|
63 |
+
return DiffusionPipeline.from_pretrained(model_key, variant='fp16', torch_dtype=dtype).to(device)
|
64 |
+
except:
|
65 |
+
return DiffusionPipeline.from_pretrained(model_key, variant=None, torch_dtype=dtype).to(device)
|
66 |
+
|
67 |
+
|
68 |
+
def get_cutoff(cutoff: float = None, scale: float = None) -> float:
|
69 |
+
if cutoff is not None:
|
70 |
+
return cutoff
|
71 |
+
|
72 |
+
if scale is not None and cutoff is None:
|
73 |
+
return 0.5 / scale
|
74 |
+
|
75 |
+
raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
|
76 |
+
|
77 |
+
|
78 |
+
def get_scale(cutoff: float = None, scale: float = None) -> float:
|
79 |
+
if scale is not None:
|
80 |
+
return scale
|
81 |
+
|
82 |
+
if cutoff is not None and scale is None:
|
83 |
+
return 0.5 / cutoff
|
84 |
+
|
85 |
+
raise ValueError('Either one of `cutoff`, or `scale` should be specified.')
|
86 |
+
|
87 |
+
|
88 |
+
def filter_2d_by_kernel_1d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
89 |
+
assert len(k.shape) in (1,), 'Kernel size should be one of (1,).'
|
90 |
+
# assert len(k.shape) in (1, 2), 'Kernel size should be one of (1, 2).'
|
91 |
+
|
92 |
+
b, c, h, w = x.shape
|
93 |
+
ks = k.shape[-1]
|
94 |
+
k = k.view(1, 1, -1).repeat(c, 1, 1)
|
95 |
+
|
96 |
+
x = x.permute(0, 2, 1, 3)
|
97 |
+
x = x.reshape(b * h, c, w)
|
98 |
+
x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
|
99 |
+
x = F.conv1d(x, k, groups=c)
|
100 |
+
x = x.reshape(b, h, c, w).permute(0, 3, 2, 1).reshape(b * w, c, h)
|
101 |
+
x = F.pad(x, (ks // 2, (ks - 1) // 2), mode='replicate')
|
102 |
+
x = F.conv1d(x, k, groups=c)
|
103 |
+
x = x.reshape(b, w, c, h).permute(0, 2, 3, 1)
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
def filter_2d_by_kernel_2d(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
108 |
+
assert len(k.shape) in (2, 3), 'Kernel size should be one of (2, 3).'
|
109 |
+
|
110 |
+
x = F.pad(x, (
|
111 |
+
k.shape[-2] // 2, (k.shape[-2] - 1) // 2,
|
112 |
+
k.shape[-1] // 2, (k.shape[-1] - 1) // 2,
|
113 |
+
), mode='replicate')
|
114 |
+
|
115 |
+
b, c, _, _ = x.shape
|
116 |
+
if len(k.shape) == 2 or (len(k.shape) == 3 and k.shape[0] == 1):
|
117 |
+
k = k.view(1, 1, *k.shape[-2:]).repeat(c, 1, 1, 1)
|
118 |
+
x = F.conv2d(x, k, groups=c)
|
119 |
+
elif len(k.shape) == 3:
|
120 |
+
assert k.shape[0] == b, \
|
121 |
+
'The number of kernels should match the batch size.'
|
122 |
+
|
123 |
+
k = k.unsqueeze(1)
|
124 |
+
x = F.conv2d(x.permute(1, 0, 2, 3), k, groups=b).permute(1, 0, 2, 3)
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
@amp.autocast(False)
|
129 |
+
def filter_by_kernel(
|
130 |
+
x: torch.Tensor,
|
131 |
+
k: torch.Tensor,
|
132 |
+
is_batch: bool = False,
|
133 |
+
) -> torch.Tensor:
|
134 |
+
k_dim = len(k.shape)
|
135 |
+
if k_dim == 1 or k_dim == 2 and is_batch:
|
136 |
+
return filter_2d_by_kernel_1d(x, k)
|
137 |
+
elif k_dim == 2 or k_dim == 3 and is_batch:
|
138 |
+
return filter_2d_by_kernel_2d(x, k)
|
139 |
+
else:
|
140 |
+
raise ValueError('Kernel size should be one of (1, 2, 3).')
|
141 |
+
|
142 |
+
|
143 |
+
def gen_gauss_lowpass_filter_2d(
|
144 |
+
std: torch.Tensor,
|
145 |
+
window_size: int = None,
|
146 |
+
) -> torch.Tensor:
|
147 |
+
# Gaussian kernel size is odd in order to preserve the center.
|
148 |
+
if window_size is None:
|
149 |
+
window_size = (
|
150 |
+
2 * int(np.ceil(3 * std.max().detach().cpu().numpy())) + 1)
|
151 |
+
|
152 |
+
y = torch.arange(
|
153 |
+
window_size, dtype=std.dtype, device=std.device
|
154 |
+
).view(-1, 1).repeat(1, window_size)
|
155 |
+
grid = torch.stack((y.t(), y), dim=-1)
|
156 |
+
grid -= 0.5 * (window_size - 1) # (W, W)
|
157 |
+
var = (std * std).unsqueeze(-1).unsqueeze(-1)
|
158 |
+
distsq = (grid * grid).sum(dim=-1).unsqueeze(0).repeat(*std.shape, 1, 1)
|
159 |
+
k = torch.exp(-0.5 * distsq / var)
|
160 |
+
k /= k.sum(dim=(-2, -1), keepdim=True)
|
161 |
+
return k
|
162 |
+
|
163 |
+
|
164 |
+
def gaussian_lowpass(
|
165 |
+
x: torch.Tensor,
|
166 |
+
std: Union[float, Tuple[float], torch.Tensor] = None,
|
167 |
+
cutoff: Union[float, torch.Tensor] = None,
|
168 |
+
scale: Union[float, torch.Tensor] = None,
|
169 |
+
) -> torch.Tensor:
|
170 |
+
if std is None:
|
171 |
+
cutoff = get_cutoff(cutoff, scale)
|
172 |
+
std = 0.5 / (np.pi * cutoff)
|
173 |
+
if isinstance(std, (float, int)):
|
174 |
+
std = (std, std)
|
175 |
+
if isinstance(std, torch.Tensor):
|
176 |
+
"""Using nn.functional.conv2d with Gaussian kernels built in runtime is
|
177 |
+
80% faster than transforms.functional.gaussian_blur for individual
|
178 |
+
items.
|
179 |
+
|
180 |
+
(in GPU); However, in CPU, the result is exactly opposite. But you
|
181 |
+
won't gonna run this on CPU, right?
|
182 |
+
"""
|
183 |
+
if len(list(s for s in std.shape if s != 1)) >= 2:
|
184 |
+
raise NotImplementedError(
|
185 |
+
'Anisotropic Gaussian filter is not currently available.')
|
186 |
+
|
187 |
+
# k.shape == (B, W, W).
|
188 |
+
k = gen_gauss_lowpass_filter_2d(std=std.view(-1))
|
189 |
+
if k.shape[0] == 1:
|
190 |
+
return filter_by_kernel(x, k[0], False)
|
191 |
+
else:
|
192 |
+
return filter_by_kernel(x, k, True)
|
193 |
+
else:
|
194 |
+
# Gaussian kernel size is odd in order to preserve the center.
|
195 |
+
window_size = tuple(2 * int(np.ceil(3 * s)) + 1 for s in std)
|
196 |
+
return TF.gaussian_blur(x, window_size, std)
|
197 |
+
|
198 |
+
|
199 |
+
def blend(
|
200 |
+
fg: Union[torch.Tensor, Image.Image],
|
201 |
+
bg: Union[torch.Tensor, Image.Image],
|
202 |
+
mask: Union[torch.Tensor, Image.Image],
|
203 |
+
std: float = 0.0,
|
204 |
+
) -> Image.Image:
|
205 |
+
if not isinstance(fg, torch.Tensor):
|
206 |
+
fg = T.ToTensor()(fg)
|
207 |
+
if not isinstance(bg, torch.Tensor):
|
208 |
+
bg = T.ToTensor()(bg)
|
209 |
+
if not isinstance(mask, torch.Tensor):
|
210 |
+
mask = (T.ToTensor()(mask) < 0.5).float()[:1]
|
211 |
+
if std > 0:
|
212 |
+
mask = gaussian_lowpass(mask[None], std)[0].clip_(0, 1)
|
213 |
+
return T.ToPILImage()(fg * mask + bg * (1 - mask))
|
214 |
+
|
215 |
+
|
216 |
+
def get_panorama_views(
|
217 |
+
panorama_height: int,
|
218 |
+
panorama_width: int,
|
219 |
+
window_size: int = 64,
|
220 |
+
) -> tuple[List[Tuple[int]], torch.Tensor]:
|
221 |
+
stride = window_size // 2
|
222 |
+
is_horizontal = panorama_width > panorama_height
|
223 |
+
num_blocks_height = (panorama_height - window_size + stride - 1) // stride + 1
|
224 |
+
num_blocks_width = (panorama_width - window_size + stride - 1) // stride + 1
|
225 |
+
total_num_blocks = num_blocks_height * num_blocks_width
|
226 |
+
|
227 |
+
half_fwd = torch.linspace(0, 1, (window_size + 1) // 2)
|
228 |
+
half_rev = half_fwd.flip(0)
|
229 |
+
if window_size % 2 == 1:
|
230 |
+
half_rev = half_rev[1:]
|
231 |
+
c = torch.cat((half_fwd, half_rev))
|
232 |
+
one = torch.ones_like(c)
|
233 |
+
f = c.clone()
|
234 |
+
f[:window_size // 2] = 1
|
235 |
+
b = c.clone()
|
236 |
+
b[-(window_size // 2):] = 1
|
237 |
+
|
238 |
+
h = [one] if num_blocks_height == 1 else [f] + [c] * (num_blocks_height - 2) + [b]
|
239 |
+
w = [one] if num_blocks_width == 1 else [f] + [c] * (num_blocks_width - 2) + [b]
|
240 |
+
|
241 |
+
views = []
|
242 |
+
masks = torch.zeros(total_num_blocks, panorama_height, panorama_width) # (n, h, w)
|
243 |
+
for i in range(total_num_blocks):
|
244 |
+
hi, wi = i // num_blocks_width, i % num_blocks_width
|
245 |
+
h_start = hi * stride
|
246 |
+
h_end = min(h_start + window_size, panorama_height)
|
247 |
+
w_start = wi * stride
|
248 |
+
w_end = min(w_start + window_size, panorama_width)
|
249 |
+
views.append((h_start, h_end, w_start, w_end))
|
250 |
+
|
251 |
+
h_width = h_end - h_start
|
252 |
+
w_width = w_end - w_start
|
253 |
+
masks[i, h_start:h_end, w_start:w_end] = h[hi][:h_width, None] * w[wi][None, :w_width]
|
254 |
+
|
255 |
+
# Sum of the mask weights at each pixel `masks.sum(dim=1)` must be unity.
|
256 |
+
return views, masks[None] # (1, n, h, w)
|
257 |
+
|
258 |
+
|
259 |
+
def shift_to_mask_bbox_center(im: torch.Tensor, mask: torch.Tensor, reverse: bool = False) -> List[int]:
|
260 |
+
h, w = mask.shape[-2:]
|
261 |
+
device = mask.device
|
262 |
+
mask = mask.reshape(-1, h, w)
|
263 |
+
# assert mask.shape[0] == im.shape[0]
|
264 |
+
h_occupied = mask.sum(dim=-2) > 0
|
265 |
+
w_occupied = mask.sum(dim=-1) > 0
|
266 |
+
l = torch.argmax(h_occupied * torch.arange(w, 0, -1).to(device), 1, keepdim=True).cpu()
|
267 |
+
r = torch.argmax(h_occupied * torch.arange(w).to(device), 1, keepdim=True).cpu()
|
268 |
+
t = torch.argmax(w_occupied * torch.arange(h, 0, -1).to(device), 1, keepdim=True).cpu()
|
269 |
+
b = torch.argmax(w_occupied * torch.arange(h).to(device), 1, keepdim=True).cpu()
|
270 |
+
tb = (t + b + 1) // 2
|
271 |
+
lr = (l + r + 1) // 2
|
272 |
+
shifts = (tb - (h // 2), lr - (w // 2))
|
273 |
+
shifts = torch.cat(shifts, dim=1) # (p, 2)
|
274 |
+
if reverse:
|
275 |
+
shifts = shifts * -1
|
276 |
+
return torch.stack([i.roll(shifts=s.tolist(), dims=(-2, -1)) for i, s in zip(im, shifts)], dim=0)
|
277 |
+
|
278 |
+
|
279 |
+
class Streamer:
|
280 |
+
def __init__(self, fn: Callable, ema_alpha: float = 0.9) -> None:
|
281 |
+
self.fn = fn
|
282 |
+
self.ema_alpha = ema_alpha
|
283 |
+
|
284 |
+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
285 |
+
self.future = self.executor.submit(fn)
|
286 |
+
self.image = None
|
287 |
+
|
288 |
+
self.prev_exec_time = 0
|
289 |
+
self.ema_exec_time = 0
|
290 |
+
|
291 |
+
@property
|
292 |
+
def throughput(self) -> float:
|
293 |
+
return 1.0 / self.ema_exec_time if self.ema_exec_time else float('inf')
|
294 |
+
|
295 |
+
def timed_fn(self) -> Any:
|
296 |
+
start = time.time()
|
297 |
+
res = self.fn()
|
298 |
+
end = time.time()
|
299 |
+
self.prev_exec_time = end - start
|
300 |
+
self.ema_exec_time = self.ema_exec_time * self.ema_alpha + self.prev_exec_time * (1 - self.ema_alpha)
|
301 |
+
return res
|
302 |
+
|
303 |
+
def __call__(self) -> Any:
|
304 |
+
if self.future.done() or self.image is None:
|
305 |
+
# get the result (the new image) and start a new task
|
306 |
+
image = self.future.result()
|
307 |
+
self.future = self.executor.submit(self.timed_fn)
|
308 |
+
self.image = image
|
309 |
+
return image
|
310 |
+
else:
|
311 |
+
# if self.fn() is not ready yet, use the previous image
|
312 |
+
# NOTE: This assumes that we have access to a previously generated image here.
|
313 |
+
# If there's no previous image (i.e., this is the first invocation), you could fall
|
314 |
+
# back to some default image or handle it differently based on your requirements.
|
315 |
+
return self.image
|