|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from torch import Tensor |
|
|
|
from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids |
|
|
|
|
|
class XLMRobertaEmbeddings(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim, |
|
vocab_size, |
|
max_position_embeddings, |
|
type_vocab_size, |
|
padding_idx=None, |
|
device=None, |
|
dtype=None, |
|
): |
|
""" |
|
If max_position_embeddings <= 0, there's no position embeddings |
|
If type_vocab_size <= 0, there's no token type embeddings |
|
""" |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
self.word_embeddings = nn.Embedding( |
|
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs |
|
) |
|
self.max_position_embeddings = max_position_embeddings |
|
self.type_vocab_size = type_vocab_size |
|
if self.max_position_embeddings > 0: |
|
self.position_embeddings = nn.Embedding( |
|
max_position_embeddings, embed_dim, **factory_kwargs |
|
) |
|
if self.type_vocab_size > 0: |
|
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs) |
|
|
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None): |
|
""" |
|
input_ids: (batch, seqlen) |
|
position_ids: (batch, seqlen) |
|
token_type_ids: (batch, seqlen) |
|
""" |
|
batch_size, seqlen = input_ids.shape |
|
if isinstance(task_type, tuple): |
|
assert input_ids.shape[0] % 9 == 0 |
|
split = int(input_ids.shape[0] / 9) |
|
tensor1 = input_ids[:split, :] |
|
tensor2 = input_ids[split:, :] |
|
emb1 = self.word_embeddings(tensor1, task_type=task_type[0]) |
|
emb2 = self.word_embeddings(tensor2, task_type=task_type[1]) |
|
embeddings = torch.cat((emb1, emb2), dim=0) |
|
else: |
|
lora_kwargs = {'task_type': task_type} if task_type is not None else {} |
|
embeddings = self.word_embeddings(input_ids, **lora_kwargs) |
|
|
|
if self.max_position_embeddings > 0: |
|
if position_ids is None: |
|
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device) |
|
|
|
position_embeddings = self.position_embeddings(position_ids) |
|
embeddings = embeddings + position_embeddings |
|
if self.type_vocab_size > 0: |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device) |
|
if isinstance(task_type, tuple): |
|
assert embeddings.shape[0] % 9 == 0 |
|
split = int(embeddings.shape[0] / 9) |
|
emb1 = embeddings[:split, :, :] |
|
emb2 = embeddings[split:, :, :] |
|
token_type_embs1 = self.token_type_embeddings(token_type_ids, task_type=task_type[0]) |
|
token_type_embs2 = self.token_type_embeddings(token_type_ids, task_type=task_type[1]) |
|
emb1 = emb1 + token_type_embs1 |
|
emb2 = emb2 + token_type_embs2 |
|
embeddings = torch.cat((emb1, emb2), dim=0) |
|
else: |
|
lora_kwargs = {'task_type': task_type} if task_type is not None else {} |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs) |
|
embeddings = embeddings + token_type_embeddings |
|
return embeddings |
|
|