|
|
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
import math |
|
from tqdm import tqdm |
|
import argparse |
|
from collections import OrderedDict |
|
import json |
|
|
|
from collections import defaultdict |
|
from model.deberta_moe import DebertaV2ForMaskedLM |
|
from transformers import DebertaV2Tokenizer |
|
|
|
import clip |
|
import ffmpeg |
|
from VideoLoader import VideoLoader |
|
|
|
def get_mask(lengths, max_length): |
|
""" Computes a batch of padding masks given batched lengths """ |
|
mask = 1 * ( |
|
torch.arange(max_length).unsqueeze(1) < lengths |
|
).transpose(0, 1) |
|
return mask |
|
|
|
class Infer: |
|
def __init__(self, device): |
|
pretrained_ckpt = torch.load("ckpts/model.pth", map_location="cpu") |
|
args = pretrained_ckpt['args'] |
|
args.n_ans = 2 |
|
args.max_tokens = 256 |
|
self.args = args |
|
self.clip_model = clip.load("ViT-L/14", device = device)[0] |
|
self.tokenizer = DebertaV2Tokenizer.from_pretrained( |
|
"ckpts/deberta-v2-xlarge", local_files_only=True |
|
) |
|
|
|
self.model = DebertaV2ForMaskedLM.from_pretrained( |
|
features_dim=args.features_dim if args.use_video else 0, |
|
max_feats=args.max_feats, |
|
freeze_lm=args.freeze_lm, |
|
freeze_mlm=args.freeze_mlm, |
|
ft_ln=args.ft_ln, |
|
ds_factor_attn=args.ds_factor_attn, |
|
ds_factor_ff=args.ds_factor_ff, |
|
dropout=args.dropout, |
|
n_ans=args.n_ans, |
|
freeze_last=args.freeze_last, |
|
pretrained_model_name_or_path="ckpts/deberta-v2-xlarge", |
|
local_files_only=False, |
|
add_video_feat=args.add_video_feat, |
|
freeze_ad=args.freeze_ad, |
|
) |
|
new_state_dict = OrderedDict() |
|
for k, v in pretrained_ckpt['model'].items(): |
|
new_state_dict[k.replace("module.","")] = v |
|
self.model.load_state_dict(pretrained_ckpt, strict=False) |
|
self.model.eval() |
|
self.model.to(device) |
|
self.device = device |
|
|
|
self.video_loader = VideoLoader() |
|
self.set_answer() |
|
|
|
def _get_clip_feature(self, video): |
|
feat = self.clip_model.encode_image(video.to(self.device)) |
|
|
|
return feat |
|
|
|
def set_answer(self): |
|
tok_yes = torch.tensor( |
|
self.tokenizer( |
|
"Yes", |
|
add_special_tokens=False, |
|
max_length=1, |
|
truncation=True, |
|
padding="max_length", |
|
)["input_ids"], |
|
dtype=torch.long, |
|
) |
|
tok_no = torch.tensor( |
|
self.tokenizer( |
|
"No", |
|
add_special_tokens=False, |
|
max_length=1, |
|
truncation=True, |
|
padding="max_length", |
|
)["input_ids"], |
|
dtype=torch.long, |
|
) |
|
|
|
a2tok = torch.stack([tok_yes, tok_no]) |
|
self.model.set_answer_embeddings( |
|
a2tok.to(self.model.device), freeze_last=self.args.freeze_last |
|
) |
|
|
|
def generate(self, text, candidates, video_path): |
|
video, video_len = self.video_loader(video_path) |
|
video = self._get_clip_feature(video).unsqueeze(0).float() |
|
video_mask = get_mask(video_len, 10) |
|
video_mask = torch.cat([torch.ones((1,1)),video_mask], dim=1) |
|
logits_list = [] |
|
|
|
question = text.capitalize().strip() |
|
if question[-1] != "?": |
|
question = str(question) + "?" |
|
|
|
for aid in range(len(candidates)): |
|
prompt = ( |
|
f" Question: {question} Is it '{candidates[aid]}'? {self.tokenizer.mask_token}. Subtitles: " |
|
) |
|
prompt = prompt.strip() |
|
encoded = self.tokenizer( |
|
prompt, |
|
add_special_tokens=True, |
|
max_length=self.args.max_tokens, |
|
padding="longest", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
output = self.model( |
|
video=video.to(self.device), |
|
video_mask=video_mask.to(self.device), |
|
input_ids=encoded["input_ids"].to(self.device), |
|
attention_mask=encoded["attention_mask"].to(self.device), |
|
) |
|
|
|
logits = output["logits"] |
|
|
|
delay = 11 |
|
logits = logits[:, delay : encoded["input_ids"].size(1) + delay][ |
|
encoded["input_ids"] == self.tokenizer.mask_token_id |
|
] |
|
logits_list.append(logits.softmax(-1)[:, 0]) |
|
|
|
logits = torch.stack(logits_list, 1) |
|
if logits.shape[1] == 1: |
|
preds = logits.round().long().squeeze(1) |
|
else: |
|
preds = logits.max(1).indices |
|
|
|
return candidates[preds] |
|
|