|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from MusicCaps.modules import AudioEncoder |
|
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig |
|
|
|
class BartCaptionModel(nn.Module): |
|
def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768): |
|
super(BartCaptionModel, self).__init__() |
|
|
|
bart_config = BartConfig.from_pretrained(bart_type) |
|
self.tokenizer = BartTokenizer.from_pretrained(bart_type) |
|
self.bart = BartForConditionalGeneration(bart_config) |
|
|
|
self.n_sample = sr * duration |
|
self.hop_length = int(0.01 * sr) |
|
self.n_frames = int(self.n_sample // self.hop_length) |
|
self.num_of_stride_conv = num_of_conv - 1 |
|
self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1 |
|
self.audio_encoder = AudioEncoder( |
|
n_mels = n_mels, |
|
n_ctx = self.n_ctx, |
|
audio_dim = audio_dim, |
|
text_dim = self.bart.config.hidden_size, |
|
num_of_stride_conv = self.num_of_stride_conv |
|
) |
|
|
|
self.max_length = max_length |
|
self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100) |
|
|
|
@property |
|
def device(self): |
|
return list(self.parameters())[0].device |
|
|
|
def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): |
|
""" |
|
Shift input ids one token to the right.ls |
|
""" |
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() |
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
|
if pad_token_id is None: |
|
raise ValueError("self.model.config.pad_token_id has to be defined.") |
|
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
return shifted_input_ids |
|
|
|
def forward_encoder(self, audio): |
|
audio_embs = self.audio_encoder(audio) |
|
encoder_outputs = self.bart.model.encoder( |
|
input_ids=None, |
|
inputs_embeds=audio_embs, |
|
return_dict=True |
|
)["last_hidden_state"] |
|
return encoder_outputs, audio_embs |
|
|
|
def forward_decoder(self, text, encoder_outputs): |
|
text = self.tokenizer(text, |
|
padding='longest', |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_tensors="pt") |
|
input_ids = text["input_ids"].to(self.device) |
|
attention_mask = text["attention_mask"].to(self.device) |
|
|
|
decoder_targets = input_ids.masked_fill( |
|
input_ids == self.tokenizer.pad_token_id, -100 |
|
) |
|
|
|
decoder_input_ids = self.shift_tokens_right( |
|
decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id |
|
) |
|
|
|
decoder_outputs = self.bart( |
|
input_ids=None, |
|
attention_mask=None, |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=attention_mask, |
|
inputs_embeds=None, |
|
labels=None, |
|
encoder_outputs=(encoder_outputs,), |
|
return_dict=True |
|
) |
|
lm_logits = decoder_outputs["logits"] |
|
loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1)) |
|
return loss |
|
|
|
def forward(self, audio, text): |
|
encoder_outputs, _ = self.forward_encoder(audio) |
|
loss = self.forward_decoder(text, encoder_outputs) |
|
return loss |
|
|
|
def generate(self, |
|
samples, |
|
use_nucleus_sampling=False, |
|
num_beams=5, |
|
max_length=128, |
|
min_length=2, |
|
top_p=0.9, |
|
repetition_penalty=1.0, |
|
): |
|
|
|
|
|
audio_embs = self.audio_encoder(samples) |
|
encoder_outputs = self.bart.model.encoder( |
|
input_ids=None, |
|
attention_mask=None, |
|
head_mask=None, |
|
inputs_embeds=audio_embs, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=True) |
|
|
|
input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) |
|
input_ids[:, 0] = self.bart.config.decoder_start_token_id |
|
decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) |
|
if use_nucleus_sampling: |
|
outputs = self.bart.generate( |
|
input_ids=None, |
|
attention_mask=None, |
|
decoder_input_ids=input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
encoder_outputs=encoder_outputs, |
|
max_length=max_length, |
|
min_length=min_length, |
|
do_sample=True, |
|
top_p=top_p, |
|
num_return_sequences=1, |
|
repetition_penalty=1.1) |
|
else: |
|
outputs = self.bart.generate(input_ids=None, |
|
attention_mask=None, |
|
decoder_input_ids=input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
encoder_outputs=encoder_outputs, |
|
head_mask=None, |
|
decoder_head_mask=None, |
|
inputs_embeds=None, |
|
decoder_inputs_embeds=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
max_length=max_length, |
|
min_length=min_length, |
|
num_beams=num_beams, |
|
repetition_penalty=repetition_penalty) |
|
|
|
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
return captions |
|
|