|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import re |
|
|
|
import requests |
|
import torch |
|
|
|
|
|
from models.blip import blip_decoder |
|
from models.blip_itm import blip_itm |
|
from models.blip_vqa import blip_vqa |
|
from PIL import Image |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
from transformers import ( |
|
BertTokenizer, |
|
BlipConfig, |
|
BlipForConditionalGeneration, |
|
BlipForImageTextRetrieval, |
|
BlipForQuestionAnswering, |
|
) |
|
|
|
|
|
def load_demo_image(image_size, device): |
|
img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg" |
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") |
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
] |
|
) |
|
image = transform(raw_image).unsqueeze(0).to(device) |
|
return image |
|
|
|
|
|
def rename_key(key): |
|
if "visual_encoder" in key: |
|
key = re.sub("visual_encoder*", "vision_model.encoder", key) |
|
if "blocks" in key: |
|
key = re.sub(r"blocks", "layers", key) |
|
if "attn" in key: |
|
key = re.sub(r"attn", "self_attn", key) |
|
if "norm1" in key: |
|
key = re.sub(r"norm1", "layer_norm1", key) |
|
if "norm2" in key: |
|
key = re.sub(r"norm2", "layer_norm2", key) |
|
if "encoder.norm" in key: |
|
key = re.sub(r"encoder.norm", "post_layernorm", key) |
|
if "encoder.patch_embed.proj" in key: |
|
key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key) |
|
|
|
if "encoder.pos_embed" in key: |
|
key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key) |
|
if "encoder.cls_token" in key: |
|
key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key) |
|
|
|
if "self_attn" in key: |
|
key = re.sub(r"self_attn.proj", "self_attn.projection", key) |
|
|
|
return key |
|
|
|
|
|
@torch.no_grad() |
|
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None): |
|
""" |
|
Copy/paste/tweak model's weights to transformers design. |
|
""" |
|
if config_path is not None: |
|
config = BlipConfig.from_pretrained(config_path) |
|
else: |
|
config = BlipConfig(projection_dim=512, text_config={}, vision_config={}) |
|
|
|
hf_model = BlipForConditionalGeneration(config).eval() |
|
|
|
model_url = "model_base_capfilt_large.pth" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_size = 384 |
|
image = load_demo_image(image_size=image_size, device="cpu") |
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base") |
|
vqa_model.eval() |
|
|
|
modified_state_dict = vqa_model.state_dict() |
|
for key in modified_state_dict.copy(): |
|
value = modified_state_dict.pop(key) |
|
renamed_key = rename_key(key) |
|
modified_state_dict[renamed_key] = value |
|
|
|
hf_vqa_model = BlipForQuestionAnswering(config) |
|
offset_keys = [i for i in modified_state_dict.keys() if i not in hf_vqa_model.state_dict().keys()] |
|
print(len([i for i in hf_vqa_model.state_dict().keys() if i in modified_state_dict.keys()])) |
|
for key in offset_keys: |
|
modified_state_dict.pop(key) |
|
|
|
hf_vqa_model.load_state_dict(modified_state_dict) |
|
|
|
question = ["How many dogs are in this image?"] |
|
question_input_ids = tokenizer(question, return_tensors="pt").input_ids |
|
|
|
answer = hf_vqa_model.generate(question_input_ids, image) |
|
print(tokenizer.decode(answer[0])) |
|
|
|
|
|
if pytorch_dump_folder_path is not None: |
|
hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") |
|
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") |
|
args = parser.parse_args() |
|
|
|
convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path) |
|
|