Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import math | |
import xformers | |
class DummyController: | |
def __call__(self, *args): | |
return args[0] | |
def __init__(self): | |
self.num_att_layers = 0 | |
class GroupedCAController: | |
def __init__(self, mask_list = None): | |
self.mask_list = mask_list | |
if self.mask_list is None: | |
self.is_decom = False | |
else: | |
self.is_decom = True | |
def mask_img_to_mask_vec(self, mask, length): | |
mask_vec = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), (length, length)).squeeze() | |
mask_vec = mask_vec.flatten() | |
return mask_vec | |
def ca_forward_decom(self, q, k_list, v_list, scale, place_in_unet): | |
# attn [Bh, N, d ] | |
# [8, 4096, 77] | |
# q [Bh, N, d] [8, 4096, 40] [8, 1024, 80] [8, 256,160] [8, 64, 160] | |
# k [Bh, P, d] [8, 77 , 40] [8, 77, 80] [8, 77, 160] [8, 77, 160] | |
# v [Bh, P, d] [8, 77 , 40] [8, 77, 80] [8, 77, 160] [8, 77, 160] | |
N = q.shape[1] | |
mask_vec_list = [] | |
for mask in self.mask_list: | |
mask_vec = self.mask_img_to_mask_vec(mask, int(math.sqrt(N))) # [1,N,1] | |
mask_vec = mask_vec.unsqueeze(0).unsqueeze(-1) | |
mask_vec_list.append(mask_vec) | |
out = 0 | |
for mask_vec, k, v in zip(mask_vec_list, k_list, v_list): | |
sim = torch.einsum("b i d, b j d -> b i j", q, k) * scale # [8, 4096, 20] | |
attn = sim.softmax(dim=-1) # [Bh,N,P] [8,4096,20] | |
attn = attn.masked_fill(mask_vec==0, 0) | |
masked_out = torch.einsum("b i j, b j d -> b i d", attn, v) # [Bh,N,d] [8,4096,320/h] | |
# mask_vec_inf = torch.where(mask_vec>0, 0, torch.finfo(k.dtype).min) | |
# masked_out1 = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask_vec_inf, op=None, scale=scale) | |
out += masked_out | |
return out | |
def reshape_heads_to_batch_dim(self): | |
def func(tensor): | |
batch_size, seq_len, dim = tensor.shape | |
head_size = self.num_heads | |
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) | |
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) | |
return func | |
def reshape_batch_dim_to_heads(self): | |
def func(tensor): | |
batch_size, seq_len, dim = tensor.shape | |
head_size = self.num_heads | |
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) | |
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) | |
return func | |
def register_attention_disentangled_control(unet, controller): | |
def ca_forward(self, place_in_unet): | |
to_out = self.to_out | |
if type(to_out) is torch.nn.modules.container.ModuleList: | |
to_out = self.to_out[0] | |
else: | |
to_out = self.to_out | |
def forward(x, encoder_hidden_states =None, attention_mask=None): | |
if isinstance(controller, DummyController): # SA CA full | |
q = self.to_q(x) | |
is_cross = encoder_hidden_states is not None | |
encoder_hidden_states = encoder_hidden_states if is_cross else x | |
k = self.to_k(encoder_hidden_states) | |
v = self.to_v(encoder_hidden_states) | |
q = self.head_to_batch_dim(q) | |
k = self.head_to_batch_dim(k) | |
v = self.head_to_batch_dim(v) | |
# sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale | |
# attn = sim.softmax(dim=-1) | |
# attn = controller(attn, is_cross, place_in_unet) | |
# out = torch.einsum("b i j, b j d -> b i d", attn, v) | |
out = xformers.ops.memory_efficient_attention( | |
q, k, v, attn_bias=None, op=None, scale=self.scale | |
) | |
out = self.batch_to_head_dim(out) | |
else: # decom: CA+SA | |
is_cross = encoder_hidden_states is not None | |
assert is_cross is not None | |
encoder_hidden_states_list = encoder_hidden_states if is_cross else x | |
q = self.to_q(x) | |
q = self.head_to_batch_dim(q) # [Bh, 4096, 320/h ] h: 8 | |
if is_cross: #CA | |
k_list = [] | |
v_list = [] | |
assert type(encoder_hidden_states_list) is list | |
for encoder_hidden_states in encoder_hidden_states_list: | |
k = self.to_k(encoder_hidden_states) | |
k = self.head_to_batch_dim(k) # [Bh, 77, 320/h ] | |
k_list.append(k) | |
v = self.to_v(encoder_hidden_states) | |
v = self.head_to_batch_dim(v) # [Bh, 77, 320/h ] | |
v_list.append(v) | |
out = controller.ca_forward_decom(q, k_list, v_list, self.scale, place_in_unet) # [Bh,N,d] | |
out = self.batch_to_head_dim(out) | |
else: # SA | |
exit("decomposing SA!") | |
k = self.to_k(x) | |
v = self.to_v(x) | |
k = self.head_to_batch_dim(k) # [Bh, 77, 320/h ] | |
v = self.head_to_batch_dim(v) # [Bh, 77, 320/h ] | |
import pdb; pdb.set_trace() | |
if k.shape[1] <= 1024 ** 2: | |
out = controller.sa_forward(q, k, v, self.scale, place_in_unet) # [Bh,N,d] | |
else: | |
print("warining") | |
out = controller.sa_forward_decom(q, k, v, self.scale, place_in_unet) # [Bh,N,d] | |
# sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale | |
# attn = sim.softmax(dim=-1) # [8,4096,4096] [Bh,N,N] | |
# out = torch.einsum("b i j, b j d -> b i d", attn, v) # [Bh,N,d] [8,4096,320/h] | |
out = self.batch_to_head_dim(out) # [B, H, N, D] | |
return to_out(out) | |
return forward | |
if controller is None: | |
controller = DummyController() | |
def register_recr(net_, count, place_in_unet): | |
if net_.__class__.__name__ == 'Attention' and net_.to_k.in_features == unet.ca_dim: | |
net_.forward = ca_forward(net_, place_in_unet) | |
return count + 1 | |
elif hasattr(net_, 'children'): | |
for net__ in net_.children(): | |
count = register_recr(net__, count, place_in_unet) | |
return count | |
cross_att_count = 0 | |
sub_nets = unet.named_children() | |
for net in sub_nets: | |
if "down" in net[0]: | |
down_count = register_recr(net[1], 0, "down")#6 | |
cross_att_count += down_count | |
elif "up" in net[0]: | |
up_count = register_recr(net[1], 0, "up") #9 | |
cross_att_count += up_count | |
elif "mid" in net[0]: | |
mid_count = register_recr(net[1], 0, "mid") #1 | |
cross_att_count += mid_count | |
controller.num_att_layers = cross_att_count | |