Spaces:
Runtime error
Runtime error
init
Browse files- .gitignore +5 -0
- TrailBlazer/CrossAttn/BaseProc.py +291 -0
- TrailBlazer/CrossAttn/InjecterProc.py +79 -0
- TrailBlazer/CrossAttn/Utils.py +181 -0
- TrailBlazer/CrossAttn/__init__.py +1 -0
- TrailBlazer/Misc/BBox.py +93 -0
- TrailBlazer/Misc/ConfigIO.py +13 -0
- TrailBlazer/Misc/Const.py +6 -0
- TrailBlazer/Misc/Logger.py +70 -0
- TrailBlazer/Misc/Painter.py +224 -0
- TrailBlazer/Misc/__init__.py +0 -0
- TrailBlazer/Pipeline/TextToVideoSDPipelineCall.py +339 -0
- TrailBlazer/Pipeline/UNet3DConditionModelCall.py +229 -0
- TrailBlazer/Pipeline/Utils.py +144 -0
- TrailBlazer/Pipeline/__init__.py +0 -0
- TrailBlazer/README.md +1 -0
- TrailBlazer/Setting/Config.py +23 -0
- TrailBlazer/Setting/Const.py +4 -0
- TrailBlazer/Setting/__init__.py +0 -0
- TrailBlazer/__init__.py +8 -0
- app.py +415 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
assets
|
2 |
+
__pycache__
|
3 |
+
*.pyc
|
4 |
+
*.png
|
5 |
+
*undo*
|
TrailBlazer/CrossAttn/BaseProc.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, TypedDict
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
from diffusers.models.attention_processor import Attention as CrossAttention
|
7 |
+
from einops import rearrange
|
8 |
+
from ..Misc import Logger as log
|
9 |
+
from ..Misc.BBox import BoundingBox
|
10 |
+
|
11 |
+
KERNEL_DIVISION = 3.
|
12 |
+
INJECTION_SCALE = 1.0
|
13 |
+
|
14 |
+
|
15 |
+
def reshape_fortran(x, shape):
|
16 |
+
""" Reshape a tensor in the fortran index. See
|
17 |
+
https://stackoverflow.com/a/63964246
|
18 |
+
"""
|
19 |
+
if len(x.shape) > 0:
|
20 |
+
x = x.permute(*reversed(range(len(x.shape))))
|
21 |
+
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
|
22 |
+
|
23 |
+
|
24 |
+
def gaussian_2d(x=0, y=0, mx=0, my=0, sx=1, sy=1):
|
25 |
+
""" 2d Gaussian weight function
|
26 |
+
"""
|
27 |
+
gaussian_map = (
|
28 |
+
1
|
29 |
+
/ (2 * math.pi * sx * sy)
|
30 |
+
* torch.exp(-((x - mx) ** 2 / (2 * sx**2) + (y - my) ** 2 / (2 * sy**2)))
|
31 |
+
)
|
32 |
+
gaussian_map.div_(gaussian_map.max())
|
33 |
+
return gaussian_map
|
34 |
+
|
35 |
+
|
36 |
+
class BundleType(TypedDict):
|
37 |
+
selected_inds: List[int] # the 1-indexed indices of a subject
|
38 |
+
trailing_inds: List[int] # the 1-indexed indices of trailings
|
39 |
+
bbox: List[
|
40 |
+
float
|
41 |
+
] # four floats to determine the bounding box [left, right, top, bottom]
|
42 |
+
|
43 |
+
|
44 |
+
class CrossAttnProcessorBase:
|
45 |
+
|
46 |
+
MAX_LEN_CLIP_TOKENS = 77
|
47 |
+
DEVICE = "cuda"
|
48 |
+
|
49 |
+
def __init__(self, bundle, is_text2vidzero=False):
|
50 |
+
|
51 |
+
self.prompt = bundle["prompt_base"]
|
52 |
+
base_prompt = self.prompt.split(";")[0]
|
53 |
+
self.len_prompt = len(base_prompt.split(" "))
|
54 |
+
self.prompt_len = len(self.prompt.split(" "))
|
55 |
+
self.use_dd = False
|
56 |
+
self.use_dd_temporal = False
|
57 |
+
self.unet_chunk_size = 2
|
58 |
+
self._cross_attention_map = None
|
59 |
+
self._loss = None
|
60 |
+
self._parameters = None
|
61 |
+
self.is_text2vidzero = is_text2vidzero
|
62 |
+
bbox = None
|
63 |
+
|
64 |
+
@property
|
65 |
+
def cross_attention_map(self):
|
66 |
+
return self._cross_attention_map
|
67 |
+
|
68 |
+
@property
|
69 |
+
def loss(self):
|
70 |
+
return self._loss
|
71 |
+
|
72 |
+
@property
|
73 |
+
def parameters(self):
|
74 |
+
if type(self._parameters) == type(None):
|
75 |
+
log.warn("No parameters being initialized. Be cautious!")
|
76 |
+
return self._parameters
|
77 |
+
|
78 |
+
def __call__(
|
79 |
+
self,
|
80 |
+
attn: CrossAttention,
|
81 |
+
hidden_states,
|
82 |
+
encoder_hidden_states=None,
|
83 |
+
attention_mask=None,
|
84 |
+
):
|
85 |
+
|
86 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
87 |
+
attention_mask = attn.prepare_attention_mask(
|
88 |
+
attention_mask, sequence_length, batch_size
|
89 |
+
)
|
90 |
+
#print("====================")
|
91 |
+
query = attn.to_q(hidden_states)
|
92 |
+
|
93 |
+
is_cross_attention = encoder_hidden_states is not None
|
94 |
+
if encoder_hidden_states is None:
|
95 |
+
encoder_hidden_states = hidden_states
|
96 |
+
# elif attn.cross_attention_norm:
|
97 |
+
elif attn.norm_cross:
|
98 |
+
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
|
99 |
+
|
100 |
+
key = attn.to_k(encoder_hidden_states)
|
101 |
+
value = attn.to_v(encoder_hidden_states)
|
102 |
+
|
103 |
+
def rearrange_3(tensor, f):
|
104 |
+
F, D, C = tensor.size()
|
105 |
+
return torch.reshape(tensor, (F // f, f, D, C))
|
106 |
+
|
107 |
+
def rearrange_4(tensor):
|
108 |
+
B, F, D, C = tensor.size()
|
109 |
+
return torch.reshape(tensor, (B * F, D, C))
|
110 |
+
|
111 |
+
# Cross Frame Attention
|
112 |
+
if not is_cross_attention and self.is_text2vidzero:
|
113 |
+
video_length = key.size()[0] // 2
|
114 |
+
first_frame_index = [0] * video_length
|
115 |
+
|
116 |
+
# rearrange keys to have batch and frames in the 1st and 2nd dims respectively
|
117 |
+
key = rearrange_3(key, video_length)
|
118 |
+
key = key[:, first_frame_index]
|
119 |
+
# rearrange values to have batch and frames in the 1st and 2nd dims respectively
|
120 |
+
value = rearrange_3(value, video_length)
|
121 |
+
value = value[:, first_frame_index]
|
122 |
+
|
123 |
+
# rearrange back to original shape
|
124 |
+
key = rearrange_4(key)
|
125 |
+
value = rearrange_4(value)
|
126 |
+
|
127 |
+
query = attn.head_to_batch_dim(query)
|
128 |
+
key = attn.head_to_batch_dim(key)
|
129 |
+
value = attn.head_to_batch_dim(value)
|
130 |
+
# Cross attention map
|
131 |
+
#print(query.shape, key.shape, value.shape)
|
132 |
+
attention_probs = attn.get_attention_scores(query, key)
|
133 |
+
# print(attention_probs.shape)
|
134 |
+
# torch.Size([960, 77, 64]) torch.Size([960, 256, 64]) torch.Size([960, 77, 64]) torch.Size([960, 256, 77])
|
135 |
+
# torch.Size([10240, 24, 64]) torch.Size([10240, 24, 64]) torch.Size([10240, 24, 64]) torch.Size([10240, 24, 24])
|
136 |
+
|
137 |
+
n = attention_probs.shape[0] // 2
|
138 |
+
if attention_probs.shape[-1] == CrossAttnProcessorBase.MAX_LEN_CLIP_TOKENS:
|
139 |
+
dim = int(np.sqrt(attention_probs.shape[1]))
|
140 |
+
if self.use_dd:
|
141 |
+
# self.use_dd = False
|
142 |
+
attention_probs_4d = attention_probs.view(
|
143 |
+
attention_probs.shape[0], dim, dim, attention_probs.shape[-1]
|
144 |
+
)[n:]
|
145 |
+
attention_probs_4d = self.dd_core(attention_probs_4d)
|
146 |
+
attention_probs[n:] = attention_probs_4d.reshape(
|
147 |
+
attention_probs_4d.shape[0], dim * dim, attention_probs_4d.shape[-1]
|
148 |
+
)
|
149 |
+
|
150 |
+
self._cross_attention_map = attention_probs.view(
|
151 |
+
attention_probs.shape[0], dim, dim, attention_probs.shape[-1]
|
152 |
+
)[n:]
|
153 |
+
|
154 |
+
elif (
|
155 |
+
attention_probs.shape[-1] == self.num_frames
|
156 |
+
and (attention_probs.shape[0] == 65536)
|
157 |
+
):
|
158 |
+
dim = int(np.sqrt(attention_probs.shape[0] // (2 * attn.heads)))
|
159 |
+
if self.use_dd_temporal:
|
160 |
+
# self.use_dd_temporal = False
|
161 |
+
def temporal_doit(origin_attn):
|
162 |
+
temporal_attn = reshape_fortran(
|
163 |
+
origin_attn,
|
164 |
+
(attn.heads, dim, dim, self.num_frames, self.num_frames),
|
165 |
+
)
|
166 |
+
temporal_attn = torch.transpose(temporal_attn, 1, 2)
|
167 |
+
temporal_attn = self.dd_core(temporal_attn)
|
168 |
+
# torch.Size([8, 64, 64, 24, 24])
|
169 |
+
temporal_attn = torch.transpose(temporal_attn, 1, 2)
|
170 |
+
temporal_attn = reshape_fortran(
|
171 |
+
temporal_attn,
|
172 |
+
(attn.heads * dim * dim, self.num_frames, self.num_frames),
|
173 |
+
)
|
174 |
+
return temporal_attn
|
175 |
+
|
176 |
+
|
177 |
+
# NOTE: So null text embedding for classification free guidance
|
178 |
+
# doesn't really help?
|
179 |
+
#attention_probs[n:] = temporal_doit(attention_probs[n:])
|
180 |
+
attention_probs[:n] = temporal_doit(attention_probs[:n])
|
181 |
+
|
182 |
+
self._cross_attention_map = reshape_fortran(
|
183 |
+
attention_probs[:n],
|
184 |
+
(attn.heads, dim, dim, self.num_frames, self.num_frames),
|
185 |
+
)
|
186 |
+
self._cross_attention_map = self._cross_attention_map.mean(dim=0)
|
187 |
+
self._cross_attention_map = torch.transpose(self._cross_attention_map, 0, 1)
|
188 |
+
|
189 |
+
attention_probs = torch.abs(attention_probs)
|
190 |
+
hidden_states = torch.bmm(attention_probs, value)
|
191 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
192 |
+
# linear proj
|
193 |
+
hidden_states = attn.to_out[0](hidden_states)
|
194 |
+
# dropout
|
195 |
+
hidden_states = attn.to_out[1](hidden_states)
|
196 |
+
return hidden_states
|
197 |
+
|
198 |
+
@abstractmethod
|
199 |
+
def dd_core(self):
|
200 |
+
"""All DD variants implement this function"""
|
201 |
+
pass
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def localized_weight_map(attention_probs_4d, token_inds, bbox_per_frame, scale=1):
|
205 |
+
"""Using guassian 2d distribution to generate weight map and return the
|
206 |
+
array with the same size of the attention argument.
|
207 |
+
"""
|
208 |
+
dim = int(attention_probs_4d.size()[1])
|
209 |
+
max_val = attention_probs_4d.max()
|
210 |
+
weight_map = torch.zeros_like(attention_probs_4d).half()
|
211 |
+
frame_size = attention_probs_4d.shape[0] // len(bbox_per_frame)
|
212 |
+
|
213 |
+
for i in range(len(bbox_per_frame)):
|
214 |
+
bbox_ratios = bbox_per_frame[i]
|
215 |
+
bbox = BoundingBox(dim, bbox_ratios)
|
216 |
+
# Generating the gaussian distribution map patch
|
217 |
+
x = torch.linspace(0, bbox.height, bbox.height)
|
218 |
+
y = torch.linspace(0, bbox.width, bbox.width)
|
219 |
+
x, y = torch.meshgrid(x, y, indexing="ij")
|
220 |
+
noise_patch = (
|
221 |
+
gaussian_2d(
|
222 |
+
x,
|
223 |
+
y,
|
224 |
+
mx=int(bbox.height / 2),
|
225 |
+
my=int(bbox.width / 2),
|
226 |
+
sx=float(bbox.height / KERNEL_DIVISION),
|
227 |
+
sy=float(bbox.width / KERNEL_DIVISION),
|
228 |
+
)
|
229 |
+
.unsqueeze(0)
|
230 |
+
.unsqueeze(-1)
|
231 |
+
.repeat(frame_size, 1, 1, len(token_inds))
|
232 |
+
.to(attention_probs_4d.device)
|
233 |
+
).half()
|
234 |
+
|
235 |
+
scale = attention_probs_4d.max() * INJECTION_SCALE
|
236 |
+
noise_patch.mul_(scale)
|
237 |
+
|
238 |
+
b_idx = frame_size * i
|
239 |
+
e_idx = frame_size * (i + 1)
|
240 |
+
bbox.sliced_tensor_in_bbox(weight_map)[
|
241 |
+
b_idx:e_idx, ..., token_inds
|
242 |
+
] = noise_patch
|
243 |
+
return weight_map
|
244 |
+
|
245 |
+
@staticmethod
|
246 |
+
def localized_temporal_weight_map(attention_probs_5d, bbox_per_frame, scale=1):
|
247 |
+
"""Using guassian 2d distribution to generate weight map and return the
|
248 |
+
array with the same size of the attention argument.
|
249 |
+
"""
|
250 |
+
dim = int(attention_probs_5d.size()[1])
|
251 |
+
f = attention_probs_5d.shape[-1]
|
252 |
+
max_val = attention_probs_5d.max()
|
253 |
+
weight_map = torch.zeros_like(attention_probs_5d).half()
|
254 |
+
|
255 |
+
def get_patch(bbox_at_frame, i, j, bbox_per_frame):
|
256 |
+
bbox = BoundingBox(dim, bbox_at_frame)
|
257 |
+
# Generating the gaussian distribution map patch
|
258 |
+
x = torch.linspace(0, bbox.height, bbox.height)
|
259 |
+
y = torch.linspace(0, bbox.width, bbox.width)
|
260 |
+
x, y = torch.meshgrid(x, y, indexing="ij")
|
261 |
+
noise_patch = (
|
262 |
+
gaussian_2d(
|
263 |
+
x,
|
264 |
+
y,
|
265 |
+
mx=int(bbox.height / 2),
|
266 |
+
my=int(bbox.width / 2),
|
267 |
+
sx=float(bbox.height / KERNEL_DIVISION),
|
268 |
+
sy=float(bbox.width / KERNEL_DIVISION),
|
269 |
+
)
|
270 |
+
.unsqueeze(0)
|
271 |
+
.repeat(attention_probs_5d.shape[0], 1, 1)
|
272 |
+
.to(attention_probs_5d.device)
|
273 |
+
).half()
|
274 |
+
scale = attention_probs_5d.max() * INJECTION_SCALE
|
275 |
+
noise_patch.mul_(scale)
|
276 |
+
inv_noise_patch = noise_patch - noise_patch.max()
|
277 |
+
dist = (float(abs(j - i))) / len(bbox_per_frame)
|
278 |
+
final_patch = inv_noise_patch * dist + noise_patch * (1. - dist)
|
279 |
+
#final_patch = noise_patch * (1. - dist)
|
280 |
+
#final_patch = inv_noise_patch * dist
|
281 |
+
return final_patch, bbox
|
282 |
+
|
283 |
+
|
284 |
+
for j in range(len(bbox_per_frame)):
|
285 |
+
for i in range(len(bbox_per_frame)):
|
286 |
+
patch_i, bbox_i = get_patch(bbox_per_frame[i], i, j, bbox_per_frame)
|
287 |
+
patch_j, bbox_j = get_patch(bbox_per_frame[j], i, j, bbox_per_frame)
|
288 |
+
bbox_i.sliced_tensor_in_bbox(weight_map)[..., i, j] = patch_i
|
289 |
+
bbox_j.sliced_tensor_in_bbox(weight_map)[..., i, j] = patch_j
|
290 |
+
|
291 |
+
return weight_map
|
TrailBlazer/CrossAttn/InjecterProc.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, TypedDict
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import math
|
5 |
+
|
6 |
+
from ..Misc import Logger as log
|
7 |
+
|
8 |
+
from .BaseProc import CrossAttnProcessorBase
|
9 |
+
from .BaseProc import BundleType
|
10 |
+
from ..Misc.BBox import BoundingBox
|
11 |
+
|
12 |
+
|
13 |
+
class InjecterProcessor(CrossAttnProcessorBase):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
bundle: BundleType,
|
17 |
+
bbox_per_frame: List[BoundingBox],
|
18 |
+
name: str,
|
19 |
+
strengthen_scale: float = 0.0,
|
20 |
+
weaken_scale: float = 1.0,
|
21 |
+
is_text2vidzero: bool = False,
|
22 |
+
):
|
23 |
+
super().__init__(bundle, is_text2vidzero=is_text2vidzero)
|
24 |
+
self.strengthen_scale = strengthen_scale
|
25 |
+
self.weaken_scale = weaken_scale
|
26 |
+
self.bundle = bundle
|
27 |
+
self.num_frames = len(bbox_per_frame)
|
28 |
+
self.bbox_per_frame = bbox_per_frame
|
29 |
+
self.use_weaken = True
|
30 |
+
self.name = name
|
31 |
+
|
32 |
+
def dd_core(self, attention_probs: torch.Tensor):
|
33 |
+
""" """
|
34 |
+
|
35 |
+
frame_size = attention_probs.shape[0] // self.num_frames
|
36 |
+
num_affected_frames = self.num_frames
|
37 |
+
attention_probs_copied = attention_probs.detach().clone()
|
38 |
+
|
39 |
+
token_inds = self.bundle.get("token_inds")
|
40 |
+
trailing_length = self.bundle.get("trailing_length")
|
41 |
+
trailing_inds = list(
|
42 |
+
range(self.len_prompt + 1, self.len_prompt + trailing_length + 1)
|
43 |
+
)
|
44 |
+
# NOTE: Spatial cross attention editing
|
45 |
+
if len(attention_probs.size()) == 4:
|
46 |
+
all_tokens_inds = list(set(token_inds).union(set(trailing_inds)))
|
47 |
+
strengthen_map = self.localized_weight_map(
|
48 |
+
attention_probs_copied,
|
49 |
+
token_inds=all_tokens_inds,
|
50 |
+
bbox_per_frame=self.bbox_per_frame,
|
51 |
+
)
|
52 |
+
|
53 |
+
weaken_map = torch.ones_like(strengthen_map)
|
54 |
+
zero_indices = torch.where(strengthen_map == 0)
|
55 |
+
weaken_map[zero_indices] = self.weaken_scale
|
56 |
+
|
57 |
+
# weakening
|
58 |
+
attention_probs_copied[..., all_tokens_inds] *= weaken_map[
|
59 |
+
..., all_tokens_inds
|
60 |
+
]
|
61 |
+
# strengthen
|
62 |
+
attention_probs_copied[..., all_tokens_inds] += (
|
63 |
+
self.strengthen_scale * strengthen_map[..., all_tokens_inds]
|
64 |
+
)
|
65 |
+
# NOTE: Temporal cross attention editing
|
66 |
+
elif len(attention_probs.size()) == 5:
|
67 |
+
strengthen_map = self.localized_temporal_weight_map(
|
68 |
+
attention_probs_copied,
|
69 |
+
bbox_per_frame=self.bbox_per_frame,
|
70 |
+
)
|
71 |
+
weaken_map = torch.ones_like(strengthen_map)
|
72 |
+
zero_indices = torch.where(strengthen_map == 0)
|
73 |
+
weaken_map[zero_indices] = self.weaken_scale
|
74 |
+
# weakening
|
75 |
+
attention_probs_copied *= weaken_map
|
76 |
+
# strengthen
|
77 |
+
attention_probs_copied += self.strengthen_scale * strengthen_map
|
78 |
+
|
79 |
+
return attention_probs_copied
|
TrailBlazer/CrossAttn/Utils.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from ..Misc import Logger as log
|
7 |
+
from ..Setting import Config
|
8 |
+
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import matplotlib
|
11 |
+
|
12 |
+
# To avoid plt.imshow crash
|
13 |
+
matplotlib.use("Agg")
|
14 |
+
|
15 |
+
|
16 |
+
class CAttnProcChoice(enum.Enum):
|
17 |
+
INVALID = -1
|
18 |
+
BASIC = 0
|
19 |
+
|
20 |
+
|
21 |
+
def plot_activations(cross_attn, prompt, plot_with_trailings=False):
|
22 |
+
num_frames = cross_attn.shape[0]
|
23 |
+
cross_attn = cross_attn.cpu()
|
24 |
+
for i in range(num_frames):
|
25 |
+
filename = "/tmp/out.{:04d}.jpg".format(i)
|
26 |
+
plot_activation(cross_attn[i], prompt, filename, plot_with_trailings)
|
27 |
+
|
28 |
+
|
29 |
+
def plot_activation(cross_attn, prompt, filepath="", plot_with_trailings=False):
|
30 |
+
|
31 |
+
splitted_prompt = prompt.split(" ")
|
32 |
+
n = len(splitted_prompt)
|
33 |
+
start = 0
|
34 |
+
arrs = []
|
35 |
+
if plot_with_trailings:
|
36 |
+
for j in range(5):
|
37 |
+
arr = []
|
38 |
+
for i in range(start, start + n):
|
39 |
+
cross_attn_sliced = cross_attn[..., i + 1]
|
40 |
+
arr.append(cross_attn_sliced.T)
|
41 |
+
start += n
|
42 |
+
arr = np.hstack(arr)
|
43 |
+
arrs.append(arr)
|
44 |
+
arrs = np.vstack(arrs).T
|
45 |
+
else:
|
46 |
+
arr = []
|
47 |
+
for i in range(start, start + n):
|
48 |
+
print(i)
|
49 |
+
cross_attn_sliced = cross_attn[..., i + 1]
|
50 |
+
arr.append(cross_attn_sliced)
|
51 |
+
arrs = np.hstack(arr).astype(np.float32)
|
52 |
+
plt.clf()
|
53 |
+
|
54 |
+
v_min = arrs.min()
|
55 |
+
v_max = arrs.max()
|
56 |
+
n_min = 0.0
|
57 |
+
n_max = 1
|
58 |
+
|
59 |
+
arrs = (arrs - v_min) / (v_max - v_min)
|
60 |
+
arrs = (arrs * (n_max - n_min)) + n_min
|
61 |
+
|
62 |
+
plt.imshow(arrs, cmap="jet")
|
63 |
+
plt.title(prompt)
|
64 |
+
plt.colorbar(orientation="horizontal", pad=0.2)
|
65 |
+
if filepath:
|
66 |
+
plt.savefig(filepath)
|
67 |
+
log.info(f"Saved [{filepath}]")
|
68 |
+
else:
|
69 |
+
plt.show()
|
70 |
+
|
71 |
+
|
72 |
+
def get_cross_attn(
|
73 |
+
unet,
|
74 |
+
resolution=32,
|
75 |
+
target_size=64,
|
76 |
+
):
|
77 |
+
"""To get the cross attention map softmax(QK^T) from Unet.
|
78 |
+
Args:
|
79 |
+
unet (UNet2DConditionModel): unet
|
80 |
+
resolution (int): the cross attention map with specific resolution. It only supports 64, 32, 16, and 8
|
81 |
+
target_size (int): the target resolution for resizing the cross attention map
|
82 |
+
Returns:
|
83 |
+
(torch.tensor): a tensor with shape (target_size, target_size, 77)
|
84 |
+
"""
|
85 |
+
attns = []
|
86 |
+
check = [8, 16, 32, 64]
|
87 |
+
if resolution not in check:
|
88 |
+
raise ValueError(
|
89 |
+
"The cross attention resolution only support 8x8, 16x16, 32x32, and 64x64. "
|
90 |
+
"The given resolution {}x{} is not in the list. Abort.".format(
|
91 |
+
resolution, resolution
|
92 |
+
)
|
93 |
+
)
|
94 |
+
for name, module in unet.named_modules():
|
95 |
+
module_name = type(module).__name__
|
96 |
+
# NOTE: attn2 is for cross-attention while attn1 is self-attention
|
97 |
+
dim = resolution * resolution
|
98 |
+
if not hasattr(module, "processor"):
|
99 |
+
continue
|
100 |
+
if hasattr(module.processor, "cross_attention_map"):
|
101 |
+
attn = module.processor.cross_attention_map[None, ...]
|
102 |
+
attns.append(attn)
|
103 |
+
|
104 |
+
if not attns:
|
105 |
+
print("Err: Quried attns size [{}]".format(len(attns)))
|
106 |
+
return
|
107 |
+
attns = torch.cat(attns, dim=0)
|
108 |
+
attns = torch.sum(attns, dim=0)
|
109 |
+
# resized = torch.zeros([target_size, target_size, 77])
|
110 |
+
# f = torchvision.transforms.Resize(size=(64, 64))
|
111 |
+
# dim = attns.shape[1]
|
112 |
+
# print(attns.shape)
|
113 |
+
# for i in range(77):
|
114 |
+
# attn_slice = attns[..., i].view(1, dim, dim)
|
115 |
+
# resized[..., i] = f(attn_slice)[0]
|
116 |
+
return attns
|
117 |
+
|
118 |
+
|
119 |
+
def get_avg_cross_attn(unet, resolutions, resize):
|
120 |
+
"""To get the average cross attention map across its resolutions.
|
121 |
+
Args:
|
122 |
+
unet (UNet2DConditionModel): unet
|
123 |
+
resolution (list): a list of specific resolution. It only supports 64, 32, 16, and 8
|
124 |
+
target_size (int): the target resolution for resizing the cross attention map
|
125 |
+
Returns:
|
126 |
+
(torch.tensor): a tensor with shape (target_size, target_size, 77)
|
127 |
+
"""
|
128 |
+
cross_attns = []
|
129 |
+
for resolution in resolutions:
|
130 |
+
try:
|
131 |
+
cross_attns.append(get_cross_attn(unet, resolution, resize))
|
132 |
+
except:
|
133 |
+
log.warn(f"No cross-attention map with resolution [{resolution}]")
|
134 |
+
if cross_attns:
|
135 |
+
cross_attns = torch.stack(cross_attns).mean(0)
|
136 |
+
return cross_attns
|
137 |
+
|
138 |
+
|
139 |
+
def save_cross_attn(unet):
|
140 |
+
"""TODO: to save cross attn"""
|
141 |
+
for name, module in unet.named_modules():
|
142 |
+
module_name = type(module).__name__
|
143 |
+
if module_name == "CrossAttention" and "attn2" in name:
|
144 |
+
folder = "/tmp"
|
145 |
+
filepath = os.path.join(folder, name + ".pt")
|
146 |
+
torch.save(module.attn, filepath)
|
147 |
+
print(filepath)
|
148 |
+
|
149 |
+
|
150 |
+
def use_dd(unet, use=True):
|
151 |
+
for name, module in unet.named_modules():
|
152 |
+
module_name = type(module).__name__
|
153 |
+
if module_name == "CrossAttention" and "attn2" in name:
|
154 |
+
module.processor.use_dd = use
|
155 |
+
|
156 |
+
|
157 |
+
def use_dd_temporal(unet, use=True):
|
158 |
+
for name, module in unet.named_modules():
|
159 |
+
module_name = type(module).__name__
|
160 |
+
if module_name == "CrossAttention" and "attn2" in name:
|
161 |
+
module.processor.use_dd_temporal = use
|
162 |
+
|
163 |
+
|
164 |
+
def get_loss(unet):
|
165 |
+
loss = 0
|
166 |
+
total = 0
|
167 |
+
for name, module in unet.named_modules():
|
168 |
+
module_name = type(module).__name__
|
169 |
+
if module_name == "CrossAttention" and "attn2" in name:
|
170 |
+
loss += module.processor.loss
|
171 |
+
total += 1
|
172 |
+
return loss / total
|
173 |
+
|
174 |
+
|
175 |
+
def get_params(unet):
|
176 |
+
parameters = []
|
177 |
+
for name, module in unet.named_modules():
|
178 |
+
module_name = type(module).__name__
|
179 |
+
if module_name == "CrossAttention" and "attn2" in name:
|
180 |
+
parameters.append(module.processor.parameters)
|
181 |
+
return parameters
|
TrailBlazer/CrossAttn/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
TrailBlazer/Misc/BBox.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class BoundingBox:
|
7 |
+
"""A rectangular bounding box determines the directed regions."""
|
8 |
+
|
9 |
+
def __init__(self, resolution, box_ratios, margin=0.0):
|
10 |
+
"""
|
11 |
+
Args:
|
12 |
+
resolution(int): the resolution of the 2d spatial input
|
13 |
+
box_ratios(List[float]):
|
14 |
+
Returns:
|
15 |
+
"""
|
16 |
+
assert (
|
17 |
+
box_ratios[1] < box_ratios[3]
|
18 |
+
), "the boundary top ratio should be less than bottom"
|
19 |
+
assert (
|
20 |
+
box_ratios[0] < box_ratios[2]
|
21 |
+
), "the boundary left ratio should be less than right"
|
22 |
+
self.left = int((box_ratios[0] - margin) * resolution)
|
23 |
+
self.right = int((box_ratios[2] + margin) * resolution)
|
24 |
+
self.top = int((box_ratios[1] - margin) * resolution)
|
25 |
+
self.bottom = int((box_ratios[3] + margin) * resolution)
|
26 |
+
self.height = self.bottom - self.top
|
27 |
+
self.width = self.right - self.left
|
28 |
+
if self.height == 0:
|
29 |
+
self.height = 1
|
30 |
+
if self.width == 0:
|
31 |
+
self.width = 1
|
32 |
+
|
33 |
+
def sliced_tensor_in_bbox(self, tensor: torch.tensor) -> torch.tensor:
|
34 |
+
""" slicing the tensor with bbox area
|
35 |
+
|
36 |
+
Args:
|
37 |
+
tensor(torch.tensor): the original tensor in 4d
|
38 |
+
Returns:
|
39 |
+
(torch.tensor): the reduced tensor inside bbox
|
40 |
+
"""
|
41 |
+
return tensor[:, self.top : self.bottom, self.left : self.right, :]
|
42 |
+
|
43 |
+
def mask_reweight_out_bbox(
|
44 |
+
self, tensor: torch.tensor, value: float = 0.0
|
45 |
+
) -> torch.tensor:
|
46 |
+
"""reweighting value outside bbox
|
47 |
+
|
48 |
+
Args:
|
49 |
+
tensor(torch.tensor): the original tensor in 4d
|
50 |
+
value(float): reweighting factor default with 0.0
|
51 |
+
Returns:
|
52 |
+
(torch.tensor): the reweighted tensor
|
53 |
+
"""
|
54 |
+
mask = torch.ones_like(tensor).to(tensor.device) * value
|
55 |
+
mask[:, self.top : self.bottom, self.left : self.right, :] = 1
|
56 |
+
return tensor * mask
|
57 |
+
|
58 |
+
def mask_reweight_in_bbox(
|
59 |
+
self, tensor: torch.tensor, value: float = 0.0
|
60 |
+
) -> torch.tensor:
|
61 |
+
"""reweighting value within bbox
|
62 |
+
|
63 |
+
Args:
|
64 |
+
tensor(torch.tensor): the original tensor in 4d
|
65 |
+
value(float): reweighting factor default with 0.0
|
66 |
+
Returns:
|
67 |
+
(torch.tensor): the reweighted tensor
|
68 |
+
"""
|
69 |
+
mask = torch.ones_like(tensor).to(tensor.device)
|
70 |
+
mask[:, self.top : self.bottom, self.left : self.right, :] = value
|
71 |
+
return tensor * mask
|
72 |
+
|
73 |
+
def __str__(self):
|
74 |
+
"""it prints Box(L:%d, R:%d, T:%d, B:%d) for better ingestion"""
|
75 |
+
return f"Box(L:{self.left}, R:{self.right}, T:{self.top}, B:{self.bottom})"
|
76 |
+
|
77 |
+
def __rerp__(self):
|
78 |
+
""" """
|
79 |
+
return f"Box(L:{self.left}, R:{self.right}, T:{self.top}, B:{self.bottom})"
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
# Example: second quadrant
|
84 |
+
input_res = 32
|
85 |
+
left = 0.0
|
86 |
+
top = 0.0
|
87 |
+
right = 0.5
|
88 |
+
bottom = 0.5
|
89 |
+
box_ratios = [left, top, right, bottom]
|
90 |
+
bbox = BoundingBox(resolution=input_res, box_ratios=box_ratios)
|
91 |
+
|
92 |
+
print(bbox)
|
93 |
+
# Box(L:0, R:16, T:0, B:16)
|
TrailBlazer/Misc/ConfigIO.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
|
3 |
+
def config_loader(filepath):
|
4 |
+
data = None
|
5 |
+
with open(filepath, "r") as yamlfile:
|
6 |
+
data = yaml.load(yamlfile, Loader=yaml.FullLoader)
|
7 |
+
yamlfile.close()
|
8 |
+
return data
|
9 |
+
|
10 |
+
def config_saver(data, filepath):
|
11 |
+
with open(filepath, 'w') as yamlfile:
|
12 |
+
data1 = yaml.dump(data, yamlfile)
|
13 |
+
yamlfile.close()
|
TrailBlazer/Misc/Const.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://okuha.com/best-stable-diffusion-prompts/
|
2 |
+
|
3 |
+
NEGATIVE_PROMPT = "bad anatomy, bad proportions, blurry, cloned face, cropped, deformed, dehydrated, disfigured, duplicate, error, extra arms, extra fingers, extra legs, extra limbs, fused fingers, gross proportions, jpeg artifacts, long neck, low quality, lowres, malformed limbs, missing arms, missing legs, morbid, mutated hands, mutation, mutilated, out of frame, poorly drawn face, poorly drawn hands, signature, text, too many fingers, ugly, username, watermark, worst quality, Amputee, Autograph, Bad anatomy, Bad illustration, Bad proportions, Beyond the borders, Blank background, Blurry, Body out of frame, Boring background, Branding, Cropped, Cut off, Deformed, Disfigured, Dismembered, Disproportioned, Distorted, Draft, Duplicate, Duplicated features, Extra arms, Extra fingers, Extra hands, Extra legs, Extra limbs, Fault, Flaw, Fused fingers, Grains, Grainy, Gross proportions, Hazy, Identifying mark, Improper scale, Incorrect physiology, Incorrect ratio, Indistinct, Kitsch, Logo, Long neck, Low quality, Low resolution, Macabre, Malformed, Mark, Misshapen, Missing arms, Missing fingers, Missing hands, Missing legs, Mistake, Morbid, Mutated hands, Mutation, Mutilated, Off-screen, Out of frame, Outside the picture, Pixelated, Poorly drawn face, Poorly drawn feet, Poorly drawn hands, Printed words, Render, Repellent, Replicate, Reproduce, Revolting dimensions, Script, Shortened, Sign, Signature, Split image, Squint, Storyboard, Text, Tiling, Trimmed, Ugly, Unfocused, Unattractive, Unnatural pose, Unreal engine, Unsightly, Watermark, Written language, Absent limbs, Additional appendages, Additional digits, Additional limbs, Altered appendages, Amputee, Asymmetric, Asymmetric ears, Bad anatomy, Bad ears, Bad eyes, Bad face, Bad proportions, Broken finger, Broken hand, Broken leg, Broken wrist, Cartoon, Cloned face, Cloned head, Collapsed eyeshadow, Combined appendages, Conjoined, Copied visage, Corpse, Cripple, Cropped head, Cross-eyed, Depressed, Desiccated, Disconnected limb, Disfigured, Dismembered, Disproportionate, Double face, Duplicated features, Eerie, Elongated throat, lowres, low quality, jpeg, artifacts, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, drawing, painting, crayon, sketch, graphite, impressionist, noisy, soft, extra tails"
|
4 |
+
|
5 |
+
|
6 |
+
POSITIVE_PROMPT = "; masterpiece, best quality, intricate, detailed, sharp, focused, intricate details, hyperdetailed, 8k, RAW photo,realistic style, national geography, fantasy, hyper-realistic, rich colors, realistic texture"
|
TrailBlazer/Misc/Logger.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from io import StringIO # Python3
|
5 |
+
|
6 |
+
import sys
|
7 |
+
|
8 |
+
class SilencedStdOut:
|
9 |
+
# https://stackoverflow.com/questions/65608502/is-there-a-way-to-force-any-function-to-not-be-verbose-in-python
|
10 |
+
def __enter__(self):
|
11 |
+
self.old_stdout = sys.stdout
|
12 |
+
self.result = StringIO()
|
13 |
+
sys.stdout = self.result
|
14 |
+
|
15 |
+
def __exit__(self, *args, **kwargs):
|
16 |
+
|
17 |
+
sys.stdout = self.old_stdout
|
18 |
+
result_string = self.result.getvalue() # use if you want or discard.
|
19 |
+
|
20 |
+
class CustomFormatter(logging.Formatter):
|
21 |
+
|
22 |
+
GRAY = "\x1b[38m"
|
23 |
+
YELLOW = "\x1b[33m"
|
24 |
+
CYAN = "\x1b[36m"
|
25 |
+
RED = "\x1b[31m"
|
26 |
+
BOLD_RED = "\x1b[31;1m"
|
27 |
+
RESET = "\x1b[0m"
|
28 |
+
FORMAT = "[%(asctime)s - %(name)s - %(levelname)8s] - %(message)s (%(filename)s:%(lineno)d)"
|
29 |
+
|
30 |
+
FORMATS = {
|
31 |
+
logging.DEBUG: GRAY + FORMAT + RESET,
|
32 |
+
logging.INFO: GRAY + FORMAT + RESET,
|
33 |
+
logging.WARNING: YELLOW + FORMAT + RESET,
|
34 |
+
logging.ERROR: RED + FORMAT + RESET,
|
35 |
+
logging.CRITICAL: BOLD_RED + FORMAT + RESET,
|
36 |
+
logging.DEBUG: CYAN + FORMAT + RESET,
|
37 |
+
}
|
38 |
+
|
39 |
+
def format(self, record):
|
40 |
+
log_fmt = self.FORMATS.get(record.levelno)
|
41 |
+
formatter = logging.Formatter(log_fmt)
|
42 |
+
return formatter.format(record)
|
43 |
+
|
44 |
+
# create logger with 'spam_application'
|
45 |
+
|
46 |
+
logger = logging.getLogger("TrailBlazer")
|
47 |
+
logger.handlers = []
|
48 |
+
logger.setLevel(logging.DEBUG)
|
49 |
+
# create console handler with a higher log level
|
50 |
+
console_handler = logging.StreamHandler()
|
51 |
+
console_handler.setLevel(logging.DEBUG)
|
52 |
+
|
53 |
+
console_handler.setFormatter(CustomFormatter())
|
54 |
+
logger.addHandler(console_handler)
|
55 |
+
|
56 |
+
critical = logger.critical
|
57 |
+
fatal = logger.fatal
|
58 |
+
error = logger.error
|
59 |
+
warning = logger.warning
|
60 |
+
warn = logger.warn
|
61 |
+
info = logger.info
|
62 |
+
debug = logger.debug
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
from DirectedDiffusion import Logger as log
|
66 |
+
log.info("info message")
|
67 |
+
log.warning("warning message")
|
68 |
+
log.error("error message")
|
69 |
+
log.debug("debug message")
|
70 |
+
log.critical("critical message")
|
TrailBlazer/Misc/Painter.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
import torch
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn.functional as nnf
|
7 |
+
import torchvision
|
8 |
+
import einops
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import scipy.stats as st
|
11 |
+
from PIL import Image, ImageFont, ImageDraw
|
12 |
+
|
13 |
+
plt.rcParams["figure.figsize"] = [
|
14 |
+
float(v) * 1.5 for v in plt.rcParams["figure.figsize"]
|
15 |
+
]
|
16 |
+
|
17 |
+
|
18 |
+
class CrossAttnPainter:
|
19 |
+
|
20 |
+
def __init__(self, bundle, pipe, root="/tmp"):
|
21 |
+
self.dim = 64
|
22 |
+
self.folder =
|
23 |
+
|
24 |
+
def plot_frames(self):
|
25 |
+
folder = "/tmp"
|
26 |
+
from PIL import Image
|
27 |
+
for i, f in enumerate(video_frames):
|
28 |
+
img = Image.fromarray(f)
|
29 |
+
filepath = os.path.join(folder, "recons.{:04d}.jpg".format(i))
|
30 |
+
img.save(filepath)
|
31 |
+
|
32 |
+
|
33 |
+
def plot_spatial_attn(self):
|
34 |
+
|
35 |
+
arr = (
|
36 |
+
pipe.unet.up_blocks[1]
|
37 |
+
.attentions[0]
|
38 |
+
.transformer_blocks[0]
|
39 |
+
.attn2.processor.cross_attention_map
|
40 |
+
)
|
41 |
+
heads = pipe.unet.up_blocks[1].attentions[0].transformer_blocks[0].attn2.heads
|
42 |
+
arr = torch.transpose(arr, 1, 3)
|
43 |
+
arr = nnf.interpolate(arr, size=(64, 64), mode='bicubic', align_corners=False)
|
44 |
+
arr = torch.transpose(arr, 1, 3)
|
45 |
+
arr = arr.cpu().numpy()
|
46 |
+
arr = arr.reshape(24, heads, 64, 64, 77)
|
47 |
+
arr = arr.mean(axis=1)
|
48 |
+
n = arr.shape[0]
|
49 |
+
for i in range(n):
|
50 |
+
filename = "/tmp/spatialca.{:04d}.jpg".format(i)
|
51 |
+
plt.clf()
|
52 |
+
plt.imshow(arr[i, :, :, 2], cmap="jet")
|
53 |
+
plt.gca().set_axis_off()
|
54 |
+
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
|
55 |
+
hspace = 0, wspace = 0)
|
56 |
+
plt.margins(0,0)
|
57 |
+
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
58 |
+
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
59 |
+
plt.savefig(filename, bbox_inches = 'tight',pad_inches = 0)
|
60 |
+
print(filename)
|
61 |
+
|
62 |
+
def plot_temporal_attn(self):
|
63 |
+
|
64 |
+
# arr = pipe.unet.mid_block.temp_attentions[0].transformer_blocks[0].attn2.processor.cross_attention_map
|
65 |
+
import matplotlib.pyplot as plt
|
66 |
+
import torch.nn.functional as nnf
|
67 |
+
arr = (
|
68 |
+
pipe.unet.up_blocks[2]
|
69 |
+
.temp_attentions[1]
|
70 |
+
.transformer_blocks[0]
|
71 |
+
.attn2.processor.cross_attention_map
|
72 |
+
)
|
73 |
+
#arr = pipe.unet.transformer_in.transformer_blocks[0].attn2.processor.cross_attention_map
|
74 |
+
arr = torch.transpose(arr, 0, 2).transpose(1, 3)
|
75 |
+
arr = nnf.interpolate(arr, size=(64, 64), mode="bicubic", align_corners=False)
|
76 |
+
arr = torch.transpose(arr, 0, 2).transpose(1, 3)
|
77 |
+
arr = arr.cpu().numpy()
|
78 |
+
n = arr.shape[-1]
|
79 |
+
for i in range(n-2):
|
80 |
+
filename = "/tmp/tempcaiip2.{:04d}.jpg".format(i)
|
81 |
+
plt.clf()
|
82 |
+
plt.imshow(arr[..., i+2, i], cmap="jet")
|
83 |
+
plt.gca().set_axis_off()
|
84 |
+
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
85 |
+
plt.margins(0, 0)
|
86 |
+
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
87 |
+
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
88 |
+
plt.savefig(filename, bbox_inches="tight", pad_inches=0)
|
89 |
+
print(filename)
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
def plot_latent_noise(latents, mode):
|
101 |
+
|
102 |
+
for i in range(latents.shape[0]):
|
103 |
+
tensor = latents[i].cpu()
|
104 |
+
min_val = torch.min(tensor)
|
105 |
+
max_val = torch.max(tensor)
|
106 |
+
scale = 255 * (max_val - min_val)
|
107 |
+
tensor = scale * (tensor - min_val)
|
108 |
+
tensor = tensor.type(torch.int8)
|
109 |
+
tensor = einops.rearrange(tensor, "c w h -> w h c")
|
110 |
+
if mode == "RGB":
|
111 |
+
tensor = tensor[...,:3]
|
112 |
+
mode_ = "RGB"
|
113 |
+
elif mode == "RGBA":
|
114 |
+
mode_ = "RGBA"
|
115 |
+
pass
|
116 |
+
elif mode == "GRAY":
|
117 |
+
tensor = tensor[...,0]
|
118 |
+
mode_ = "L"
|
119 |
+
|
120 |
+
x = tensor.numpy()
|
121 |
+
|
122 |
+
img = Image.fromarray(x, mode_)
|
123 |
+
img = img.resize((256, 256), resample=Image.NEAREST )
|
124 |
+
filepath = f"/tmp/out.{i:04d}.jpg"
|
125 |
+
img.save(filepath)
|
126 |
+
|
127 |
+
tensor = latents[i].cpu()
|
128 |
+
x = tensor.flatten().numpy()
|
129 |
+
x /= x.max()
|
130 |
+
plt.hist(x, density=True, bins=20, range=[-1, 1])
|
131 |
+
mn, mx = plt.xlim()
|
132 |
+
plt.xlim(mn, mx)
|
133 |
+
kde_xs = np.linspace(mn, mx, 300)
|
134 |
+
kde = st.gaussian_kde(x)
|
135 |
+
plt.plot(kde_xs, kde.pdf(kde_xs), label="PDF")
|
136 |
+
filepath = f"/tmp/hist.{i:04d}.jpg"
|
137 |
+
plt.savefig(filepath)
|
138 |
+
plt.clf()
|
139 |
+
|
140 |
+
print(i)
|
141 |
+
|
142 |
+
|
143 |
+
def plot_activation(cross_attn, prompt, filepath="", plot_with_trailings=False, n_trailing=2):
|
144 |
+
splitted_prompt = prompt.split(" ")
|
145 |
+
n = len(splitted_prompt)
|
146 |
+
start = 0
|
147 |
+
arrs = []
|
148 |
+
if plot_with_trailings:
|
149 |
+
for j in range(n_trailing):
|
150 |
+
arr = []
|
151 |
+
for i in range(start, start + n):
|
152 |
+
cross_attn_sliced = cross_attn[..., i + 1]
|
153 |
+
arr.append(cross_attn_sliced.T)
|
154 |
+
start += n
|
155 |
+
arr = np.hstack(arr)
|
156 |
+
arrs.append(arr)
|
157 |
+
arrs = np.vstack(arrs).T
|
158 |
+
else:
|
159 |
+
arr = []
|
160 |
+
for i in range(start, start + n):
|
161 |
+
cross_attn_sliced = cross_attn[..., i + 1]
|
162 |
+
arr.append(cross_attn_sliced)
|
163 |
+
arrs = np.vstack(arr)
|
164 |
+
plt.imshow(arrs, cmap="jet", vmin=0.0, vmax=.5)
|
165 |
+
plt.title(prompt)
|
166 |
+
if filepath:
|
167 |
+
plt.savefig(filepath)
|
168 |
+
else:
|
169 |
+
plt.show()
|
170 |
+
|
171 |
+
|
172 |
+
def draw_dd_metadata(img, bbox, text="", target_res=1024):
|
173 |
+
img = img.resize((target_res, target_res))
|
174 |
+
image_editable = ImageDraw.Draw(img)
|
175 |
+
|
176 |
+
for region in [bbox]:
|
177 |
+
x0 = region[0] * target_res
|
178 |
+
y0 = region[2] * target_res
|
179 |
+
x1 = region[1] * target_res
|
180 |
+
y1 = region[3] * target_res
|
181 |
+
image_editable.rectangle(xy=[x0, y0, x1, y1], outline=(255, 0, 0, 255), width=5)
|
182 |
+
if text:
|
183 |
+
font = ImageFont.truetype("./assets/JetBrainsMono-Bold.ttf", size=13)
|
184 |
+
image_editable.multiline_text(
|
185 |
+
(15, 15),
|
186 |
+
text,
|
187 |
+
(255, 255, 255, 0),
|
188 |
+
font=font,
|
189 |
+
stroke_width=2,
|
190 |
+
stroke_fill=(0, 0, 0, 255),
|
191 |
+
spacing=0,
|
192 |
+
)
|
193 |
+
return img
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
if __name__ == "__main__":
|
223 |
+
latents = torch.load("assets/experiments/a-cat-sitting-on-a-car_230615-144611/latents.pt")
|
224 |
+
plot_latent_noise(latents, "GRAY")
|
TrailBlazer/Misc/__init__.py
ADDED
File without changes
|
TrailBlazer/Pipeline/TextToVideoSDPipelineCall.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
10 |
+
from diffusers.models import AutoencoderKL, UNet3DConditionModel
|
11 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
12 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
13 |
+
from diffusers.utils import (
|
14 |
+
deprecate,
|
15 |
+
logging,
|
16 |
+
replace_example_docstring,
|
17 |
+
BaseOutput,
|
18 |
+
)
|
19 |
+
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import (
|
20 |
+
tensor2vid,
|
21 |
+
)
|
22 |
+
|
23 |
+
from ..Misc import Logger as log
|
24 |
+
from ..Misc import Const
|
25 |
+
from .Utils import initiailization, keyframed_bbox, keyframed_prompt_embeds, use_dd, use_dd_temporal
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class TextToVideoSDPipelineOutput(BaseOutput):
|
29 |
+
"""
|
30 |
+
Output class for text-to-video pipelines.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
frames (`List[np.ndarray]` or `torch.FloatTensor`)
|
34 |
+
List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
|
35 |
+
a `torch` tensor. The length of the list denotes the video length (the number of frames).
|
36 |
+
"""
|
37 |
+
|
38 |
+
frames: Union[List[np.ndarray], torch.FloatTensor]
|
39 |
+
latents: Union[List[np.ndarray], torch.FloatTensor]
|
40 |
+
bbox_per_frame: torch.tensor
|
41 |
+
|
42 |
+
|
43 |
+
@torch.no_grad()
|
44 |
+
def text_to_video_sd_pipeline_call(
|
45 |
+
self,
|
46 |
+
bundle=None,
|
47 |
+
# prompt: Union[str, List[str]] = None,
|
48 |
+
height: Optional[int] = None,
|
49 |
+
width: Optional[int] = None,
|
50 |
+
# num_frames: int = 16,
|
51 |
+
num_inference_steps: int = 50,
|
52 |
+
# num_dd_steps: int = 0,
|
53 |
+
guidance_scale: float = 9.0,
|
54 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
55 |
+
eta: float = 0.0,
|
56 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
57 |
+
latents: Optional[torch.FloatTensor] = None,
|
58 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
59 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
60 |
+
output_type: Optional[str] = "np",
|
61 |
+
return_dict: bool = True,
|
62 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
63 |
+
callback_steps: int = 1,
|
64 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
65 |
+
):
|
66 |
+
r"""
|
67 |
+
The call function to the pipeline for generation.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
prompt (`str` or `List[str]`, *optional*):
|
71 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
72 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
73 |
+
The height in pixels of the generated video.
|
74 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
75 |
+
The width in pixels of the generated video.
|
76 |
+
num_frames (`int`, *optional*, defaults to 16):
|
77 |
+
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
|
78 |
+
amounts to 2 seconds of video.
|
79 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
80 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
81 |
+
expense of slower inference.
|
82 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
83 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
84 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
85 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
86 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
87 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
88 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
89 |
+
The number of images to generate per prompt.
|
90 |
+
eta (`float`, *optional*, defaults to 0.0):
|
91 |
+
Corresponds to parameter eta (Ξ·) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
92 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
93 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
94 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
95 |
+
generation deterministic.
|
96 |
+
latents (`torch.FloatTensor`, *optional*):
|
97 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
98 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
99 |
+
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
|
100 |
+
`(batch_size, num_channel, num_frames, height, width)`.
|
101 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
102 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
103 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
104 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
105 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
106 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
107 |
+
output_type (`str`, *optional*, defaults to `"np"`):
|
108 |
+
The output format of the generated video. Choose between `torch.FloatTensor` or `np.array`.
|
109 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
110 |
+
Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
|
111 |
+
of a plain tuple.
|
112 |
+
callback (`Callable`, *optional*):
|
113 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
114 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
115 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
116 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
117 |
+
every step.
|
118 |
+
cross_attention_kwargs (`dict`, *optional*):
|
119 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
120 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
121 |
+
|
122 |
+
Examples:
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
|
126 |
+
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
|
127 |
+
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
128 |
+
"""
|
129 |
+
|
130 |
+
assert (
|
131 |
+
len(bundle["keyframe"]) >= 2
|
132 |
+
), "Must be greater than 2 keyframes. Input {} keys".format(len(bundle["keyframe"]))
|
133 |
+
|
134 |
+
assert (
|
135 |
+
bundle["keyframe"][0]["frame"] == 0
|
136 |
+
), "First keyframe must indicate frame at 0, but given {}".format(
|
137 |
+
bundle["keyframe"][0]["frame"]
|
138 |
+
)
|
139 |
+
|
140 |
+
if bundle["keyframe"][-1]["frame"] != 23:
|
141 |
+
log.info(
|
142 |
+
"It's recommended to set the last key to 23 to match"
|
143 |
+
" the sequence length 24 used in training ZeroScope"
|
144 |
+
)
|
145 |
+
|
146 |
+
for i in range(len(bundle["keyframe"]) - 1):
|
147 |
+
log.info
|
148 |
+
assert (
|
149 |
+
bundle["keyframe"][i + 1]["frame"] > bundle["keyframe"][i]["frame"]
|
150 |
+
), "The keyframe indices must be ordered in the config file, Sorry!"
|
151 |
+
|
152 |
+
bundle["prompt_base"] = bundle["keyframe"][0]["prompt"]
|
153 |
+
prompt = bundle["prompt_base"]
|
154 |
+
#prompt += Const.POSITIVE_PROMPT
|
155 |
+
num_frames = bundle["keyframe"][-1]["frame"] + 1
|
156 |
+
num_dd_spatial_steps = bundle["num_dd_spatial_steps"]
|
157 |
+
num_dd_temporal_steps = bundle["num_dd_temporal_steps"]
|
158 |
+
|
159 |
+
bbox_per_frame = keyframed_bbox(bundle)
|
160 |
+
initiailization(unet=self.unet, bundle=bundle, bbox_per_frame=bbox_per_frame)
|
161 |
+
|
162 |
+
from pprint import pprint
|
163 |
+
|
164 |
+
log.info("Experiment parameters:")
|
165 |
+
print("==========================================")
|
166 |
+
pprint(bundle)
|
167 |
+
print("==========================================")
|
168 |
+
# 0. Default height and width to unet
|
169 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
170 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
171 |
+
|
172 |
+
num_images_per_prompt = 1
|
173 |
+
negative_prompt = Const.NEGATIVE_PROMPT
|
174 |
+
# 1. Check inputs. Raise error if not correct
|
175 |
+
# self.check_inputs(
|
176 |
+
# prompt,
|
177 |
+
# height,
|
178 |
+
# width,
|
179 |
+
# callback_steps,
|
180 |
+
# negative_prompt,
|
181 |
+
# prompt_embeds,
|
182 |
+
# negative_prompt_embeds,
|
183 |
+
# )
|
184 |
+
|
185 |
+
# # 2. Define call parameters
|
186 |
+
if prompt is not None and isinstance(prompt, str):
|
187 |
+
batch_size = 1
|
188 |
+
elif prompt is not None and isinstance(prompt, list):
|
189 |
+
batch_size = len(prompt)
|
190 |
+
else:
|
191 |
+
batch_size = prompt_embeds.shape[0]
|
192 |
+
|
193 |
+
device = self._execution_device
|
194 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
195 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
196 |
+
# corresponds to doing no classifier free guidance.
|
197 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
198 |
+
|
199 |
+
# 3. Encode input prompt
|
200 |
+
text_encoder_lora_scale = (
|
201 |
+
cross_attention_kwargs.get("scale", None)
|
202 |
+
if cross_attention_kwargs is not None
|
203 |
+
else None
|
204 |
+
)
|
205 |
+
|
206 |
+
# prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
207 |
+
# prompt,
|
208 |
+
# device,
|
209 |
+
# num_images_per_prompt,
|
210 |
+
# do_classifier_free_guidance,
|
211 |
+
# negative_prompt,
|
212 |
+
# prompt_embeds=prompt_embeds,
|
213 |
+
# negative_prompt_embeds=negative_prompt_embeds,
|
214 |
+
# lora_scale=text_encoder_lora_scale,
|
215 |
+
# )
|
216 |
+
|
217 |
+
prompt_embeds, negative_prompt_embeds = keyframed_prompt_embeds(
|
218 |
+
bundle, self.encode_prompt, device
|
219 |
+
)
|
220 |
+
|
221 |
+
# For classifier free guidance, we need to do two forward passes.
|
222 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
223 |
+
# to avoid doing two forward passes
|
224 |
+
if do_classifier_free_guidance:
|
225 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
226 |
+
|
227 |
+
# 4. Prepare timesteps
|
228 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
229 |
+
timesteps = self.scheduler.timesteps
|
230 |
+
|
231 |
+
# 5. Prepare latent variables
|
232 |
+
num_channels_latents = self.unet.config.in_channels
|
233 |
+
latents = self.prepare_latents(
|
234 |
+
batch_size * num_images_per_prompt,
|
235 |
+
num_channels_latents,
|
236 |
+
num_frames,
|
237 |
+
height,
|
238 |
+
width,
|
239 |
+
prompt_embeds.dtype,
|
240 |
+
device,
|
241 |
+
generator,
|
242 |
+
latents,
|
243 |
+
)
|
244 |
+
|
245 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
246 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
247 |
+
|
248 |
+
# 7. Denoising loop
|
249 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
250 |
+
|
251 |
+
latents_at_steps = []
|
252 |
+
|
253 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
254 |
+
for i, t in enumerate(timesteps):
|
255 |
+
# expand the latents if we are doing classifier free guidance
|
256 |
+
latent_model_input = (
|
257 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
258 |
+
)
|
259 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
260 |
+
|
261 |
+
# predict the noise residual
|
262 |
+
if i < (num_dd_spatial_steps):
|
263 |
+
use_dd(self.unet, True)
|
264 |
+
|
265 |
+
if i < (num_dd_temporal_steps):
|
266 |
+
use_dd_temporal(self.unet, True)
|
267 |
+
|
268 |
+
noise_pred = self.unet(
|
269 |
+
latent_model_input,
|
270 |
+
t,
|
271 |
+
encoder_hidden_states=prompt_embeds,
|
272 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
273 |
+
return_dict=False,
|
274 |
+
)[0]
|
275 |
+
|
276 |
+
use_dd(self.unet, False)
|
277 |
+
use_dd_temporal(self.unet, False)
|
278 |
+
|
279 |
+
# perform guidance
|
280 |
+
if do_classifier_free_guidance:
|
281 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
282 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
283 |
+
noise_pred_text - noise_pred_uncond
|
284 |
+
)
|
285 |
+
|
286 |
+
# reshape latents
|
287 |
+
bsz, channel, frames, width, height = latents.shape
|
288 |
+
latents = latents.permute(0, 2, 1, 3, 4).reshape(
|
289 |
+
bsz * frames, channel, width, height
|
290 |
+
)
|
291 |
+
noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(
|
292 |
+
bsz * frames, channel, width, height
|
293 |
+
)
|
294 |
+
|
295 |
+
# compute the previous noisy sample x_t -> x_t-1
|
296 |
+
latents = self.scheduler.step(
|
297 |
+
noise_pred, t, latents, **extra_step_kwargs
|
298 |
+
).prev_sample
|
299 |
+
|
300 |
+
# if i==num_dd_steps:
|
301 |
+
# print("PF!", latents.shape)
|
302 |
+
# n = latents.shape[0]
|
303 |
+
# for f in range(n):
|
304 |
+
# latents[f] = torch.roll(latents[f], -f, dims=-1)
|
305 |
+
|
306 |
+
# reshape latents back
|
307 |
+
latents = (
|
308 |
+
latents[None, :]
|
309 |
+
.reshape(bsz, frames, channel, width, height)
|
310 |
+
.permute(0, 2, 1, 3, 4)
|
311 |
+
)
|
312 |
+
latents_at_steps.append(latents)
|
313 |
+
|
314 |
+
# call the callback, if provided
|
315 |
+
if i == len(timesteps) - 1 or (
|
316 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
317 |
+
):
|
318 |
+
progress_bar.update()
|
319 |
+
if callback is not None and i % callback_steps == 0:
|
320 |
+
callback(i, t, latents)
|
321 |
+
|
322 |
+
if output_type == "latent":
|
323 |
+
return TextToVideoSDPipelineOutput(frames=latents)
|
324 |
+
|
325 |
+
video_tensor = self.decode_latents(latents)
|
326 |
+
|
327 |
+
if output_type == "pt":
|
328 |
+
video = video_tensor
|
329 |
+
else:
|
330 |
+
video = tensor2vid(video_tensor)
|
331 |
+
|
332 |
+
# Offload all models
|
333 |
+
self.maybe_free_model_hooks()
|
334 |
+
|
335 |
+
if not return_dict:
|
336 |
+
return (video,)
|
337 |
+
|
338 |
+
latents_at_steps = torch.cat(latents_at_steps)
|
339 |
+
return TextToVideoSDPipelineOutput(frames=video, latents=latents_at_steps, bbox_per_frame=bbox_per_frame)
|
TrailBlazer/Pipeline/UNet3DConditionModelCall.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
2 |
+
# Copyright 2023 The ModelScope Team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
|
22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
24 |
+
from diffusers.utils import BaseOutput, logging
|
25 |
+
from diffusers.models.attention_processor import (
|
26 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
27 |
+
CROSS_ATTENTION_PROCESSORS,
|
28 |
+
AttentionProcessor,
|
29 |
+
AttnAddedKVProcessor,
|
30 |
+
AttnProcessor,
|
31 |
+
)
|
32 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
33 |
+
from diffusers.models.modeling_utils import ModelMixin
|
34 |
+
from diffusers.models.transformer_temporal import TransformerTemporalModel
|
35 |
+
from diffusers.models.unet_3d_blocks import (
|
36 |
+
CrossAttnDownBlock3D,
|
37 |
+
CrossAttnUpBlock3D,
|
38 |
+
DownBlock3D,
|
39 |
+
UNetMidBlock3DCrossAttn,
|
40 |
+
UpBlock3D,
|
41 |
+
get_down_block,
|
42 |
+
get_up_block,
|
43 |
+
)
|
44 |
+
from diffusers.models.unet_3d_condition import UNet3DConditionOutput
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
def unet3d_condition_model_forward(
|
49 |
+
self,
|
50 |
+
sample: torch.FloatTensor,
|
51 |
+
timestep: Union[torch.Tensor, float, int],
|
52 |
+
encoder_hidden_states: torch.Tensor,
|
53 |
+
class_labels: Optional[torch.Tensor] = None,
|
54 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
55 |
+
attention_mask: Optional[torch.Tensor] = None,
|
56 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
57 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
58 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
59 |
+
return_dict: bool = True,
|
60 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
61 |
+
r"""
|
62 |
+
The [`UNet3DConditionModel`] forward method.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
sample (`torch.FloatTensor`):
|
66 |
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
|
67 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
68 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
69 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
70 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
71 |
+
Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
|
72 |
+
tuple.
|
73 |
+
cross_attention_kwargs (`dict`, *optional*):
|
74 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
[`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
|
78 |
+
If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
|
79 |
+
a `tuple` is returned where the first element is the sample tensor.
|
80 |
+
"""
|
81 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
82 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
83 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
84 |
+
# on the fly if necessary.
|
85 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
86 |
+
|
87 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
88 |
+
forward_upsample_size = False
|
89 |
+
upsample_size = None
|
90 |
+
|
91 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
92 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
93 |
+
forward_upsample_size = True
|
94 |
+
|
95 |
+
# prepare attention_mask
|
96 |
+
if attention_mask is not None:
|
97 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
98 |
+
attention_mask = attention_mask.unsqueeze(1)
|
99 |
+
|
100 |
+
# 1. time
|
101 |
+
timesteps = timestep
|
102 |
+
if not torch.is_tensor(timesteps):
|
103 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
104 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
105 |
+
is_mps = sample.device.type == "mps"
|
106 |
+
if isinstance(timestep, float):
|
107 |
+
dtype = torch.float32 if is_mps else torch.float64
|
108 |
+
else:
|
109 |
+
dtype = torch.int32 if is_mps else torch.int64
|
110 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
111 |
+
elif len(timesteps.shape) == 0:
|
112 |
+
timesteps = timesteps[None].to(sample.device)
|
113 |
+
|
114 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
115 |
+
num_frames = sample.shape[2]
|
116 |
+
timesteps = timesteps.expand(sample.shape[0])
|
117 |
+
|
118 |
+
t_emb = self.time_proj(timesteps)
|
119 |
+
|
120 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
121 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
122 |
+
# there might be better ways to encapsulate this.
|
123 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
124 |
+
|
125 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
126 |
+
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
127 |
+
# encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
128 |
+
# print(encoder_hidden_states.shape)
|
129 |
+
# quit()
|
130 |
+
|
131 |
+
# 2. pre-process
|
132 |
+
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
133 |
+
sample = self.conv_in(sample)
|
134 |
+
|
135 |
+
sample = self.transformer_in(
|
136 |
+
sample,
|
137 |
+
num_frames=num_frames,
|
138 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
139 |
+
return_dict=False,
|
140 |
+
)[0]
|
141 |
+
|
142 |
+
# 3. down
|
143 |
+
down_block_res_samples = (sample,)
|
144 |
+
for downsample_block in self.down_blocks:
|
145 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
146 |
+
sample, res_samples = downsample_block(
|
147 |
+
hidden_states=sample,
|
148 |
+
temb=emb,
|
149 |
+
encoder_hidden_states=encoder_hidden_states,
|
150 |
+
attention_mask=attention_mask,
|
151 |
+
num_frames=num_frames,
|
152 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
|
156 |
+
|
157 |
+
down_block_res_samples += res_samples
|
158 |
+
|
159 |
+
if down_block_additional_residuals is not None:
|
160 |
+
new_down_block_res_samples = ()
|
161 |
+
|
162 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
163 |
+
down_block_res_samples, down_block_additional_residuals
|
164 |
+
):
|
165 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
166 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
167 |
+
|
168 |
+
down_block_res_samples = new_down_block_res_samples
|
169 |
+
|
170 |
+
# 4. mid
|
171 |
+
if self.mid_block is not None:
|
172 |
+
sample = self.mid_block(
|
173 |
+
sample,
|
174 |
+
emb,
|
175 |
+
encoder_hidden_states=encoder_hidden_states,
|
176 |
+
attention_mask=attention_mask,
|
177 |
+
num_frames=num_frames,
|
178 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
179 |
+
)
|
180 |
+
|
181 |
+
if mid_block_additional_residual is not None:
|
182 |
+
sample = sample + mid_block_additional_residual
|
183 |
+
|
184 |
+
# 5. up
|
185 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
186 |
+
is_final_block = i == len(self.up_blocks) - 1
|
187 |
+
|
188 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
189 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
190 |
+
|
191 |
+
# if we have not reached the final block and need to forward the
|
192 |
+
# upsample size, we do it here
|
193 |
+
if not is_final_block and forward_upsample_size:
|
194 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
195 |
+
|
196 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
197 |
+
sample = upsample_block(
|
198 |
+
hidden_states=sample,
|
199 |
+
temb=emb,
|
200 |
+
res_hidden_states_tuple=res_samples,
|
201 |
+
encoder_hidden_states=encoder_hidden_states,
|
202 |
+
upsample_size=upsample_size,
|
203 |
+
attention_mask=attention_mask,
|
204 |
+
num_frames=num_frames,
|
205 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
sample = upsample_block(
|
209 |
+
hidden_states=sample,
|
210 |
+
temb=emb,
|
211 |
+
res_hidden_states_tuple=res_samples,
|
212 |
+
upsample_size=upsample_size,
|
213 |
+
num_frames=num_frames,
|
214 |
+
)
|
215 |
+
|
216 |
+
# 6. post-process
|
217 |
+
if self.conv_norm_out:
|
218 |
+
sample = self.conv_norm_out(sample)
|
219 |
+
sample = self.conv_act(sample)
|
220 |
+
|
221 |
+
sample = self.conv_out(sample)
|
222 |
+
|
223 |
+
# reshape to (batch, channel, framerate, width, height)
|
224 |
+
sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
|
225 |
+
|
226 |
+
if not return_dict:
|
227 |
+
return (sample,)
|
228 |
+
|
229 |
+
return UNet3DConditionOutput(sample=sample)
|
TrailBlazer/Pipeline/Utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
10 |
+
from diffusers.models import AutoencoderKL, UNet3DConditionModel
|
11 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
12 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
13 |
+
from diffusers.utils import (
|
14 |
+
deprecate,
|
15 |
+
logging,
|
16 |
+
replace_example_docstring,
|
17 |
+
BaseOutput,
|
18 |
+
)
|
19 |
+
from diffusers.utils.torch_utils import randn_tensor
|
20 |
+
from diffusers.pipeline_utils import DiffusionPipeline
|
21 |
+
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import (
|
22 |
+
tensor2vid,
|
23 |
+
)
|
24 |
+
from ..CrossAttn.InjecterProc import InjecterProcessor
|
25 |
+
from ..Misc import Logger as log
|
26 |
+
from ..Misc import Const
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
def use_dd_temporal(unet, use=True):
|
32 |
+
""" To determine using the temporal attention editing at a step
|
33 |
+
"""
|
34 |
+
for name, module in unet.named_modules():
|
35 |
+
module_name = type(module).__name__
|
36 |
+
if module_name == "Attention" and "attn2" in name:
|
37 |
+
module.processor.use_dd_temporal = use
|
38 |
+
|
39 |
+
|
40 |
+
def use_dd(unet, use=True):
|
41 |
+
""" To determine using the spatial attention editing at a step
|
42 |
+
"""
|
43 |
+
for name, module in unet.named_modules():
|
44 |
+
module_name = type(module).__name__
|
45 |
+
# if module_name == "CrossAttention" and "attn2" in name:
|
46 |
+
if module_name == "Attention" and "attn2" in name:
|
47 |
+
module.processor.use_dd = use
|
48 |
+
|
49 |
+
|
50 |
+
def initiailization(unet, bundle, bbox_per_frame):
|
51 |
+
log.info("Intialization")
|
52 |
+
|
53 |
+
for name, module in unet.named_modules():
|
54 |
+
module_name = type(module).__name__
|
55 |
+
if module_name == "Attention" and "attn2" in name:
|
56 |
+
if "temp_attentions" in name:
|
57 |
+
processor = InjecterProcessor(
|
58 |
+
bundle=bundle,
|
59 |
+
bbox_per_frame=bbox_per_frame,
|
60 |
+
strengthen_scale=bundle["temp_strengthen_scale"],
|
61 |
+
weaken_scale=bundle["temp_weaken_scale"],
|
62 |
+
is_text2vidzero=False,
|
63 |
+
name=name,
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
processor = InjecterProcessor(
|
67 |
+
bundle=bundle,
|
68 |
+
bbox_per_frame=bbox_per_frame,
|
69 |
+
strengthen_scale=bundle["spatial_strengthen_scale"],
|
70 |
+
weaken_scale=bundle["spatial_weaken_scale"],
|
71 |
+
is_text2vidzero=False,
|
72 |
+
name=name,
|
73 |
+
)
|
74 |
+
module.processor = processor
|
75 |
+
# print(name)
|
76 |
+
log.info("Initialized")
|
77 |
+
|
78 |
+
|
79 |
+
def keyframed_prompt_embeds(bundle, encode_prompt_func, device):
|
80 |
+
num_frames = bundle["keyframe"][-1]["frame"] + 1
|
81 |
+
keyframe = bundle["keyframe"]
|
82 |
+
f = lambda start, end, index: (1 - index) * start + index * end
|
83 |
+
n = len(keyframe)
|
84 |
+
keyed_prompt_embeds = []
|
85 |
+
for i in range(n - 1):
|
86 |
+
if i == 0:
|
87 |
+
start_fr = keyframe[i]["frame"]
|
88 |
+
else:
|
89 |
+
start_fr = keyframe[i]["frame"] + 1
|
90 |
+
end_fr = keyframe[i + 1]["frame"]
|
91 |
+
|
92 |
+
start_prompt = keyframe[i]["prompt"] + Const.POSITIVE_PROMPT
|
93 |
+
end_prompt = keyframe[i + 1]["prompt"] + Const.POSITIVE_PROMPT
|
94 |
+
clip_length = end_fr - start_fr + 1
|
95 |
+
|
96 |
+
start_prompt_embeds, _ = encode_prompt_func(
|
97 |
+
start_prompt,
|
98 |
+
device=device,
|
99 |
+
num_images_per_prompt=1,
|
100 |
+
do_classifier_free_guidance=True,
|
101 |
+
negative_prompt=Const.NEGATIVE_PROMPT,
|
102 |
+
)
|
103 |
+
|
104 |
+
end_prompt_embeds, negative_prompt_embeds = encode_prompt_func(
|
105 |
+
end_prompt,
|
106 |
+
device=device,
|
107 |
+
num_images_per_prompt=1,
|
108 |
+
do_classifier_free_guidance=True,
|
109 |
+
negative_prompt=Const.NEGATIVE_PROMPT,
|
110 |
+
)
|
111 |
+
|
112 |
+
for fr in range(clip_length):
|
113 |
+
index = float(fr) / (clip_length - 1)
|
114 |
+
keyed_prompt_embeds.append(f(start_prompt_embeds, end_prompt_embeds, index))
|
115 |
+
assert len(keyed_prompt_embeds) == num_frames
|
116 |
+
|
117 |
+
return torch.cat(keyed_prompt_embeds), negative_prompt_embeds.repeat_interleave(
|
118 |
+
num_frames, dim=0
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
def keyframed_bbox(bundle):
|
123 |
+
|
124 |
+
keyframe = bundle["keyframe"]
|
125 |
+
bbox_per_frame = []
|
126 |
+
f = lambda start, end, index: (1 - index) * start + index * end
|
127 |
+
n = len(keyframe)
|
128 |
+
for i in range(n - 1):
|
129 |
+
if i == 0:
|
130 |
+
start_fr = keyframe[i]["frame"]
|
131 |
+
else:
|
132 |
+
start_fr = keyframe[i]["frame"] + 1
|
133 |
+
end_fr = keyframe[i + 1]["frame"]
|
134 |
+
start_bbox = keyframe[i]["bbox_ratios"]
|
135 |
+
end_bbox = keyframe[i + 1]["bbox_ratios"]
|
136 |
+
clip_length = end_fr - start_fr + 1
|
137 |
+
for fr in range(clip_length):
|
138 |
+
index = float(fr) / (clip_length - 1)
|
139 |
+
bbox = []
|
140 |
+
for j in range(4):
|
141 |
+
bbox.append(f(start_bbox[j], end_bbox[j], index))
|
142 |
+
bbox_per_frame.append(bbox)
|
143 |
+
|
144 |
+
return bbox_per_frame
|
TrailBlazer/Pipeline/__init__.py
ADDED
File without changes
|
TrailBlazer/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# TrailBlazer - Codebase
|
TrailBlazer/Setting/Config.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
|
4 |
+
DEVICE = "cuda"
|
5 |
+
GUIDANCE_SCALE = 7.5
|
6 |
+
WIDTH = 512
|
7 |
+
HEIGHT = 512
|
8 |
+
NUM_BACKWARD_STEPS = 50
|
9 |
+
STEPS = 50
|
10 |
+
DTYPE = torch.float16
|
11 |
+
|
12 |
+
MODEL_HOME = f"{os.path.expanduser('~')}/Workspace/Project/Models"
|
13 |
+
|
14 |
+
NEGATIVE_PROMPT = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
|
15 |
+
POSITIVE_PROMPT = "best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
|
16 |
+
|
17 |
+
|
18 |
+
SD_V1_5_ID = "runwayml/stable-diffusion-v1-5"
|
19 |
+
SD_V1_5_PATH = f"{MODEL_HOME}/{SD_V1_5_ID}"
|
20 |
+
CNET_CANNY_ID = "lllyasviel/sd-controlnet-canny"
|
21 |
+
CNET_CANNY_PATH = f"{MODEL_HOME}/{CNET_CANNY_ID}"
|
22 |
+
CNET_OPENPOSE_ID = "lllyasviel/sd-controlnet-openpose"
|
23 |
+
CNET_OPENPOSE_PATH = f"{MODEL_HOME}/{CNET_OPENPOSE_ID}"
|
TrailBlazer/Setting/Const.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
RECONS_NAME = "recons.jpg"
|
2 |
+
LATENTS_NAME = "latents.pt"
|
3 |
+
CATTN_NAME = "cattn.pt"
|
4 |
+
CATTN_VIZ_NAME = "cattn.jpg"
|
TrailBlazer/Setting/__init__.py
ADDED
File without changes
|
TrailBlazer/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# # VideoDiffusion
|
2 |
+
# from .Pipeline.Dumnmy import DummyPipeline
|
3 |
+
# from .Pipeline.Standard import StandardPipeline
|
4 |
+
# from .Pipeline.ControlNet import ControlNetPipeline
|
5 |
+
# from .Pipeline.Img2Img import Img2ImgPipeline
|
6 |
+
# from .Pipeline.Video import VideoPipeline
|
7 |
+
|
8 |
+
# from .Pipeline.TestMayaNoise import TestMayaNoisePipeline
|
app.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image, ImageOps, ImageDraw, ImageFont, ImageColor
|
7 |
+
from urllib.request import urlopen
|
8 |
+
|
9 |
+
root = os.path.dirname(os.path.abspath(__file__))
|
10 |
+
static = os.path.join(root, "static")
|
11 |
+
|
12 |
+
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
13 |
+
from diffusers.pipelines import TextToVideoSDPipeline
|
14 |
+
from diffusers.utils import export_to_video
|
15 |
+
from TrailBlazer.Misc import ConfigIO
|
16 |
+
from TrailBlazer.Misc import Logger as log
|
17 |
+
from TrailBlazer.Pipeline.TextToVideoSDPipelineCall import (
|
18 |
+
text_to_video_sd_pipeline_call,
|
19 |
+
)
|
20 |
+
from TrailBlazer.Pipeline.UNet3DConditionModelCall import (
|
21 |
+
unet3d_condition_model_forward,
|
22 |
+
)
|
23 |
+
|
24 |
+
TextToVideoSDPipeline.__call__ = text_to_video_sd_pipeline_call
|
25 |
+
from diffusers.models.unet_3d_condition import UNet3DConditionModel
|
26 |
+
|
27 |
+
unet3d_condition_model_forward_copy = UNet3DConditionModel.forward
|
28 |
+
UNet3DConditionModel.forward = unet3d_condition_model_forward
|
29 |
+
|
30 |
+
|
31 |
+
from diffusers.utils import export_to_video
|
32 |
+
|
33 |
+
model_id = "cerspense/zeroscope_v2_576w"
|
34 |
+
model_path = model_id
|
35 |
+
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
|
36 |
+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
37 |
+
pipe.enable_model_cpu_offload()
|
38 |
+
|
39 |
+
def core(bundle):
|
40 |
+
|
41 |
+
generator = torch.Generator().manual_seed(int(bundle["seed"]))
|
42 |
+
result = pipe(
|
43 |
+
bundle=bundle,
|
44 |
+
height=512,
|
45 |
+
width=512,
|
46 |
+
generator=generator,
|
47 |
+
num_inference_steps=40,
|
48 |
+
)
|
49 |
+
return result.frames
|
50 |
+
|
51 |
+
|
52 |
+
def clear_btn_fn():
|
53 |
+
return "", "", "", ""
|
54 |
+
|
55 |
+
|
56 |
+
def gen_btn_fn(
|
57 |
+
prompts,
|
58 |
+
bboxes,
|
59 |
+
frames,
|
60 |
+
word_prompt_indices,
|
61 |
+
trailing_length,
|
62 |
+
n_spatial_steps,
|
63 |
+
n_temporal_steps,
|
64 |
+
spatial_strengthen_scale,
|
65 |
+
spatial_weaken_scale,
|
66 |
+
temporal_strengthen_scale,
|
67 |
+
temporal_weaken_scale,
|
68 |
+
rand_seed,
|
69 |
+
):
|
70 |
+
|
71 |
+
bundle = {}
|
72 |
+
bundle["trailing_length"] = trailing_length
|
73 |
+
bundle["num_dd_spatial_steps"] = n_spatial_steps
|
74 |
+
bundle["num_dd_temporal_steps"] = n_temporal_steps
|
75 |
+
bundle["num_frames"] = 24
|
76 |
+
bundle["seed"] = rand_seed
|
77 |
+
bundle["spatial_strengthen_scale"] = spatial_strengthen_scale
|
78 |
+
bundle["spatial_weaken_scale"] = spatial_weaken_scale
|
79 |
+
bundle["temp_strengthen_scale"] = temporal_strengthen_scale
|
80 |
+
bundle["temp_weaken_scale"] = temporal_weaken_scale
|
81 |
+
bundle["token_inds"] = [int(v) for v in word_prompt_indices.split(",")]
|
82 |
+
|
83 |
+
bundle["keyframe"] = []
|
84 |
+
frames = frames.split(";")
|
85 |
+
bboxes = bboxes.split(";")
|
86 |
+
if ";" in prompts:
|
87 |
+
prompts = prompts.split(";")
|
88 |
+
else:
|
89 |
+
prompts = [prompts for i in range(len(frames))]
|
90 |
+
|
91 |
+
assert (
|
92 |
+
len(frames) == len(bboxes) == len(prompts)
|
93 |
+
), "Inconsistent number of keyframes in the given inputs."
|
94 |
+
|
95 |
+
frames.pop()
|
96 |
+
bboxes.pop()
|
97 |
+
prompts.pop()
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
for i in range(len(frames)):
|
102 |
+
keyframe = {}
|
103 |
+
keyframe["bbox_ratios"] = [float(v) for v in bboxes[i].split(",")]
|
104 |
+
keyframe["frame"] = int(frames[i])
|
105 |
+
keyframe["prompt"] = prompts[i]
|
106 |
+
bundle["keyframe"].append(keyframe)
|
107 |
+
print(bundle)
|
108 |
+
result = core(bundle)
|
109 |
+
path = export_to_video(result)
|
110 |
+
return path
|
111 |
+
|
112 |
+
|
113 |
+
def save_mask(inputs):
|
114 |
+
layers = inputs["layers"]
|
115 |
+
if not layers:
|
116 |
+
return inputs["background"]
|
117 |
+
mask = layers[0]
|
118 |
+
new_image = Image.new("RGBA", mask.size, color="white")
|
119 |
+
new_image.paste(mask, mask=mask)
|
120 |
+
new_image = new_image.convert("RGB")
|
121 |
+
print("SAve")
|
122 |
+
return ImageOps.invert(new_image)
|
123 |
+
|
124 |
+
|
125 |
+
def out_label_cb(im):
|
126 |
+
layers = im["layers"]
|
127 |
+
if not isinstance(layers, list):
|
128 |
+
layers = [layers]
|
129 |
+
|
130 |
+
img = None
|
131 |
+
text = "Bboxes: "
|
132 |
+
for idx, layer in enumerate(layers):
|
133 |
+
mask = np.array(layer).sum(axis=-1)
|
134 |
+
ys, xs = np.where(mask != 0)
|
135 |
+
h, w = mask.shape
|
136 |
+
if not list(xs) or not list(ys):
|
137 |
+
continue
|
138 |
+
x_min = np.min(xs)
|
139 |
+
x_max = np.max(xs)
|
140 |
+
y_min = np.min(ys)
|
141 |
+
y_max = np.max(ys)
|
142 |
+
|
143 |
+
text += "{:.2f},{:.2f},{:.2f},{:.2f}".format(
|
144 |
+
x_min * 1.0 / w, y_min * 1.0 / h, x_max * 1.0 / w, y_max * 1.0 / h
|
145 |
+
)
|
146 |
+
text += ";\n"
|
147 |
+
return text
|
148 |
+
|
149 |
+
|
150 |
+
def out_board_cb(im):
|
151 |
+
|
152 |
+
layers = im["layers"]
|
153 |
+
if not isinstance(layers, list):
|
154 |
+
layers = [layers]
|
155 |
+
|
156 |
+
img = None
|
157 |
+
for idx, layer in enumerate(layers):
|
158 |
+
mask = np.array(layer).sum(axis=-1)
|
159 |
+
ys, xs = np.where(mask != 0)
|
160 |
+
|
161 |
+
if not list(xs) or not list(ys):
|
162 |
+
continue
|
163 |
+
|
164 |
+
h, w = mask.shape
|
165 |
+
if not img:
|
166 |
+
img = Image.new("RGBA", (w, h))
|
167 |
+
x_min = np.min(xs)
|
168 |
+
x_max = np.max(xs)
|
169 |
+
y_min = np.min(ys)
|
170 |
+
y_max = np.max(ys)
|
171 |
+
|
172 |
+
# output
|
173 |
+
shape = [(x_min, y_min), (x_max, y_max)]
|
174 |
+
colors = list(ImageColor.colormap.keys())
|
175 |
+
draw = ImageDraw.Draw(img)
|
176 |
+
draw.rectangle(shape, outline=colors[idx], width=5)
|
177 |
+
text = "Bbox#{}".format(idx)
|
178 |
+
font = ImageFont.load_default()
|
179 |
+
draw.text((x_max - 0.5 * (x_max - x_min), y_max), text, font=font, align="left")
|
180 |
+
|
181 |
+
return img
|
182 |
+
|
183 |
+
|
184 |
+
with gr.Blocks(
|
185 |
+
analytics_enabled=False,
|
186 |
+
title="TrailBlazer Demo",
|
187 |
+
) as main:
|
188 |
+
|
189 |
+
description = """
|
190 |
+
<h1 align="center" style="font-size: 48px">TrailBlazer: Trajectory Control for Diffusion-Based Video Generation</h1>
|
191 |
+
<h4 align="center" style="margin: 0;">If you like our project, please give us a star β¨ at our Huggingface space, and our Github repository.</h4>
|
192 |
+
<br>
|
193 |
+
<span align="center" style="font-size: 18px">
|
194 |
+
[<a href="https://hohonu-vicml.github.io/Trailblazer.Page/" target="_blank">Project Page</a>]
|
195 |
+
[<a href="http://arxiv.org/abs/2401.00896" target="_blank">Paper</a>]
|
196 |
+
[<a href="https://github.com/hohonu-vicml/Trailblazer" target="_blank">GitHub</a>]
|
197 |
+
[<a href="https://www.youtube.com/watch?v=kEN-32wN-xQ" target="_blank">Project Video</a>]
|
198 |
+
[<a href="https://www.youtube.com/watch?v=P-PSkS7sNco" target="_blank">Result Video</a>]
|
199 |
+
</span>
|
200 |
+
</p>
|
201 |
+
<p>
|
202 |
+
<strong>Usage:</strong> Our Gradio app is implemented based on our executable script CmdTrailBlazer in our github repository. Please see our general information below for a quick guidance, as well as the hints within the app widgets.
|
203 |
+
<ul>
|
204 |
+
<li>Basic: The bounding box (bbox) is the tuple of four floats for the rectangular corners: left, top, right, bottom in the normalized ratio. The Word prompt indices is a list of 1-indexed numbers determining the prompt word.</li>
|
205 |
+
<li>Advanced Options: We also offer some key parameters to adjust the synthesis result. Please see our paper for more information about the ablations.</li>
|
206 |
+
</ul>
|
207 |
+
</p>
|
208 |
+
"""
|
209 |
+
gr.HTML(description)
|
210 |
+
|
211 |
+
with gr.Row():
|
212 |
+
with gr.Column(scale=2):
|
213 |
+
with gr.Row():
|
214 |
+
with gr.Tab("Main"):
|
215 |
+
text_prompt_tb = gr.Textbox(
|
216 |
+
interactive=True, label="Keyframe: Prompt"
|
217 |
+
)
|
218 |
+
bboxes_tb = gr.Textbox(interactive=True, label="Keyframe: Bboxes")
|
219 |
+
frame_tb = gr.Textbox(
|
220 |
+
interactive=True, label="Keyframe: frame indices"
|
221 |
+
)
|
222 |
+
with gr.Row():
|
223 |
+
word_prompt_indices_tb = gr.Textbox(
|
224 |
+
interactive=True, label="Word prompt indices:"
|
225 |
+
)
|
226 |
+
text = "Hint: Each keyframe ends with <strong>SEMICOLON</strong>, and <strong>COMMA</strong> for separating each value in the keyframe. The prompt field can be a single prompt without semicolon, or multiple prompts ended semicolon. One can use the SketchPadHelper tab to help to design the bboxes field."
|
227 |
+
gr.HTML(text)
|
228 |
+
with gr.Row():
|
229 |
+
clear_btn = gr.Button(value="Clear")
|
230 |
+
gen_btn = gr.Button(value="Generate")
|
231 |
+
|
232 |
+
with gr.Accordion("Advanced Options", open=False):
|
233 |
+
text = "Hint: This default value should be sufficient for most tasks. However, it's important to note that our approach is currently implemented on ZeroScope, and its performance may be influenced by the model's characteristics. We plan to conduct experiments on different models in the future."
|
234 |
+
gr.HTML(text)
|
235 |
+
with gr.Row():
|
236 |
+
trailing_length = gr.Slider(
|
237 |
+
minimum=0,
|
238 |
+
maximum=30,
|
239 |
+
step=1,
|
240 |
+
value=13,
|
241 |
+
interactive=True,
|
242 |
+
label="#Trailing",
|
243 |
+
)
|
244 |
+
n_spatial_steps = gr.Slider(
|
245 |
+
minimum=0,
|
246 |
+
maximum=30,
|
247 |
+
step=1,
|
248 |
+
value=5,
|
249 |
+
interactive=True,
|
250 |
+
label="#Spatial edits",
|
251 |
+
)
|
252 |
+
n_temporal_steps = gr.Slider(
|
253 |
+
minimum=0,
|
254 |
+
maximum=30,
|
255 |
+
step=1,
|
256 |
+
value=5,
|
257 |
+
interactive=True,
|
258 |
+
label="#Temporal edits",
|
259 |
+
)
|
260 |
+
with gr.Row():
|
261 |
+
spatial_strengthen_scale = gr.Slider(
|
262 |
+
minimum=0,
|
263 |
+
maximum=2,
|
264 |
+
step=0.01,
|
265 |
+
value=0.15,
|
266 |
+
interactive=True,
|
267 |
+
label="Spatial Strengthen Scale",
|
268 |
+
)
|
269 |
+
spatial_weaken_scale = gr.Slider(
|
270 |
+
minimum=0,
|
271 |
+
maximum=1,
|
272 |
+
step=0.01,
|
273 |
+
value=0.001,
|
274 |
+
interactive=True,
|
275 |
+
label="Spatial Weaken Scale",
|
276 |
+
)
|
277 |
+
temporal_strengthen_scale = gr.Slider(
|
278 |
+
minimum=0,
|
279 |
+
maximum=2,
|
280 |
+
step=0.01,
|
281 |
+
value=0.15,
|
282 |
+
interactive=True,
|
283 |
+
label="Temporal Strengthen Scale",
|
284 |
+
)
|
285 |
+
temporal_weaken_scale = gr.Slider(
|
286 |
+
minimum=0,
|
287 |
+
maximum=1,
|
288 |
+
step=0.01,
|
289 |
+
value=0.001,
|
290 |
+
interactive=True,
|
291 |
+
label="Temporal Weaken Scale",
|
292 |
+
)
|
293 |
+
|
294 |
+
with gr.Row():
|
295 |
+
guidance_scale = gr.Slider(
|
296 |
+
minimum=0,
|
297 |
+
maximum=50,
|
298 |
+
step=0.5,
|
299 |
+
value=7.5,
|
300 |
+
interactive=True,
|
301 |
+
label="Guidance Scale",
|
302 |
+
)
|
303 |
+
rand_seed = gr.Slider(
|
304 |
+
minimum=0,
|
305 |
+
maximum=523451232531,
|
306 |
+
step=1,
|
307 |
+
value=0,
|
308 |
+
interactive=True,
|
309 |
+
label="Seed",
|
310 |
+
)
|
311 |
+
|
312 |
+
with gr.Tab("SketchPadHelper"):
|
313 |
+
with gr.Row():
|
314 |
+
user_board = gr.ImageMask(type="pil", label="Draw me")
|
315 |
+
out_board = gr.Image(type="pil", label="Processed bbox")
|
316 |
+
user_board.change(
|
317 |
+
out_board_cb, inputs=[user_board], outputs=[out_board]
|
318 |
+
)
|
319 |
+
with gr.Row():
|
320 |
+
text = "Hint: Utilize a black pen with the Draw Button to create a ``rough'' bbox. When you press the green ``Save Changes'' Button, the app calculates the minimum and maximum boundaries. Each ``Layer'', located at the bottom left of the pad, corresponds to one bounding box. Copy the returned value to the bbox textfield in the main tab."
|
321 |
+
gr.HTML(text)
|
322 |
+
with gr.Row():
|
323 |
+
out_label = gr.Label(label="Converted bboxes string")
|
324 |
+
user_board.change(
|
325 |
+
out_label_cb, inputs=[user_board], outputs=[out_label]
|
326 |
+
)
|
327 |
+
|
328 |
+
with gr.Column(scale=1):
|
329 |
+
gr.HTML(
|
330 |
+
'<span style="font-size: 20px; font-weight: bold">Generated Images</span>'
|
331 |
+
)
|
332 |
+
with gr.Row():
|
333 |
+
out_gen_1 = gr.Video(visible=True, show_label=False)
|
334 |
+
|
335 |
+
with gr.Row():
|
336 |
+
gr.Examples(
|
337 |
+
examples=[
|
338 |
+
[
|
339 |
+
"A clown fish swimming in a coral reef",
|
340 |
+
"0.5,0.35,1.0,0.65; 0.0,0.35,0.5,0.65;",
|
341 |
+
"0; 24;",
|
342 |
+
"1,2,3",
|
343 |
+
"123451232531",
|
344 |
+
"assets/gradio/fish-RL.mp4",
|
345 |
+
],
|
346 |
+
[
|
347 |
+
"A cat is running on the grass",
|
348 |
+
"0.0,0.35,0.4,0.65; 0.6,0.35,1.0,0.65; 0.0,0.35,0.4,0.65;"
|
349 |
+
"0.6,0.35,1.0,0.65; 0.0,0.35,0.4,0.65;",
|
350 |
+
"0; 6; 12; 18; 24;",
|
351 |
+
"1,2",
|
352 |
+
"123451232530",
|
353 |
+
"assets/gradio/cat-LRLR.mp4",
|
354 |
+
],
|
355 |
+
[
|
356 |
+
"A fish swimming in the ocean",
|
357 |
+
"0.0,0.0,0.1,0.1; 0.5,0.5,1.0,1.0;",
|
358 |
+
"0; 24;",
|
359 |
+
"1, 2",
|
360 |
+
"0",
|
361 |
+
"assets/gradio/fish-TL2BR.mp4"
|
362 |
+
],
|
363 |
+
[
|
364 |
+
"A tiger walking alone down the street",
|
365 |
+
"0.0,0.0,0.1,0.1; 0.5,0.5,1.0,1.0;",
|
366 |
+
"0; 24;",
|
367 |
+
"1, 2",
|
368 |
+
"0",
|
369 |
+
"assets/gradio/tiger-TL2BR.mp4"
|
370 |
+
],
|
371 |
+
[
|
372 |
+
"A white cat walking on the grass; A yellow dog walking on the grass;",
|
373 |
+
"0.7,0.4,1.0,0.65; 0.0,0.4,0.3,0.65;",
|
374 |
+
"0; 24;",
|
375 |
+
"1,2,3",
|
376 |
+
"123451232531",
|
377 |
+
"assets/gradio/Cat2Dog.mp4",
|
378 |
+
],
|
379 |
+
],
|
380 |
+
inputs=[text_prompt_tb, bboxes_tb, frame_tb, word_prompt_indices_tb, rand_seed,out_gen_1],
|
381 |
+
outputs=None,
|
382 |
+
fn=None,
|
383 |
+
cache_examples=False,
|
384 |
+
)
|
385 |
+
|
386 |
+
clear_btn.click(
|
387 |
+
clear_btn_fn,
|
388 |
+
inputs=[],
|
389 |
+
outputs=[text_prompt_tb, bboxes_tb, frame_tb, word_prompt_indices_tb],
|
390 |
+
queue=False,
|
391 |
+
)
|
392 |
+
|
393 |
+
gen_btn.click(
|
394 |
+
gen_btn_fn,
|
395 |
+
inputs=[
|
396 |
+
text_prompt_tb,
|
397 |
+
bboxes_tb,
|
398 |
+
frame_tb,
|
399 |
+
word_prompt_indices_tb,
|
400 |
+
trailing_length,
|
401 |
+
n_spatial_steps,
|
402 |
+
n_temporal_steps,
|
403 |
+
spatial_strengthen_scale,
|
404 |
+
spatial_weaken_scale,
|
405 |
+
temporal_strengthen_scale,
|
406 |
+
temporal_weaken_scale,
|
407 |
+
rand_seed,
|
408 |
+
],
|
409 |
+
outputs=[out_gen_1],
|
410 |
+
queue=False,
|
411 |
+
)
|
412 |
+
|
413 |
+
|
414 |
+
if __name__ == "__main__":
|
415 |
+
main.launch(share=False)
|