josh-oo commited on
Commit
3519726
1 Parent(s): 73c2f58

mBART + fine-tuned benjamin/gerpt2

Browse files
config.json ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_num_labels": 3,
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "add_bias_logits": false,
6
+ "add_final_layer_norm": true,
7
+ "architectures": [
8
+ "MLongformerEncoderDecoderForConditionalGenerationCustom"
9
+ ],
10
+ "attention_dilation": [
11
+ 1,
12
+ 1,
13
+ 1,
14
+ 1,
15
+ 1,
16
+ 1,
17
+ 1,
18
+ 1,
19
+ 1,
20
+ 1,
21
+ 1,
22
+ 1
23
+ ],
24
+ "attention_dropout": 0.0,
25
+ "attention_mode": "sliding_chunks",
26
+ "attention_probs_dropout_prob": 0.0,
27
+ "attention_window": [
28
+ 512,
29
+ 512,
30
+ 512,
31
+ 512,
32
+ 512,
33
+ 512,
34
+ 512,
35
+ 512,
36
+ 512,
37
+ 512,
38
+ 512,
39
+ 512
40
+ ],
41
+ "auto_map": {
42
+ "AutoConfig": "longformer_enc_dec_custom.MLongformerEncoderDecoderConfigCustom",
43
+ "AutoModelForSeq2SeqLM": "longformer_enc_dec_custom.MLongformerEncoderDecoderForConditionalGenerationCustom"
44
+ },
45
+ "autoregressive": false,
46
+ "bos_token_id": 0,
47
+ "classif_dropout": 0.0,
48
+ "classifier_dropout": 0.0,
49
+ "d_model": 1024,
50
+ "decoder_attention_heads": 16,
51
+ "decoder_config": {
52
+ "_name_or_path": "benjamin/gerpt2",
53
+ "activation_function": "gelu_new",
54
+ "add_cross_attention": false,
55
+ "architectures": [
56
+ "GPT2LMHeadModel"
57
+ ],
58
+ "attn_pdrop": 0.1,
59
+ "bad_words_ids": null,
60
+ "begin_suppress_tokens": null,
61
+ "bos_token_id": 0,
62
+ "chunk_size_feed_forward": 0,
63
+ "cross_attention_hidden_size": null,
64
+ "decoder_start_token_id": null,
65
+ "diversity_penalty": 0.0,
66
+ "do_sample": false,
67
+ "early_stopping": false,
68
+ "embd_pdrop": 0.1,
69
+ "encoder_no_repeat_ngram_size": 0,
70
+ "eos_token_id": 0,
71
+ "exponential_decay_length_penalty": null,
72
+ "finetuning_task": null,
73
+ "forced_bos_token_id": null,
74
+ "forced_eos_token_id": null,
75
+ "gradient_checkpointing": false,
76
+ "id2label": {
77
+ "0": "LABEL_0",
78
+ "1": "LABEL_1"
79
+ },
80
+ "initializer_range": 0.02,
81
+ "is_decoder": false,
82
+ "is_encoder_decoder": false,
83
+ "label2id": {
84
+ "LABEL_0": 0,
85
+ "LABEL_1": 1
86
+ },
87
+ "layer_norm_epsilon": 1e-05,
88
+ "length_penalty": 1.0,
89
+ "max_length": 20,
90
+ "min_length": 0,
91
+ "model_type": "gpt2",
92
+ "n_ctx": 1024,
93
+ "n_embd": 768,
94
+ "n_head": 12,
95
+ "n_inner": null,
96
+ "n_layer": 12,
97
+ "n_positions": 1024,
98
+ "no_repeat_ngram_size": 0,
99
+ "num_beam_groups": 1,
100
+ "num_beams": 1,
101
+ "num_return_sequences": 1,
102
+ "output_attentions": false,
103
+ "output_hidden_states": false,
104
+ "output_scores": false,
105
+ "pad_token_id": 1,
106
+ "prefix": null,
107
+ "problem_type": null,
108
+ "pruned_heads": {},
109
+ "remove_invalid_values": false,
110
+ "reorder_and_upcast_attn": false,
111
+ "repetition_penalty": 1.0,
112
+ "resid_pdrop": 0.1,
113
+ "return_dict": true,
114
+ "return_dict_in_generate": false,
115
+ "scale_attn_by_inverse_layer_idx": false,
116
+ "scale_attn_weights": true,
117
+ "sep_token_id": null,
118
+ "summary_activation": null,
119
+ "summary_first_dropout": 0.1,
120
+ "summary_proj_to_labels": true,
121
+ "summary_type": "cls_index",
122
+ "summary_use_proj": true,
123
+ "suppress_tokens": null,
124
+ "task_specific_params": {
125
+ "text-generation": {
126
+ "do_sample": true,
127
+ "max_length": 100
128
+ }
129
+ },
130
+ "temperature": 1.0,
131
+ "tf_legacy_loss": false,
132
+ "tie_encoder_decoder": false,
133
+ "tie_word_embeddings": false,
134
+ "tokenizer_class": null,
135
+ "top_k": 50,
136
+ "top_p": 1.0,
137
+ "torch_dtype": "float32",
138
+ "torchscript": false,
139
+ "transformers_version": "4.29.2",
140
+ "typical_p": 1.0,
141
+ "use_bfloat16": false,
142
+ "use_cache": true,
143
+ "vocab_size": 50258
144
+ },
145
+ "decoder_ffn_dim": 4096,
146
+ "decoder_layerdrop": 0.0,
147
+ "decoder_layers": 12,
148
+ "dropout": 0.1,
149
+ "encoder_attention_heads": 16,
150
+ "encoder_ffn_dim": 4096,
151
+ "encoder_layerdrop": 0.0,
152
+ "encoder_layers": 12,
153
+ "eos_token_id": 0,
154
+ "forced_eos_token_id": 2,
155
+ "from_mbart": false,
156
+ "global_attention_indices": [
157
+ -1
158
+ ],
159
+ "gradient_checkpointing": false,
160
+ "id2label": {
161
+ "0": "LABEL_0",
162
+ "1": "LABEL_1",
163
+ "2": "LABEL_2"
164
+ },
165
+ "init_std": 0.02,
166
+ "is_encoder_decoder": true,
167
+ "label2id": {
168
+ "LABEL_0": 0,
169
+ "LABEL_1": 1,
170
+ "LABEL_2": 2
171
+ },
172
+ "max_decoder_position_embeddings": 1024,
173
+ "max_encoder_position_embeddings": 4096,
174
+ "max_length": 1024,
175
+ "max_position_embeddings": 1024,
176
+ "model_type": "mbart",
177
+ "normalize_before": true,
178
+ "normalize_embedding": true,
179
+ "num_beams": 5,
180
+ "num_hidden_layers": 12,
181
+ "output_past": true,
182
+ "pad_token_id": 1,
183
+ "scale_embedding": true,
184
+ "static_position_embeddings": false,
185
+ "task_specific_params": {
186
+ "translation_en_to_ro": {
187
+ "decoder_start_token_id": 250020
188
+ }
189
+ },
190
+ "tie_word_embeddings": false,
191
+ "torch_dtype": "float32",
192
+ "transformers_version": "4.29.2",
193
+ "use_cache": true,
194
+ "vocab_size": 20031
195
+ }
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "forced_eos_token_id": 2,
6
+ "max_length": 1024,
7
+ "num_beams": 5,
8
+ "pad_token_id": 1,
9
+ "transformers_version": "4.29.2"
10
+ }
longformer_enc_dec_custom.py ADDED
@@ -0,0 +1,1108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ This code is in part adapted from AllenAI's Longformer:
4
+ https://github.com/allenai/longformer/
5
+ and in part adapted from:
6
+ https://github.com/huggingface/transformers
7
+
8
+ Author: Annette Rios ([email protected])
9
+
10
+ """
11
+ from typing import List, Optional, Tuple, Dict, Union
12
+ from torch import nn, Tensor, zeros
13
+ import torch
14
+ import math
15
+ import random
16
+ from transformers.models.mbart.modeling_mbart import MBartConfig, MBartForConditionalGeneration, MBartEncoder, MBartLearnedPositionalEmbedding, MBartEncoderLayer, MBartDecoder, MBartModel, _expand_mask
17
+ from transformers.modeling_outputs import BaseModelOutput,Seq2SeqModelOutput
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers import GPT2Model, GPT2Config, AutoModelForCausalLM,AutoConfig
20
+ from transformers.activations import ACT2FN
21
+
22
+ import torch.nn.functional as F
23
+ from transformers.models.roberta.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM
24
+
25
+ from functools import lru_cache
26
+ import os.path
27
+
28
+
29
+ class MLongformerEncoderDecoderForConditionalGenerationCustom(MBartForConditionalGeneration):
30
+ def __init__(self, config):
31
+ super(MBartForConditionalGeneration, self).__init__(config)
32
+ self.decoder_config = GPT2Config.from_dict(config.decoder_config)
33
+ self.decoder_config.add_cross_attention=True
34
+ self.config.eos_token_id = self.decoder_config.eos_token_id
35
+ #self.config.bos_token_id = 0
36
+
37
+ self.model = LongMBartModelCustom(config)
38
+ #self.register_buffer("final_logits_bias", torch.zeros((1, self.decoder_config.vocab_size)))
39
+
40
+ if self.config.from_mbart:
41
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
42
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
43
+ else:
44
+ self.lm_head = nn.Linear(self.decoder_config.n_embd, self.decoder_config.vocab_size, bias=False)
45
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.decoder_config.vocab_size)))
46
+
47
+ self.model.decoder = GPT2Model(self.decoder_config)
48
+ if config.attention_mode == 'n2':
49
+ pass # do nothing, use MBartSelfAttention instead
50
+ else:
51
+ for i, layer in enumerate(self.model.encoder.layers):
52
+ layer.self_attn = LongformerSelfAttentionForMBart(config, layer_id=i)
53
+ # Initialize weights and apply final processing
54
+ self.post_init()
55
+
56
+ def post_init(self):
57
+ super().post_init()
58
+ if not self.config.from_mbart:
59
+ self.lm_head = nn.Linear(self.decoder_config.n_embd, self.decoder_config.vocab_size, bias=False)
60
+
61
+ def _set_gradient_checkpointing(self, module, value=False):
62
+ if isinstance(module, (MBartDecoder)):
63
+ module.gradient_checkpointing = value
64
+ self.model.decoder._set_gradient_checkpointing(module, value=value)
65
+
66
+ @classmethod
67
+ def from_encoder_decoder_pretrained(
68
+ cls,
69
+ mbart_pretrained_model_name_or_path: str = None,
70
+ decoder_pretrained_model_name_or_path: str = None,
71
+ *model_args,
72
+ **kwargs
73
+ ) -> MBartForConditionalGeneration:
74
+ config = MLongformerEncoderDecoderConfigCustom.from_pretrained(mbart_pretrained_model_name_or_path)
75
+ config.from_mbart = True
76
+ config.tie_word_embeddings = False
77
+ config.decoder_config = GPT2Config.from_pretrained(decoder_pretrained_model_name_or_path).to_dict()
78
+
79
+ mbart = super().from_pretrained(mbart_pretrained_model_name_or_path, config=config)
80
+ decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, add_cross_attention=True)
81
+
82
+ mbart.model.decoder = decoder.transformer
83
+ mbart.lm_head = decoder.lm_head
84
+ mbart.register_buffer("final_logits_bias", torch.zeros((1, decoder.config.vocab_size)))
85
+
86
+ #reinit cross attention layers
87
+ mbart.model.enc_to_dec_proj.apply(mbart.model._init_weights)
88
+ for layer in mbart.model.decoder.h:
89
+ layer.crossattention.c_attn.apply(mbart.model.decoder._init_weights)
90
+
91
+ del mbart.model.shared
92
+ return mbart
93
+
94
+
95
+ class MLongformerEncoderDecoderConfigCustom(MBartConfig):
96
+ def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
97
+ autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
98
+ gradient_checkpointing: bool = False, **kwargs):
99
+ """
100
+ Args:
101
+ attention_window: list of attention window sizes of length = number of layers.
102
+ window size = number of attention locations on each side.
103
+ For an affective window size of 512, use `attention_window=[256]*num_layers`
104
+ which is 256 on each side.
105
+ attention_dilation: list of attention dilation of length = number of layers.
106
+ attention dilation of `1` means no dilation.
107
+ autoregressive: do autoregressive attention or have attention of both sides
108
+ attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
109
+ selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
110
+ """
111
+ super().__init__(**kwargs)
112
+ self.from_mbart = False
113
+ self.attention_window = attention_window
114
+ self.attention_dilation = attention_dilation
115
+ self.autoregressive = autoregressive
116
+ self.attention_mode = attention_mode
117
+ self.gradient_checkpointing = gradient_checkpointing
118
+ assert self.attention_mode in ['sliding_chunks', 'n2']
119
+
120
+
121
+ class LongMBartModelCustom(MBartModel):
122
+ def __init__(self, config: MBartConfig):
123
+ super().__init__(config)
124
+ del self.shared
125
+ decoder_config = GPT2Config.from_dict(config.decoder_config)
126
+
127
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
128
+ if self.config.from_mbart:
129
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
130
+
131
+ self.encoder = LongMBartEncoder(config)
132
+ self.enc_to_dec_proj = torch.nn.Linear(config.d_model, decoder_config.n_embd)
133
+ self.act = ACT2FN[decoder_config.activation_function]
134
+ self.decoder = GPT2Model(decoder_config)
135
+
136
+ # Initialize weights and apply final processing
137
+ self.post_init()
138
+
139
+ def get_input_embeddings(self):
140
+ return self.encoder.embed_tokens
141
+
142
+ def set_input_embeddings(self, value):
143
+ self.encoder.embed_tokens = value
144
+
145
+ def forward(
146
+ self,
147
+ input_ids: torch.LongTensor = None,
148
+ attention_mask: Optional[torch.Tensor] = None,
149
+ decoder_input_ids: Optional[torch.LongTensor] = None,
150
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
151
+ head_mask: Optional[torch.Tensor] = None,
152
+ decoder_head_mask: Optional[torch.Tensor] = None,
153
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
154
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
155
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
156
+ inputs_embeds: Optional[torch.FloatTensor] = None,
157
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
158
+ use_cache: Optional[bool] = None,
159
+ output_attentions: Optional[bool] = None,
160
+ output_hidden_states: Optional[bool] = None,
161
+ return_dict: Optional[bool] = None,
162
+ ):
163
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
164
+ output_hidden_states = (
165
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
166
+ )
167
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
168
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
+
170
+ # different to other models, MBart automatically creates decoder_input_ids from
171
+ # input_ids if no decoder_input_ids are provided
172
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
173
+ decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
174
+
175
+ #print("input_ids: ", input_ids)
176
+ #print("input_embeds: ", inputs_embeds)
177
+ #print("decoder_input_ids: ", decoder_input_ids.shape)
178
+ #print("attention_mask: ",attention_mask.shape)
179
+
180
+ if encoder_outputs is None:
181
+ encoder_outputs = self.encoder(
182
+ input_ids=input_ids,
183
+ attention_mask=attention_mask,
184
+ head_mask=head_mask,
185
+ inputs_embeds=inputs_embeds,
186
+ output_attentions=output_attentions,
187
+ output_hidden_states=output_hidden_states,
188
+ return_dict=return_dict,
189
+ )
190
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
191
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
192
+ encoder_outputs = BaseModelOutput(
193
+ last_hidden_state=encoder_outputs[0],
194
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
195
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
196
+ )
197
+
198
+ encoder_hidden_states = encoder_outputs[0]
199
+
200
+ #remove uneccessary padding spaces
201
+ non_empty_mask = attention_mask.abs().sum(dim=0).bool()
202
+ encoder_hidden_states = encoder_hidden_states[:,non_empty_mask]
203
+ encoder_attention_mask = attention_mask[:,non_empty_mask]
204
+
205
+ #to remove global attention tokens (2)
206
+ encoder_attention_mask = torch.clamp(encoder_attention_mask, min=0, max=1)
207
+
208
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
209
+ encoder_hidden_states = self.act(encoder_hidden_states)
210
+ encoder_hidden_states = torch.nn.Dropout(p=0.1)(encoder_hidden_states)
211
+
212
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
213
+ decoder_outputs = self.decoder(
214
+ input_ids=decoder_input_ids,
215
+ attention_mask=decoder_attention_mask,
216
+ encoder_hidden_states=encoder_hidden_states,
217
+ encoder_attention_mask=encoder_attention_mask,
218
+ head_mask=decoder_head_mask,
219
+ #cross_attn_head_mask=cross_attn_head_mask,
220
+ past_key_values=past_key_values,
221
+ inputs_embeds=decoder_inputs_embeds,
222
+ use_cache=use_cache,
223
+ output_attentions=output_attentions,
224
+ output_hidden_states=output_hidden_states,
225
+ return_dict=return_dict,
226
+ )
227
+
228
+ if not return_dict:
229
+ return decoder_outputs + encoder_outputs
230
+
231
+ return Seq2SeqModelOutput(
232
+ last_hidden_state=decoder_outputs.last_hidden_state,
233
+ past_key_values=decoder_outputs.past_key_values,
234
+ decoder_hidden_states=decoder_outputs.hidden_states,
235
+ decoder_attentions=decoder_outputs.attentions,
236
+ cross_attentions=decoder_outputs.cross_attentions,
237
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
238
+ encoder_hidden_states=encoder_outputs.hidden_states,
239
+ encoder_attentions=encoder_outputs.attentions,
240
+ )
241
+
242
+ class MLongformerEncoderDecoderForConditionalGeneration(MBartForConditionalGeneration):
243
+ def __init__(self, config):
244
+ super(MBartForConditionalGeneration, self).__init__(config)
245
+
246
+ self.model = LongMBartModel(config)
247
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
248
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
249
+ #print(self)
250
+
251
+ if config.attention_mode == 'n2':
252
+ pass # do nothing, use MBartSelfAttention instead
253
+ else:
254
+ for i, layer in enumerate(self.model.encoder.layers):
255
+ layer.self_attn = LongformerSelfAttentionForMBart(config, layer_id=i)
256
+ # Initialize weights and apply final processing
257
+ self.post_init()
258
+
259
+
260
+ class MLongformerEncoderDecoderConfig(MBartConfig):
261
+ def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
262
+ autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
263
+ gradient_checkpointing: bool = False, **kwargs):
264
+ """
265
+ Args:
266
+ attention_window: list of attention window sizes of length = number of layers.
267
+ window size = number of attention locations on each side.
268
+ For an affective window size of 512, use `attention_window=[256]*num_layers`
269
+ which is 256 on each side.
270
+ attention_dilation: list of attention dilation of length = number of layers.
271
+ attention dilation of `1` means no dilation.
272
+ autoregressive: do autoregressive attention or have attention of both sides
273
+ attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
274
+ selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
275
+ """
276
+ super().__init__(**kwargs)
277
+ self.attention_window = attention_window
278
+ self.attention_dilation = attention_dilation
279
+ self.autoregressive = autoregressive
280
+ self.attention_mode = attention_mode
281
+ self.gradient_checkpointing = gradient_checkpointing
282
+ assert self.attention_mode in ['sliding_chunks', 'n2']
283
+
284
+ class LongformerSelfAttentionForMBart(nn.Module):
285
+ def __init__(self, config, layer_id):
286
+ super().__init__()
287
+ self.embed_dim = config.d_model
288
+ self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
289
+ self.output = nn.Linear(self.embed_dim, self.embed_dim)
290
+
291
+ def forward(
292
+ self,
293
+ hidden_states: Tensor, # shape (batch_size, q_len, model_size)
294
+ key_value_states: Optional[Tensor] = None, # cross-attention in transformers.models.mbart.modeling_mbart
295
+ past_key_value: Optional[Tuple[Tensor]] = None, # only for decoder
296
+ attention_mask: Optional[Tensor] = None, # shape (batch_size, k_len) -> changed in transformers.models.modeling_mbart.MBartEncoder and MBartEncoderLayer (new mask uses bool -> global attention positions are lost, need to use the inverted orignal mask
297
+ layer_head_mask: Optional[Tensor] = None, # head dropout?
298
+ output_attentions: bool = False
299
+ ) -> Tuple[Tensor, Optional[Tensor]]:
300
+
301
+ bsz, tgt_len, embed_dim = hidden_states.size()
302
+ assert embed_dim == self.embed_dim
303
+ assert list(hidden_states.size()) == [bsz, tgt_len, embed_dim]
304
+
305
+ outputs = self.longformer_self_attn(
306
+ hidden_states,
307
+ attention_mask=attention_mask * -1, # shape (batch_size, 1, 1, key_len)
308
+ head_mask=None,
309
+ encoder_hidden_states=None,
310
+ encoder_attention_mask=None,
311
+ output_attentions=output_attentions,
312
+ )
313
+
314
+ ## new: MBart encoder expects shape (seq_len, bsz, embed_dim), no transpose needed
315
+ attn_output = self.output(outputs[0])
316
+ # new return in MBartAttention has attn_output, attn_weights_reshaped, past_key_value (only for decoder), need to return 3 values (None for past_key_value)
317
+ return (attn_output, outputs[1:] ,None) if len(outputs) == 2 else (attn_output, None, None)
318
+
319
+
320
+ class LongMBartEncoder(MBartEncoder):
321
+ """
322
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
323
+ [`MBartEncoderLayer`].
324
+
325
+ Args:
326
+ config: MBartConfig
327
+ embed_tokens (nn.Embedding): output embedding
328
+ """
329
+
330
+ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
331
+ super().__init__(config)
332
+
333
+ self.dropout = config.dropout
334
+ self.layerdrop = config.encoder_layerdrop
335
+
336
+ embed_dim = config.d_model
337
+ self.padding_idx = config.pad_token_id
338
+ self.max_source_positions = config.max_encoder_position_embeddings
339
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
340
+
341
+ if embed_tokens is not None:
342
+ self.embed_tokens = embed_tokens
343
+ else:
344
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
345
+
346
+ self.embed_positions = MBartLearnedPositionalEmbedding(
347
+ self.max_source_positions,
348
+ embed_dim,
349
+ )
350
+ self.layers = nn.ModuleList([LongMBartEncoderLayer(config) for _ in range(config.encoder_layers)])
351
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
352
+ self.layer_norm = nn.LayerNorm(config.d_model)
353
+
354
+ self.gradient_checkpointing = False
355
+ # Initialize weights and apply final processing
356
+ self.post_init()
357
+
358
+ def forward(
359
+ self,
360
+ input_ids: torch.LongTensor = None,
361
+ attention_mask: Optional[torch.Tensor] = None,
362
+ head_mask: Optional[torch.Tensor] = None,
363
+ inputs_embeds: Optional[torch.FloatTensor] = None,
364
+ output_attentions: Optional[bool] = None,
365
+ output_hidden_states: Optional[bool] = None,
366
+ return_dict: Optional[bool] = None,
367
+ ) -> Union[Tuple, BaseModelOutput]:
368
+ r"""
369
+ Args:
370
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
371
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
372
+ provide it.
373
+
374
+ Indices can be obtained using [`MBartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
375
+ [`PreTrainedTokenizer.__call__`] for details.
376
+
377
+ [What are input IDs?](../glossary#input-ids)
378
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
379
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
380
+
381
+ - 1 for tokens that are **not masked**,
382
+ - 0 for tokens that are **masked**.
383
+
384
+ [What are attention masks?](../glossary#attention-mask)
385
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
386
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
387
+
388
+ - 1 indicates the head is **not masked**,
389
+ - 0 indicates the head is **masked**.
390
+
391
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
392
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
393
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
394
+ than the model's internal embedding lookup matrix.
395
+ output_attentions (`bool`, *optional*):
396
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
397
+ returned tensors for more detail.
398
+ output_hidden_states (`bool`, *optional*):
399
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
400
+ for more detail.
401
+ return_dict (`bool`, *optional*):
402
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
403
+ """
404
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
405
+ output_hidden_states = (
406
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
407
+ )
408
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
409
+
410
+ # retrieve input_ids and inputs_embeds
411
+ if input_ids is not None and inputs_embeds is not None:
412
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
413
+ elif input_ids is not None:
414
+ input = input_ids
415
+ input_shape = input.shape
416
+ input_ids = input_ids.view(-1, input_shape[-1])
417
+ elif inputs_embeds is not None:
418
+ input = inputs_embeds[:, :, -1]
419
+ else:
420
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
421
+
422
+ if inputs_embeds is None:
423
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
424
+
425
+ embed_pos = self.embed_positions(input)
426
+
427
+ hidden_states = inputs_embeds + embed_pos
428
+ hidden_states = self.layernorm_embedding(hidden_states)
429
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
430
+
431
+ # expand attention_mask
432
+ longformer_attention_mask = None
433
+ if attention_mask is not None:
434
+ # need to return original, inverted mask for longformer attention, else value for global attention (=2 in given mask, will be -1) is lost
435
+ longformer_attention_mask = 1 - attention_mask
436
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
437
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
438
+
439
+
440
+ encoder_states = () if output_hidden_states else None
441
+ all_attentions = () if output_attentions else None
442
+
443
+ # check if head_mask has a correct number of layers specified if desired
444
+ if head_mask is not None:
445
+ if head_mask.size()[0] != len(self.layers):
446
+ raise ValueError(
447
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
448
+ f" {head_mask.size()[0]}."
449
+ )
450
+ for idx, encoder_layer in enumerate(self.layers):
451
+ if output_hidden_states:
452
+ encoder_states = encoder_states + (hidden_states,)
453
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
454
+ dropout_probability = random.uniform(0, 1)
455
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
456
+ layer_outputs = (None, None)
457
+ else:
458
+ if self.gradient_checkpointing and self.training:
459
+
460
+ def create_custom_forward(module):
461
+ def custom_forward(*inputs):
462
+ return module(*inputs, output_attentions)
463
+
464
+ return custom_forward
465
+
466
+ layer_outputs = torch.utils.checkpoint.checkpoint(
467
+ create_custom_forward(encoder_layer),
468
+ hidden_states,
469
+ attention_mask,
470
+ longformer_attention_mask,
471
+ (head_mask[idx] if head_mask is not None else None),
472
+ )
473
+ else:
474
+ layer_outputs = encoder_layer(
475
+ hidden_states,
476
+ attention_mask,
477
+ longformer_attention_mask,
478
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
479
+ output_attentions=output_attentions,
480
+ )
481
+
482
+ hidden_states = layer_outputs[0]
483
+
484
+ if output_attentions:
485
+ all_attentions = all_attentions + (layer_outputs[1],)
486
+
487
+ hidden_states = self.layer_norm(hidden_states)
488
+ #print("Encoder output: ",hidden_states.shape)
489
+
490
+ if output_hidden_states:
491
+ encoder_states = encoder_states + (hidden_states,)
492
+
493
+ if not return_dict:
494
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
495
+ return BaseModelOutput(
496
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
497
+ )
498
+
499
+
500
+ class LongMBartModel(MBartModel):
501
+ def __init__(self, config: MBartConfig):
502
+ super().__init__(config)
503
+
504
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
505
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
506
+
507
+ self.encoder = LongMBartEncoder(config, self.shared)
508
+ self.decoder = MBartDecoder(config, self.shared)
509
+
510
+ # Initialize weights and apply final processing
511
+ self.post_init()
512
+
513
+ def forward(
514
+ self,
515
+ input_ids: torch.LongTensor = None,
516
+ attention_mask: Optional[torch.Tensor] = None,
517
+ decoder_input_ids: Optional[torch.LongTensor] = None,
518
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
519
+ head_mask: Optional[torch.Tensor] = None,
520
+ decoder_head_mask: Optional[torch.Tensor] = None,
521
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
522
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
523
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
524
+ inputs_embeds: Optional[torch.FloatTensor] = None,
525
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
526
+ use_cache: Optional[bool] = None,
527
+ output_attentions: Optional[bool] = None,
528
+ output_hidden_states: Optional[bool] = None,
529
+ return_dict: Optional[bool] = None,
530
+ ) -> Union[Seq2SeqModelOutput, Tuple[torch.FloatTensor]]:
531
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
532
+ output_hidden_states = (
533
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
534
+ )
535
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
536
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
537
+
538
+ # different to other models, MBart automatically creates decoder_input_ids from
539
+ # input_ids if no decoder_input_ids are provided
540
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
541
+ decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
542
+
543
+ if encoder_outputs is None:
544
+ encoder_outputs = self.encoder(
545
+ input_ids=input_ids,
546
+ attention_mask=attention_mask,
547
+ head_mask=head_mask,
548
+ inputs_embeds=inputs_embeds,
549
+ output_attentions=output_attentions,
550
+ output_hidden_states=output_hidden_states,
551
+ return_dict=return_dict,
552
+ )
553
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
554
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
555
+ encoder_outputs = BaseModelOutput(
556
+ last_hidden_state=encoder_outputs[0],
557
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
558
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
559
+ )
560
+
561
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
562
+ decoder_outputs = self.decoder(
563
+ input_ids=decoder_input_ids,
564
+ attention_mask=decoder_attention_mask,
565
+ encoder_hidden_states=encoder_outputs[0],
566
+ encoder_attention_mask=attention_mask,
567
+ head_mask=decoder_head_mask,
568
+ cross_attn_head_mask=cross_attn_head_mask,
569
+ past_key_values=past_key_values,
570
+ inputs_embeds=decoder_inputs_embeds,
571
+ use_cache=use_cache,
572
+ output_attentions=output_attentions,
573
+ output_hidden_states=output_hidden_states,
574
+ return_dict=return_dict,
575
+ )
576
+
577
+ if not return_dict:
578
+ return decoder_outputs + encoder_outputs
579
+
580
+ return Seq2SeqModelOutput(
581
+ last_hidden_state=decoder_outputs.last_hidden_state,
582
+ past_key_values=decoder_outputs.past_key_values,
583
+ decoder_hidden_states=decoder_outputs.hidden_states,
584
+ decoder_attentions=decoder_outputs.attentions,
585
+ cross_attentions=decoder_outputs.cross_attentions,
586
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
587
+ encoder_hidden_states=encoder_outputs.hidden_states,
588
+ encoder_attentions=encoder_outputs.attentions,
589
+ )
590
+
591
+ class LongMBartEncoderLayer(MBartEncoderLayer):
592
+ def __init__(self, config: MBartConfig):
593
+ super().__init__(config)
594
+
595
+ def forward(
596
+ self,
597
+ hidden_states: torch.Tensor,
598
+ attention_mask: torch.Tensor,
599
+ longformer_attention_mask: torch.Tensor,
600
+ layer_head_mask: torch.Tensor,
601
+ output_attentions: bool = False,
602
+ ) -> torch.Tensor:
603
+ """
604
+ Args:
605
+ hidden_states (`torch.FloatTensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
606
+ attention_mask (`torch.FloatTensor`): attention mask of size
607
+ *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
608
+ longformer_attention_mask (:obj:`torch.FloatTensor`): attention mask of size
609
+ `(batch, src_len)` where 0=local, -1=global, 1=padding.
610
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
611
+ *(encoder_attention_heads,)*.
612
+ output_attentions (`bool`, *optional*):
613
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
614
+ returned tensors for more detail.
615
+ """
616
+ # if longformer attention instead of mbart self attention: use special mask
617
+ if isinstance(self.self_attn, LongformerSelfAttentionForMBart):
618
+ attention_mask = longformer_attention_mask
619
+ residual = hidden_states
620
+ hidden_states = self.self_attn_layer_norm(hidden_states)
621
+ hidden_states, attn_weights, _ = self.self_attn(
622
+ hidden_states=hidden_states,
623
+ attention_mask=attention_mask,
624
+ layer_head_mask=layer_head_mask,
625
+ output_attentions=output_attentions,
626
+ )
627
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
628
+ hidden_states = residual + hidden_states
629
+
630
+ residual = hidden_states
631
+ hidden_states = self.final_layer_norm(hidden_states)
632
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
633
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
634
+ hidden_states = self.fc2(hidden_states)
635
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
636
+ hidden_states = residual + hidden_states
637
+
638
+ if hidden_states.dtype == torch.float16 and (
639
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
640
+ ):
641
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
642
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
643
+
644
+ outputs = (hidden_states,)
645
+
646
+ if output_attentions:
647
+ outputs += (attn_weights,)
648
+
649
+ return outputs
650
+
651
+ class Longformer(RobertaModel):
652
+ def __init__(self, config):
653
+ super(Longformer, self).__init__(config)
654
+ if config.attention_mode == 'n2':
655
+ pass # do nothing, use BertSelfAttention instead
656
+ else:
657
+ for i, layer in enumerate(self.encoder.layer):
658
+ layer.attention.self = LongformerSelfAttention(config, layer_id=i)
659
+
660
+
661
+ class LongformerForMaskedLM(RobertaForMaskedLM):
662
+ def __init__(self, config):
663
+ super(LongformerForMaskedLM, self).__init__(config)
664
+ if config.attention_mode == 'n2':
665
+ pass # do nothing, use BertSelfAttention instead
666
+ else:
667
+ for i, layer in enumerate(self.roberta.encoder.layer):
668
+ layer.attention.self = LongformerSelfAttention(config, layer_id=i)
669
+
670
+
671
+ class LongformerConfig(RobertaConfig):
672
+ def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
673
+ autoregressive: bool = False, attention_mode: str = 'sliding_chunks', **kwargs):
674
+ """
675
+ Args:
676
+ attention_window: list of attention window sizes of length = number of layers.
677
+ window size = number of attention locations on each side.
678
+ For an affective window size of 512, use `attention_window=[256]*num_layers`
679
+ which is 256 on each side.
680
+ attention_dilation: list of attention dilation of length = number of layers.
681
+ attention dilation of `1` means no dilation.
682
+ autoregressive: do autoregressive attention or have attention of both sides
683
+ attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
684
+ selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
685
+ """
686
+ super().__init__(**kwargs)
687
+ self.attention_window = attention_window
688
+ self.attention_dilation = attention_dilation
689
+ self.autoregressive = autoregressive
690
+ self.attention_mode = attention_mode
691
+ assert self.attention_mode in ['sliding_chunks', 'n2', 'sliding_chunks_no_overlap']
692
+
693
+
694
+ class LongformerSelfAttention(nn.Module):
695
+ def __init__(self, config, layer_id):
696
+ super(LongformerSelfAttention, self).__init__()
697
+ if config.hidden_size % config.num_attention_heads != 0:
698
+ raise ValueError(
699
+ "The hidden size (%d) is not a multiple of the number of attention "
700
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
701
+ self.num_heads = config.num_attention_heads
702
+ self.head_dim = int(config.hidden_size / config.num_attention_heads)
703
+ self.embed_dim = config.hidden_size
704
+
705
+ self.query = nn.Linear(config.hidden_size, self.embed_dim)
706
+ self.key = nn.Linear(config.hidden_size, self.embed_dim)
707
+ self.value = nn.Linear(config.hidden_size, self.embed_dim)
708
+
709
+ self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
710
+ self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
711
+ self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
712
+
713
+ self.dropout = config.attention_probs_dropout_prob
714
+
715
+ self.layer_id = layer_id
716
+ self.attention_window = config.attention_window[self.layer_id]
717
+ self.attention_dilation = config.attention_dilation[self.layer_id]
718
+ self.attention_mode = config.attention_mode
719
+ self.autoregressive = config.autoregressive
720
+ assert self.attention_window > 0
721
+ assert self.attention_dilation > 0
722
+ assert self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']
723
+ if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']:
724
+ assert not self.autoregressive # not supported
725
+ assert self.attention_dilation == 1 # dilation is not supported
726
+
727
+ def forward(
728
+ self,
729
+ hidden_states,
730
+ attention_mask=None,
731
+ head_mask=None,
732
+ encoder_hidden_states=None,
733
+ encoder_attention_mask=None,
734
+ output_attentions=False,
735
+ ):
736
+ '''
737
+ The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
738
+ -ve: no attention
739
+ 0: local attention
740
+ +ve: global attention
741
+ '''
742
+ assert encoder_hidden_states is None, "`encoder_hidden_states` is not supported and should be None"
743
+ assert encoder_attention_mask is None, "`encoder_attention_mask` is not supported and should be None"
744
+
745
+ if attention_mask is not None:
746
+ key_padding_mask = attention_mask < 0
747
+ extra_attention_mask = attention_mask > 0
748
+ remove_from_windowed_attention_mask = attention_mask != 0
749
+
750
+ num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1)
751
+ max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
752
+ if max_num_extra_indices_per_batch <= 0:
753
+ extra_attention_mask = None
754
+ else:
755
+ # To support the case of variable number of global attention in the rows of a batch,
756
+ # we use the following three selection masks to select global attention embeddings
757
+ # in a 3d tensor and pad it to `max_num_extra_indices_per_batch`
758
+ # 1) selecting embeddings that correspond to global attention
759
+ extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True)
760
+ zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch,
761
+ device=num_extra_indices_per_batch.device)
762
+ # mask indicating which values are actually going to be padding
763
+ selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1)
764
+ # 2) location of the non-padding values in the selected global attention
765
+ selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True)
766
+ # 3) location of the padding values in the selected global attention
767
+ selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True)
768
+ else:
769
+ remove_from_windowed_attention_mask = None
770
+ extra_attention_mask = None
771
+ key_padding_mask = None
772
+
773
+ hidden_states = hidden_states.transpose(0, 1)
774
+ seq_len, bsz, embed_dim = hidden_states.size()
775
+ assert embed_dim == self.embed_dim
776
+ q = self.query(hidden_states)
777
+ k = self.key(hidden_states)
778
+ v = self.value(hidden_states)
779
+ q /= math.sqrt(self.head_dim)
780
+
781
+ q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
782
+ k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
783
+ # attn_weights = (bsz, seq_len, num_heads, window*2+1)
784
+ if self.attention_mode == "sliding_chunks":
785
+ attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0)
786
+ elif self.attention_mode == "sliding_chunks_no_overlap":
787
+ attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0)
788
+ else:
789
+ raise False
790
+ mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False)
791
+ if remove_from_windowed_attention_mask is not None:
792
+ # This implementation is fast and takes very little memory because num_heads x hidden_size = 1
793
+ # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size)
794
+ remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1)
795
+ # cast to float/half then replace 1's with -inf
796
+ float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0)
797
+ repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation)
798
+ float_mask = float_mask.repeat(1, 1, repeat_size, 1)
799
+ ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
800
+ # diagonal mask with zeros everywhere and -inf inplace of padding
801
+ if self.attention_mode == "sliding_chunks":
802
+ d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
803
+ elif self.attention_mode == "sliding_chunks_no_overlap":
804
+ d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
805
+
806
+ attn_weights += d_mask
807
+ assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads]
808
+ assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3]
809
+
810
+ # the extra attention
811
+ if extra_attention_mask is not None:
812
+ selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
813
+ selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
814
+ # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch)
815
+ selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k))
816
+ selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
817
+ # concat to attn_weights
818
+ # (bsz, seq_len, num_heads, extra attention count + 2*window+1)
819
+ attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
820
+ attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
821
+ if key_padding_mask is not None:
822
+ # softmax sometimes inserts NaN if all positions are masked, replace them with 0
823
+ attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
824
+ attn_weights = attn_weights_float.type_as(attn_weights)
825
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
826
+ v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
827
+ attn = 0
828
+ if extra_attention_mask is not None:
829
+ selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch)
830
+ selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
831
+ selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
832
+ # use `matmul` because `einsum` crashes sometimes with fp16
833
+ # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
834
+ attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2)
835
+ attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous()
836
+
837
+ if self.attention_mode == "sliding_chunks":
838
+ attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window)
839
+ elif self.attention_mode == "sliding_chunks_no_overlap":
840
+ attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window)
841
+ else:
842
+ raise False
843
+
844
+ attn = attn.type_as(hidden_states)
845
+ assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
846
+ attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous()
847
+
848
+ # For this case, we'll just recompute the attention for these indices
849
+ # and overwrite the attn tensor. TODO: remove the redundant computation
850
+ if extra_attention_mask is not None:
851
+ selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim)
852
+ selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]]
853
+
854
+ q = self.query_global(selected_hidden_states)
855
+ k = self.key_global(hidden_states)
856
+ v = self.value_global(hidden_states)
857
+ q /= math.sqrt(self.head_dim)
858
+
859
+ q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim)
860
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
861
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
862
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
863
+ assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len]
864
+
865
+ attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
866
+ attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0
867
+ if key_padding_mask is not None:
868
+ attn_weights = attn_weights.masked_fill(
869
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
870
+ -10000.0,
871
+ )
872
+ attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len)
873
+ attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
874
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
875
+ selected_attn = torch.bmm(attn_probs, v)
876
+ assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim]
877
+
878
+ selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim)
879
+ nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]]
880
+ attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states)
881
+
882
+ context_layer = attn.transpose(0, 1) # attn shape: (seq_len, bsz, embed_dim), context_layer shape: (bsz, seq_len, embed_dim)
883
+ if output_attentions:
884
+ if extra_attention_mask is not None:
885
+ # With global attention, return global attention probabilities only
886
+ # batch_size x num_heads x max_num_global_attention_tokens x sequence_length
887
+ # which is the attention weights from tokens with global attention to all tokens
888
+ # It doesn't not return local attention
889
+ # In case of variable number of global attantion in the rows of a batch,
890
+ # attn_weights are padded with -10000.0 attention scores
891
+ attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
892
+ else:
893
+ # without global attention, return local attention probabilities
894
+ # batch_size x num_heads x sequence_length x window_size
895
+ # which is the attention weights of every token attending to its neighbours
896
+ attn_weights = attn_weights.permute(0, 2, 1, 3)
897
+ outputs = (context_layer, attn_weights) if output_attentions else (context_layer,)
898
+ return outputs
899
+
900
+ def _skew(x, direction, padding_value):
901
+ '''Convert diagonals into columns (or columns into diagonals depending on `direction`'''
902
+ x_padded = F.pad(x, direction, value=padding_value)
903
+ x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
904
+ return x_padded
905
+
906
+
907
+ def _skew2(x, padding_value):
908
+ '''shift every row 1 step to right converting columns into diagonals'''
909
+ # X = B x C x M x L
910
+ B, C, M, L = x.size()
911
+ x = F.pad(x, (0, M + 1), value=padding_value) # B x C x M x (L+M+1)
912
+ x = x.view(B, C, -1) # B x C x ML+MM+M
913
+ x = x[:, :, :-M] # B x C x ML+MM
914
+ x = x.view(B, C, M, M + L) # B x C, M x L+M
915
+ x = x[:, :, :, :-1]
916
+ return x
917
+
918
+
919
+ def _chunk(x, w):
920
+ '''convert into overlapping chunkings. Chunk size = 2w, overlap size = w'''
921
+
922
+ # non-overlapping chunks of size = 2w
923
+ x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))
924
+
925
+ # use `as_strided` to make the chunks overlap with an overlap size = w
926
+ chunk_size = list(x.size())
927
+ chunk_size[1] = chunk_size[1] * 2 - 1
928
+
929
+ chunk_stride = list(x.stride())
930
+ chunk_stride[1] = chunk_stride[1] // 2
931
+ return x.as_strided(size=chunk_size, stride=chunk_stride)
932
+
933
+
934
+ def sliding_chunks_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
935
+ '''Matrix multiplicatio of query x key tensors using with a sliding window attention pattern.
936
+ This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
937
+ with an overlap of size w'''
938
+ bsz, seqlen, num_heads, head_dim = q.size()
939
+ assert seqlen % (w * 2) == 0
940
+ assert q.size() == k.size()
941
+
942
+ chunks_count = seqlen // w - 1
943
+
944
+ # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
945
+ q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
946
+ k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
947
+
948
+ chunk_q = _chunk(q, w)
949
+ chunk_k = _chunk(k, w)
950
+
951
+ # matrix multipication
952
+ # bcxd: bsz*num_heads x chunks x 2w x head_dim
953
+ # bcyd: bsz*num_heads x chunks x 2w x head_dim
954
+ # bcxy: bsz*num_heads x chunks x 2w x 2w
955
+ chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k)) # multiply
956
+
957
+ # convert diagonals into columns
958
+ diagonal_chunk_attn = _skew(chunk_attn, direction=(0, 0, 0, 1), padding_value=padding_value)
959
+
960
+ # allocate space for the overall attention matrix where the chunks are compined. The last dimension
961
+ # has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to
962
+ # w previous words). The following column is attention score from each word to itself, then
963
+ # followed by w columns for the upper triangle.
964
+
965
+ diagonal_attn = diagonal_chunk_attn.new_empty((bsz * num_heads, chunks_count + 1, w, w * 2 + 1))
966
+
967
+ # copy parts from diagonal_chunk_attn into the compined matrix of attentions
968
+ # - copying the main diagonal and the upper triangle
969
+ diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, :w + 1]
970
+ diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, :w + 1]
971
+ # - copying the lower triangle
972
+ diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, - (w + 1):-1, w + 1:]
973
+ diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, :w - 1, 1 - w:]
974
+
975
+ # separate bsz and num_heads dimensions again
976
+ diagonal_attn = diagonal_attn.view(bsz, num_heads, seqlen, 2 * w + 1).transpose(2, 1)
977
+
978
+ mask_invalid_locations(diagonal_attn, w, 1, False)
979
+ return diagonal_attn
980
+
981
+
982
+ def sliding_chunks_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
983
+ '''Same as sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output
984
+ format from sliding_chunks_matmul_qk'''
985
+ bsz, seqlen, num_heads, head_dim = v.size()
986
+ assert seqlen % (w * 2) == 0
987
+ assert prob.size()[:3] == v.size()[:3]
988
+ assert prob.size(3) == 2 * w + 1
989
+ chunks_count = seqlen // w - 1
990
+ # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size 2w
991
+ chunk_prob = prob.transpose(1, 2).reshape(bsz * num_heads, seqlen // w, w, 2 * w + 1)
992
+
993
+ # group bsz and num_heads dimensions into one
994
+ v = v.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
995
+
996
+ # pad seqlen with w at the beginning of the sequence and another w at the end
997
+ padded_v = F.pad(v, (0, 0, w, w), value=-1)
998
+
999
+ # chunk padded_v into chunks of size 3w and an overlap of size w
1000
+ chunk_v_size = (bsz * num_heads, chunks_count + 1, 3 * w, head_dim)
1001
+ chunk_v_stride = padded_v.stride()
1002
+ chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2]
1003
+ chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride)
1004
+
1005
+ skewed_prob = _skew2(chunk_prob, padding_value=0)
1006
+
1007
+ context = torch.einsum('bcwd,bcdh->bcwh', (skewed_prob, chunk_v))
1008
+ return context.view(bsz, num_heads, seqlen, head_dim).transpose(1, 2)
1009
+
1010
+
1011
+ def pad_to_window_size(input_ids: torch.Tensor, attention_mask: torch.Tensor,
1012
+ one_sided_window_size: int, pad_token_id: int):
1013
+ '''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer selfattention.
1014
+ Input:
1015
+ input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces
1016
+ attention_mask = torch.Tensor(bsz x seqlen): attention mask
1017
+ one_sided_window_size = int: window size on one side of each token
1018
+ pad_token_id = int: tokenizer.pad_token_id
1019
+ Returns
1020
+ (input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size
1021
+ '''
1022
+ w = int(2 * one_sided_window_size)
1023
+ seqlen = input_ids.size(1)
1024
+ padding_len = (w - seqlen % w) % w
1025
+ input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id)
1026
+ attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens
1027
+ return input_ids, attention_mask
1028
+
1029
+
1030
+ # ========= "sliding_chunks_no_overlap": alternative implemenation of the sliding window attention =========
1031
+ # This implementation uses non-overlapping chunks (or blocks) of size `w` with number of local attention = 3xw
1032
+ # To make this implemenation comparable to "sliding_chunks" set w such that
1033
+ # w_of_sliding_chunks_no_overlap = w_of_sliding_chunks * 2 / 3
1034
+ # For example,
1035
+ # w_of_sliding_chunks = 256 (this is one sided. Total attention size = 512)
1036
+ # w_of_sliding_chunks_no_overlap = 170 (Total attention size = 510)
1037
+ # Performance:
1038
+ # - Speed: 30% faster than "sliding_chunks"
1039
+ # - Memory: 95% of the memory usage of "sliding_chunks"
1040
+ # The windows are asymmetric where number of attention on each side of a token ranges between w to 2w
1041
+ # while "sliding_chunks" has a symmetric window around each token.
1042
+
1043
+
1044
+ def sliding_chunks_no_overlap_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
1045
+ bsz, seqlen, num_heads, head_dim = q.size()
1046
+ assert seqlen % w == 0
1047
+ assert q.size() == k.size()
1048
+ # chunk seqlen into non-overlapping chunks of size w
1049
+ chunk_q = q.view(bsz, seqlen // w, w, num_heads, head_dim)
1050
+ chunk_k = k.view(bsz, seqlen // w, w, num_heads, head_dim)
1051
+ chunk_k_expanded = torch.stack((
1052
+ F.pad(chunk_k[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
1053
+ chunk_k,
1054
+ F.pad(chunk_k[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
1055
+ ), dim=-1)
1056
+ diagonal_attn = torch.einsum('bcxhd,bcyhde->bcxhey', (chunk_q, chunk_k_expanded)) # multiply
1057
+ return diagonal_attn.reshape(bsz, seqlen, num_heads, 3 * w)
1058
+
1059
+
1060
+ def sliding_chunks_no_overlap_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
1061
+ bsz, seqlen, num_heads, head_dim = v.size()
1062
+ chunk_prob = prob.view(bsz, seqlen // w, w, num_heads, 3, w)
1063
+ chunk_v = v.view(bsz, seqlen // w, w, num_heads, head_dim)
1064
+ chunk_v_extended = torch.stack((
1065
+ F.pad(chunk_v[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
1066
+ chunk_v,
1067
+ F.pad(chunk_v[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
1068
+ ), dim=-1)
1069
+ context = torch.einsum('bcwhpd,bcdhep->bcwhe', (chunk_prob, chunk_v_extended))
1070
+ return context.reshape(bsz, seqlen, num_heads, head_dim)
1071
+
1072
+ def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int):
1073
+ diagonals_list = []
1074
+ for j in range(-d * w, d, d):
1075
+ diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8)
1076
+ diagonal_mask[:-j] = 1
1077
+ diagonals_list.append(diagonal_mask)
1078
+ return torch.stack(diagonals_list, dim=-1)
1079
+
1080
+ @lru_cache()
1081
+ def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str):
1082
+ if isinstance(d, int):
1083
+ affected_seq_len = w * d
1084
+ mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
1085
+ mask = mask[None, :, None, :]
1086
+ else:
1087
+ affected_seq_len = w * d.max()
1088
+ head_masks = []
1089
+ d_list = d.cpu().numpy().tolist()
1090
+ for d in d_list:
1091
+ one_head_mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
1092
+ head_masks.append(one_head_mask)
1093
+ mask = torch.stack(head_masks, dim=-2)
1094
+ mask = mask[None, :, :, :]
1095
+
1096
+ ending_mask = None if autoregressive else mask.flip(dims=(1, 3)).bool().to(device)
1097
+ return affected_seq_len, mask.bool().to(device), ending_mask
1098
+
1099
+ def mask_invalid_locations(input_tensor: torch.Tensor, w: int, d: Union[torch.Tensor, int], autoregressive: bool) -> torch.Tensor:
1100
+ affected_seq_len, beginning_mask, ending_mask = _get_invalid_locations_mask(w, d, autoregressive, input_tensor.device)
1101
+ seq_len = input_tensor.size(1)
1102
+ beginning_input = input_tensor[:, :affected_seq_len, :, :w+1]
1103
+ beginning_mask = beginning_mask[:, :seq_len].expand(beginning_input.size())
1104
+ beginning_input.masked_fill_(beginning_mask, -float('inf'))
1105
+ if not autoregressive:
1106
+ ending_input = input_tensor[:, -affected_seq_len:, :, -(w+1):]
1107
+ ending_mask = ending_mask[:, -seq_len:].expand(ending_input.size())
1108
+ ending_input.masked_fill_(ending_mask, -float('inf'))
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74434cbc4348f9491b766c9acc11ac4a55a46e00f6e61d486d77e41094a386c1
3
+ size 1648941221