Spaces:
Runtime error
Runtime error
import os | |
import math | |
import logging | |
import json | |
from pathlib import Path | |
from typing import Tuple, Optional, Union | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTJForCausalLM | |
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) | |
log = logging.getLogger(__name__) | |
def load_weights(self, Module, path, name, default_name, prev_name=None, **kwargs): | |
hparams = None | |
assert isinstance(default_name, str), f'invalid default transformer name: {default_name}' | |
model = get_transformer_module(Module, default_name, **kwargs) | |
setattr(self, name, model) | |
return hparams | |
def get_transformer_module(Module, default_name, **kwargs): | |
if default_name == 'EleutherAI/gpt-j-6B': | |
kwargs = {**kwargs, **dict(revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True)} | |
model = Module.from_pretrained(default_name, **kwargs) | |
return model | |
class MLP(nn.Module): | |
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh): | |
super(MLP, self).__init__() | |
self.divider = math.sqrt(sizes[-1] / sizes[0]) | |
layers = [] | |
for i in range(len(sizes) - 1): | |
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) | |
if i < len(sizes) - 2: | |
layers.append(act()) | |
self.model = nn.Sequential(*layers) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = x / self.divider # scaling for the initial stability | |
x = self.model(x) | |
return x | |
class MlpTransformer(nn.Module): | |
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=F.relu, dropout=0.): | |
super().__init__() | |
out_d = out_d if out_d is not None else in_dim | |
self.fc1 = nn.Linear(in_dim, h_dim) | |
self.act = act | |
self.fc2 = nn.Linear(h_dim, out_d) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.dropout(x) | |
x = self.fc2(x) | |
x = self.dropout(x) | |
return x | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim_self // num_heads | |
self.scale = head_dim ** -0.5 | |
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias) | |
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias) | |
self.project = nn.Linear(dim_self, dim_self) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, y=None, mask=None): | |
y = y if y is not None else x | |
b, n, c = x.shape | |
_, m, d = y.shape | |
# b n h dh | |
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads) | |
# b m 2 h dh | |
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads) | |
keys, values = keys_values[:, :, 0], keys_values[:, :, 1] | |
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale | |
if mask is not None: | |
if mask.dim() == 2: | |
mask = mask.unsqueeze(1) | |
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) | |
attention = attention.softmax(dim=2) | |
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c) | |
out = self.project(out) | |
return out, attention | |
class TransformerLayer(nn.Module): | |
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=F.relu, | |
norm_layer: nn.Module = nn.LayerNorm): | |
super().__init__() | |
self.norm1 = norm_layer(dim_self) | |
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout) | |
self.norm2 = norm_layer(dim_self) | |
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout) | |
def forward_with_attention(self, x, y=None, mask=None): | |
x_, attention = self.attn(self.norm1(x), y, mask) | |
x = x + x_ | |
x = x + self.mlp(self.norm2(x)) | |
return x, attention | |
def forward(self, x, y=None, mask=None): | |
x = x + self.attn(self.norm1(x), y, mask)[0] | |
x = x + self.mlp(self.norm2(x)) | |
return x | |
class Transformer(nn.Module): | |
def forward_with_attention(self, x, y=None, mask=None): | |
attentions = [] | |
for layer in self.layers: | |
x, att = layer.forward_with_attention(x, y, mask) | |
attentions.append(att) | |
return x, attentions | |
def forward(self, x, y=None, mask=None): | |
for i, layer in enumerate(self.layers): | |
if i % 2 == 0 and self.enc_dec: # cross | |
x = layer(x, y) | |
elif self.enc_dec: # self | |
x = layer(x, x, mask) | |
else: # self or cross | |
x = layer(x, y, mask) | |
return x | |
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None, | |
mlp_ratio: float = 2., act=F.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False): | |
super(Transformer, self).__init__() | |
dim_ref = dim_ref if dim_ref is not None else dim_self | |
self.enc_dec = enc_dec | |
if enc_dec: | |
num_layers = num_layers * 2 | |
layers = [] | |
for i in range(num_layers): | |
if i % 2 == 0 and enc_dec: # cross | |
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) | |
elif enc_dec: # self | |
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) | |
else: # self or cross | |
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) | |
self.layers = nn.ModuleList(layers) | |
class TransformerMapper(nn.Module): | |
def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int = 10, | |
clip_length: int = 10, num_layers: int = 8): | |
super(TransformerMapper, self).__init__() | |
self.clip_length = clip_length | |
self.transformer = Transformer(dim_embedding, 8, num_layers) | |
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding) | |
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True) | |
def forward(self, x): | |
x = self.linear(x).view(x.shape[0], self.clip_length, -1) | |
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape) | |
prefix = torch.cat((x, prefix), dim=1) | |
out = self.transformer(prefix)[:, self.clip_length:] | |
return out | |
class ClipCap(nn.Module): | |
def __init__(self, model_name, device, prefix_length: int = 10, clip_length: int = 40, prefix_size: int = 512, | |
num_layers: int = 1, model_path: str = '', fix_gpt: bool = False, | |
use_label_prefix: bool = False, label_path: str = '', label_length: int = 10, | |
use_transformer_mapper: bool = False, use_ptuning_v2: bool = False, | |
dropout: float = 0, | |
model_weight: str = '', scalar_output: bool = False): | |
super(ClipCap, self).__init__() | |
self.prefix_length = prefix_length | |
self.prefix_size = prefix_size | |
self.label_length = label_length | |
self.scalar_output = scalar_output | |
self.num_layers = num_layers | |
self.use_transformer_mapper = use_transformer_mapper | |
self.use_ptuning_v2 = use_ptuning_v2 | |
self.dropout = nn.Dropout(dropout) | |
hparams = load_weights(self, AutoModelForCausalLM, model_weight, 'gpt', model_name, | |
prev_name='model') | |
self.device = device | |
self.gpt = self.gpt.to(self.device) | |
config = self.gpt.config | |
self.match_n_layer = getattr(config, 'n_layer', getattr(config, 'num_layers', None)) # gpt2 vs. gpt_neo | |
self.match_n_head = getattr(config, 'n_head', getattr(config, 'num_heads', None)) | |
self.n_embd = getattr(config, 'n_embd', getattr(config, 'hidden_size', None)) | |
self.match_n_embd = self.n_embd // self.match_n_head | |
self.clip_project = self.get_mapper() | |
if Path(label_path).is_file(): | |
with open(label_path) as f: | |
labels = json.load(f) | |
self.labels = {i: v for v, i in labels.items()} | |
if not use_label_prefix: | |
log.info("adding label projections") | |
self.label_project = nn.Sequential( | |
nn.Embedding(len(self.labels), self.prefix_size), | |
self.get_mapper() | |
) | |
if os.path.isfile(model_path): | |
log.info(f"loading model from {model_path}") | |
weight = torch.load(model_path, map_location=torch.device('cpu')) | |
weight = {k[len('clip_project.'):]: v for k, v in weight.items() | |
if k.startswith('clip_project.')} | |
self.clip_project.load_state_dict(weight) | |
if fix_gpt: | |
log.info("fixing gpt parameters") | |
for param in self.gpt.parameters(): | |
param.requires_grad_(False) | |
if self.scalar_output: | |
self.gpt.lm_head = nn.Linear(self.gpt.transformer.embed_dim, 1).to(self.device) | |
self.clip_project = self.clip_project.to(self.device) | |
if hasattr(self, 'label_project'): | |
self.label_project = self.label_project.to(self.device) | |
def get_mapper(self): | |
if self.use_ptuning_v2: | |
total_embd = self.match_n_layer * 2 * self.n_embd | |
module = MLP((self.prefix_size, | |
*[self.prefix_size | |
for i in range(self.num_layers)], | |
total_embd * self.prefix_length)) | |
elif self.use_transformer_mapper: | |
log.info("using transformer mapper") | |
module = TransformerMapper(self.prefix_size, self.n_embd, | |
self.prefix_length, self.prefix_length, num_layers=self.num_layers) # 8) | |
else: | |
module = MLP((self.prefix_size, | |
*[(self.n_embd * self.prefix_length) // 2 | |
for i in range(self.num_layers)], | |
self.n_embd * self.prefix_length)) | |
return module | |
def get_encoder_loss(self, input_ids: torch.Tensor, features: torch.Tensor, | |
device = None): | |
input_ids = input_ids[:, :self.prefix_length].to(device) | |
embedding = self.gpt.transformer.wte(input_ids) | |
features = features.to(device) | |
prefix_projections = self.clip_project(features.type_as(embedding)).reshape(-1, self.prefix_length, self.n_embd) | |
fct = nn.MSELoss() | |
loss = fct(prefix_projections, embedding.detach()) | |
return loss | |
def forward(self, *args, **kwargs): | |
if self.use_ptuning_v2: | |
return self.forward_prefix(*args, **kwargs) | |
else: | |
return self.forward_embedding(*args, **kwargs) | |
def forward_embedding(self, input_ids: torch.Tensor, features: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
labels: Optional[torch.Tensor] = None, | |
past_key_values = None, device = None, **kwargs): | |
if device is None: | |
device = self.device | |
input_ids = input_ids.to(device) | |
if features is not None: | |
features = features.to(device) | |
if attention_mask is not None: | |
attention_mask = attention_mask.to(device) | |
if labels is not None: | |
labels = labels.to(device) | |
use_labels = labels is not None and hasattr(self, 'label_project') | |
embedding = self.gpt.transformer.wte(input_ids) | |
embed_txt = embedding | |
prefix_length = self.prefix_length | |
if use_labels: | |
prefix_length += self.label_length | |
if past_key_values is None: | |
prefix_projections = self.clip_project(features.type_as(embedding)).reshape(-1, self.prefix_length, self.n_embd) | |
if use_labels: | |
label_projections = self.label_project(labels.long()).reshape(-1, self.label_length, self.n_embd) | |
prefix_projections = torch.cat((prefix_projections, label_projections), dim=1) | |
embedding = torch.cat((prefix_projections.to(embedding.dtype), embedding), dim=1) | |
if torch.is_tensor(attention_mask): | |
prefix_mask = torch.ones_like(attention_mask)[:, :1].repeat(1, prefix_length) | |
attention_mask = torch.cat([prefix_mask, attention_mask], dim=1) | |
outputs = self.gpt(inputs_embeds=embedding, attention_mask=attention_mask, | |
past_key_values=past_key_values, | |
return_dict=True, | |
output_attentions=False, | |
output_hidden_states=True) | |
if past_key_values is None: | |
outputs.logits = outputs.logits[:, prefix_length:] | |
return outputs | |
def forward_prefix(self, input_ids: torch.Tensor, features: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
labels: Optional[torch.Tensor] = None, | |
past_key_values = None, device = None, **kwargs): | |
if device is None: | |
device = self.device | |
input_ids = input_ids.to(device) | |
if features is not None: | |
features = features.to(device) | |
if attention_mask is not None: | |
attention_mask = attention_mask.to(device) | |
if labels is not None: | |
labels = labels.to(device) | |
use_labels = labels is not None and hasattr(self, 'label_project') | |
prefix_length = self.prefix_length | |
if use_labels: | |
prefix_length += self.label_length | |
if past_key_values is None: | |
prefix_projections = self.clip_project(features.type_as(self.clip_project.model[0].weight)) | |
prefix_projections = prefix_projections.reshape(-1, self.prefix_length, | |
self.match_n_layer * 2, self.match_n_head, self.match_n_embd) | |
if use_labels: | |
label_projections = self.label_project(labels.long()) | |
label_projections = label_projections.reshape(-1, self.label_length, | |
self.match_n_layer * 2, self.match_n_head, self.match_n_embd) | |
prefix_projections = torch.cat((prefix_projections, label_projections), dim=1) | |
temp_control = prefix_projections | |
temp_control = self.dropout(temp_control) | |
past_key_values = temp_control.permute([2, 0, 3, 1, 4]).split(2) | |
if torch.is_tensor(attention_mask): | |
prefix_mask = torch.ones_like(attention_mask)[:, :1].repeat(1, prefix_length) | |
attention_mask = torch.cat([prefix_mask, attention_mask], dim=1) | |
outputs = self.gpt(input_ids=input_ids, attention_mask=attention_mask, | |
past_key_values=past_key_values, | |
return_dict=True, | |
output_attentions=False, | |
output_hidden_states=True) | |
if past_key_values is None: | |
outputs.logits = outputs.logits[:, prefix_length:] | |
return outputs | |
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): | |
token_type_ids = kwargs.get("token_type_ids", None) | |
# only last token for inputs_ids if past is defined in kwargs | |
if past: | |
input_ids = input_ids[:, -1].unsqueeze(-1) | |
if token_type_ids is not None: | |
token_type_ids = token_type_ids[:, -1].unsqueeze(-1) | |
attention_mask = kwargs.get("attention_mask", None) | |
position_ids = kwargs.get("position_ids", None) | |
features = kwargs.get("features", None) | |
labels = kwargs.get("labels", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past: | |
position_ids = position_ids[:, -1].unsqueeze(-1) | |
else: | |
position_ids = None | |
return { | |
"input_ids": input_ids, | |
"past_key_values": past, | |
"use_cache": kwargs.get("use_cache"), | |
"position_ids": position_ids, | |
"attention_mask": attention_mask, | |
"token_type_ids": token_type_ids, | |
"features": features, | |
"labels": labels, | |
} | |