Spaces:
Sleeping
Sleeping
# Copyright (c) Tencent Inc. All rights reserved. | |
import itertools | |
from typing import List, Sequence, Tuple | |
import torch | |
from torch import Tensor | |
from torch.nn.modules.batchnorm import _BatchNorm | |
from mmengine.model import BaseModule | |
from mmyolo.registry import MODELS | |
from mmdet.utils import OptMultiConfig, ConfigType | |
from transformers import ( | |
AutoTokenizer, | |
AutoModel, | |
CLIPTextConfig) | |
from transformers import CLIPTextModelWithProjection as CLIPTP | |
class HuggingVisionBackbone(BaseModule): | |
def __init__(self, | |
model_name: str, | |
out_indices: Sequence[int] = (0, 1, 2, 3), | |
norm_eval: bool = True, | |
frozen_modules: Sequence[str] = (), | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.norm_eval = norm_eval | |
self.frozen_modules = frozen_modules | |
self.model = AutoModel.from_pretrained(model_name) | |
self._freeze_modules() | |
def forward(self, image: Tensor) -> Tuple[Tensor]: | |
encoded_dict = self.image_model(pixel_values=image, | |
output_hidden_states=True) | |
hidden_states = encoded_dict.hidden_states | |
img_feats = encoded_dict.get('reshaped_hidden_states', hidden_states) | |
img_feats = [img_feats[i] for i in self.image_out_indices] | |
return tuple(img_feats) | |
def _freeze_modules(self): | |
for name, module in self.model.named_modules(): | |
for frozen_name in self.frozen_modules: | |
if name.startswith(frozen_name): | |
module.eval() | |
for param in module.parameters(): | |
param.requires_grad = False | |
break | |
def train(self, mode=True): | |
super().train(mode) | |
self._freeze_modules() | |
if mode and self.norm_eval: | |
for m in self.modules(): | |
# trick: eval have effect on BatchNorm only | |
if isinstance(m, _BatchNorm): | |
m.eval() | |
class HuggingCLIPLanguageBackbone(BaseModule): | |
def __init__(self, | |
model_name: str, | |
frozen_modules: Sequence[str] = (), | |
dropout: float = 0.0, | |
training_use_cache: bool = False, | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__(init_cfg=init_cfg) | |
self.frozen_modules = frozen_modules | |
self.training_use_cache = training_use_cache | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
clip_config = CLIPTextConfig.from_pretrained(model_name, | |
attention_dropout=dropout) | |
self.model = CLIPTP.from_pretrained(model_name, | |
config=clip_config) | |
self._freeze_modules() | |
def forward_cache(self, text: List[List[str]]) -> Tensor: | |
if not hasattr(self, "cache"): | |
self.cache = self.forward_text(text) | |
return self.cache | |
def forward(self, text: List[List[str]]) -> Tensor: | |
if self.training: | |
return self.forward_text(text) | |
else: | |
return self.forward_text(text) | |
# return self.forward_cache(text) | |
def forward_tokenizer(self, texts): | |
if not hasattr(self, 'text'): | |
text = list(itertools.chain(*texts)) | |
# print(text) | |
# # text = ['a photo of {}'.format(x) for x in text] | |
text = self.tokenizer(text=text, return_tensors='pt', padding=True) | |
# print(text) | |
self.text = text.to(device=self.model.device) | |
return self.text | |
def forward_text(self, text: List[List[str]]) -> Tensor: | |
num_per_batch = [len(t) for t in text] | |
assert max(num_per_batch) == min(num_per_batch), ( | |
'number of sequences not equal in batch') | |
# print(max([[len(t.split(' ')) for t in tt] for tt in text])) | |
# print(num_per_batch, max(num_per_batch)) | |
text = list(itertools.chain(*text)) | |
# print(text) | |
# text = ['a photo of {}'.format(x) for x in text] | |
# text = self.forward_tokenizer(text) | |
text = self.tokenizer(text=text, return_tensors='pt', padding=True) | |
text = text.to(device=self.model.device) | |
txt_outputs = self.model(**text) | |
# txt_feats = txt_outputs.last_hidden_state[:, 0, :] | |
txt_feats = txt_outputs.text_embeds | |
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) | |
txt_feats = txt_feats.reshape(-1, num_per_batch[0], | |
txt_feats.shape[-1]) | |
return txt_feats | |
def _freeze_modules(self): | |
if len(self.frozen_modules) == 0: | |
# not freeze | |
return | |
if self.frozen_modules[0] == "all": | |
self.model.eval() | |
for _, module in self.model.named_modules(): | |
module.eval() | |
for param in module.parameters(): | |
param.requires_grad = False | |
return | |
for name, module in self.model.named_modules(): | |
for frozen_name in self.frozen_modules: | |
if name.startswith(frozen_name): | |
module.eval() | |
for param in module.parameters(): | |
param.requires_grad = False | |
break | |
def train(self, mode=True): | |
super().train(mode) | |
self._freeze_modules() | |
class PseudoLanguageBackbone(BaseModule): | |
"""Pseudo Language Backbone | |
Args: | |
text_embed_path (str): path to the text embedding file | |
""" | |
def __init__(self, | |
text_embed_path: str = "", | |
test_embed_path: str = None, | |
init_cfg: OptMultiConfig = None): | |
super().__init__(init_cfg) | |
# {text:embed} | |
self.text_embed = torch.load(text_embed_path, map_location='cpu') | |
if test_embed_path is None: | |
self.test_embed = self.text_embed | |
else: | |
self.test_embed = torch.load(test_embed_path) | |
self.register_buffer("buff", torch.zeros([ | |
1, | |
])) | |
def forward_cache(self, text: List[List[str]]) -> Tensor: | |
if not hasattr(self, "cache"): | |
self.cache = self.forward_text(text) | |
return self.cache | |
def forward(self, text: List[List[str]]) -> Tensor: | |
if self.training: | |
return self.forward_text(text) | |
else: | |
return self.forward_cache(text) | |
def forward_text(self, text: List[List[str]]) -> Tensor: | |
num_per_batch = [len(t) for t in text] | |
assert max(num_per_batch) == min(num_per_batch), ( | |
'number of sequences not equal in batch') | |
text = list(itertools.chain(*text)) | |
if self.training: | |
text_embed_dict = self.text_embed | |
else: | |
text_embed_dict = self.test_embed | |
text_embeds = torch.stack( | |
[text_embed_dict[x.split("/")[0]] for x in text]) | |
# requires no grad and force to float | |
text_embeds = text_embeds.to( | |
self.buff.device).requires_grad_(False).float() | |
text_embeds = text_embeds.reshape(-1, num_per_batch[0], | |
text_embeds.shape[-1]) | |
return text_embeds | |
class MultiModalYOLOBackbone(BaseModule): | |
def __init__(self, | |
image_model: ConfigType, | |
text_model: ConfigType, | |
init_cfg: OptMultiConfig = None) -> None: | |
super().__init__(init_cfg) | |
self.image_model = MODELS.build(image_model) | |
self.text_model = MODELS.build(text_model) | |
def forward(self, image: Tensor, | |
text: List[List[str]]) -> Tuple[Tuple[Tensor], Tensor]: | |
img_feats = self.image_model(image) | |
txt_feats = self.text_model(text) | |
return img_feats, txt_feats | |