GinnM commited on
Commit
5c054bd
1 Parent(s): 02e8d56

Upload ProPrimeForMaskedLM

Browse files
Files changed (3) hide show
  1. config.json +7 -0
  2. model.safetensors +3 -0
  3. modeling_proprime.py +1179 -0
config.json CHANGED
@@ -1,5 +1,11 @@
1
  {
 
 
 
2
  "attention_probs_dropout_prob": 0.0,
 
 
 
3
  "emb_layer_norm_before": false,
4
  "flash_attention": true,
5
  "hidden_dropout_prob": 0.0,
@@ -15,6 +21,7 @@
15
  "pad_token_id": 1,
16
  "position_embedding_type": "rotary",
17
  "token_dropout": true,
 
18
  "transformers_version": "4.36.2",
19
  "use_cache": true,
20
  "vocab_size": 33
 
1
  {
2
+ "architectures": [
3
+ "ProPrimeForMaskedLM"
4
+ ],
5
  "attention_probs_dropout_prob": 0.0,
6
+ "auto_map": {
7
+ "AutoModelForMaskedLM": "modeling_proprime.ProPrimeForMaskedLM"
8
+ },
9
  "emb_layer_norm_before": false,
10
  "flash_attention": true,
11
  "hidden_dropout_prob": 0.0,
 
21
  "pad_token_id": 1,
22
  "position_embedding_type": "rotary",
23
  "token_dropout": true,
24
+ "torch_dtype": "float32",
25
  "transformers_version": "4.36.2",
26
  "use_cache": true,
27
  "vocab_size": 33
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc27a3f563758e3a1445c31e0280c1672b51a07c6199ab5056c2c32ed9c12844
3
+ size 2604245372
modeling_proprime.py ADDED
@@ -0,0 +1,1179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+ import torch.nn.functional as F
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from dataclasses import dataclass
9
+ from transformers.modeling_outputs import (
10
+ BaseModelOutputWithPastAndCrossAttentions,
11
+ BaseModelOutputWithPoolingAndCrossAttentions,
12
+ MaskedLMOutput,
13
+ ModelOutput,
14
+ )
15
+ from transformers.modeling_utils import (
16
+ PreTrainedModel,
17
+ find_pruneable_heads_and_indices,
18
+ prune_linear_layer,
19
+ )
20
+ from transformers.utils import logging
21
+ from ProPrime.configuration_proprime import ProPrimeConfig
22
+ from torch.nn.functional import scaled_dot_product_attention
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ PROPRIME_PRETRAINED_MODEL_ARCHIVE_LIST = [
28
+ "AI4protein/proprime_650M",
29
+ ]
30
+
31
+
32
+ def rotate_half(x):
33
+ return torch.cat((-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]), dim=-1)
34
+
35
+
36
+ def apply_rotary_pos_emb(x, cos, sin):
37
+ cos = cos[:, :, : x.shape[-2], :]
38
+ sin = sin[:, :, : x.shape[-2], :]
39
+ return (x * cos) + (rotate_half(x) * sin)
40
+
41
+
42
+ def gelu(x):
43
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
44
+
45
+
46
+ class RotaryEmbedding(torch.nn.Module):
47
+ def __init__(self, dim: int):
48
+ super().__init__()
49
+ # Generate and save the inverse frequency buffer (non trainable)
50
+ inv_freq = 1.0 / (
51
+ 10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
52
+ )
53
+ inv_freq = inv_freq
54
+ self.register_buffer("inv_freq", inv_freq)
55
+
56
+ self._seq_len_cached = None
57
+ self._cos_cached = None
58
+ self._sin_cached = None
59
+
60
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
61
+ seq_len = x.shape[seq_dimension]
62
+
63
+ # Reset the tables if the sequence length has changed,
64
+ # or if we're on a new device (possibly due to tracing for instance)
65
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
66
+ self._seq_len_cached = seq_len
67
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(
68
+ self.inv_freq
69
+ )
70
+ freqs = torch.outer(t, self.inv_freq)
71
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
72
+
73
+ self._cos_cached = emb.cos()[None, None, :, :]
74
+ self._sin_cached = emb.sin()[None, None, :, :]
75
+
76
+ return self._cos_cached, self._sin_cached
77
+
78
+ def forward(
79
+ self, q: torch.Tensor, k: torch.Tensor
80
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
81
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
82
+ k, seq_dimension=-2
83
+ )
84
+
85
+ return (
86
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
87
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
88
+ )
89
+
90
+
91
+ class ProPrimeEmbeddings(nn.Module):
92
+
93
+ def __init__(self, config):
94
+ super().__init__()
95
+ self.word_embeddings = nn.Embedding(
96
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
97
+ )
98
+
99
+ if config.emb_layer_norm_before:
100
+ self.layer_norm = nn.LayerNorm(
101
+ config.hidden_size, eps=config.layer_norm_eps
102
+ )
103
+ else:
104
+ self.layer_norm = None
105
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
106
+ self.position_embedding_type = getattr(
107
+ config, "position_embedding_type", "absolute"
108
+ )
109
+ self.register_buffer(
110
+ "position_ids",
111
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
112
+ persistent=False,
113
+ )
114
+
115
+ self.padding_idx = config.pad_token_id
116
+ if self.position_embedding_type == "absolute":
117
+ self.position_embeddings = nn.Embedding(
118
+ config.max_position_embeddings,
119
+ config.hidden_size,
120
+ padding_idx=self.padding_idx,
121
+ )
122
+ self.token_dropout = config.token_dropout
123
+ self.mask_token_id = config.mask_token_id
124
+
125
+ def forward(
126
+ self,
127
+ input_ids=None,
128
+ attention_mask=None,
129
+ position_ids=None,
130
+ inputs_embeds=None,
131
+ past_key_values_length=0,
132
+ ):
133
+ if position_ids is None:
134
+ if input_ids is not None:
135
+ position_ids = create_position_ids_from_input_ids(
136
+ input_ids, self.padding_idx, past_key_values_length
137
+ )
138
+ else:
139
+ position_ids = self.create_position_ids_from_inputs_embeds(
140
+ inputs_embeds
141
+ )
142
+
143
+ if inputs_embeds is None:
144
+ inputs_embeds = self.word_embeddings(input_ids)
145
+
146
+ embeddings = inputs_embeds
147
+
148
+ if self.token_dropout:
149
+ embeddings = embeddings.masked_fill(
150
+ (input_ids == self.mask_token_id).unsqueeze(-1), 0.0
151
+ )
152
+ mask_ratio_train = 0.15 * 0.8
153
+ src_lengths = attention_mask.sum(-1)
154
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum(
155
+ -1
156
+ ).float() / src_lengths
157
+ embeddings = (
158
+ embeddings
159
+ * (1 - mask_ratio_train)
160
+ / (1 - mask_ratio_observed)[:, None, None]
161
+ ).to(embeddings.dtype)
162
+
163
+ if self.position_embedding_type == "absolute":
164
+ position_embeddings = self.position_embeddings(position_ids)
165
+ embeddings = embeddings + position_embeddings
166
+
167
+ if self.layer_norm is not None:
168
+ embeddings = self.layer_norm(embeddings)
169
+ if attention_mask is not None:
170
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(
171
+ embeddings.dtype
172
+ )
173
+ # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
174
+ # embeddings = self.dropout(embeddings)
175
+ return embeddings
176
+
177
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
178
+ input_shape = inputs_embeds.size()[:-1]
179
+ sequence_length = input_shape[1]
180
+
181
+ position_ids = torch.arange(
182
+ self.padding_idx + 1,
183
+ sequence_length + self.padding_idx + 1,
184
+ dtype=torch.long,
185
+ device=inputs_embeds.device,
186
+ )
187
+ return position_ids.unsqueeze(0).expand(input_shape)
188
+
189
+
190
+ class ProPrimeSelfAttention(nn.Module):
191
+ def __init__(self, config, position_embedding_type=None):
192
+ super().__init__()
193
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
194
+ config, "embedding_size"
195
+ ):
196
+ raise ValueError(
197
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
198
+ f"heads ({config.num_attention_heads})"
199
+ )
200
+
201
+ self.num_attention_heads = config.num_attention_heads
202
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
203
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
204
+
205
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
206
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
207
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
208
+
209
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
210
+ self.position_embedding_type = position_embedding_type or getattr(
211
+ config, "position_embedding_type", "absolute"
212
+ )
213
+ self.rotary_embeddings = None
214
+ if (
215
+ self.position_embedding_type == "relative_key"
216
+ or self.position_embedding_type == "relative_key_query"
217
+ ):
218
+ self.max_position_embeddings = config.max_position_embeddings
219
+ self.distance_embedding = nn.Embedding(
220
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
221
+ )
222
+ elif self.position_embedding_type == "rotary":
223
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
224
+ self.flash_attention = config.flash_attention
225
+ self.is_decoder = config.is_decoder
226
+ self.config = config
227
+
228
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
229
+ new_x_shape = x.size()[:-1] + (
230
+ self.num_attention_heads,
231
+ self.attention_head_size,
232
+ )
233
+ x = x.view(new_x_shape)
234
+ return x.permute(0, 2, 1, 3)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states: torch.Tensor,
239
+ attention_mask: Optional[torch.FloatTensor] = None,
240
+ head_mask: Optional[torch.FloatTensor] = None,
241
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
242
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
243
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
244
+ output_attentions: Optional[bool] = False,
245
+ ) -> Tuple[torch.Tensor]:
246
+ mixed_query_layer = self.query(hidden_states)
247
+
248
+ # If this is instantiated as a cross-attention module, the keys
249
+ # and values come from an encoder; the attention mask needs to be
250
+ # such that the encoder's padding tokens are not attended to.
251
+ is_cross_attention = encoder_hidden_states is not None
252
+
253
+ if is_cross_attention and past_key_value is not None:
254
+ # reuse k,v, cross_attentions
255
+ key_layer = past_key_value[0]
256
+ value_layer = past_key_value[1]
257
+ attention_mask = encoder_attention_mask
258
+ elif is_cross_attention:
259
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
260
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
261
+ attention_mask = encoder_attention_mask
262
+ elif past_key_value is not None:
263
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
264
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
265
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
266
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
267
+ else:
268
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
269
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
270
+
271
+ query_layer = self.transpose_for_scores(mixed_query_layer)
272
+
273
+ query_layer = query_layer * self.attention_head_size**-0.5
274
+
275
+ if self.is_decoder:
276
+ past_key_value = (key_layer, value_layer)
277
+
278
+ if self.position_embedding_type == "rotary":
279
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
280
+
281
+ if not self.flash_attention:
282
+ # Take the dot product between "query" and "key" to get the raw attention scores.
283
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
284
+
285
+ if (
286
+ self.position_embedding_type == "relative_key"
287
+ or self.position_embedding_type == "relative_key_query"
288
+ ):
289
+ seq_length = hidden_states.size()[1]
290
+ position_ids_l = torch.arange(
291
+ seq_length, dtype=torch.long, device=hidden_states.device
292
+ ).view(-1, 1)
293
+ position_ids_r = torch.arange(
294
+ seq_length, dtype=torch.long, device=hidden_states.device
295
+ ).view(1, -1)
296
+ distance = position_ids_l - position_ids_r
297
+ positional_embedding = self.distance_embedding(
298
+ distance + self.max_position_embeddings - 1
299
+ )
300
+ positional_embedding = positional_embedding.to(
301
+ dtype=query_layer.dtype
302
+ ) # fp16 compatibility
303
+
304
+ if self.position_embedding_type == "relative_key":
305
+ relative_position_scores = torch.einsum(
306
+ "bhld,lrd->bhlr", query_layer, positional_embedding
307
+ )
308
+ attention_scores = attention_scores + relative_position_scores
309
+ elif self.position_embedding_type == "relative_key_query":
310
+ relative_position_scores_query = torch.einsum(
311
+ "bhld,lrd->bhlr", query_layer, positional_embedding
312
+ )
313
+ relative_position_scores_key = torch.einsum(
314
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
315
+ )
316
+ attention_scores = (
317
+ attention_scores
318
+ + relative_position_scores_query
319
+ + relative_position_scores_key
320
+ )
321
+
322
+ if attention_mask is not None:
323
+ attention_scores = attention_scores + attention_mask
324
+
325
+ # Normalize the attention scores to probabilities.
326
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
327
+
328
+ # This is actually dropping out entire tokens to attend to, which might
329
+ # seem a bit unusual, but is taken from the original Transformer paper.
330
+ attention_probs = self.dropout(attention_probs)
331
+
332
+ # Mask heads if we want to
333
+ if head_mask is not None:
334
+ attention_probs = attention_probs * head_mask
335
+
336
+ context_layer = torch.matmul(attention_probs, value_layer)
337
+ else:
338
+ if self.training:
339
+ context_layer = scaled_dot_product_attention(
340
+ query_layer,
341
+ key_layer,
342
+ value_layer,
343
+ attn_mask=attention_mask,
344
+ dropout_p=self.config.attention_probs_dropout_prob,
345
+ scale=1, # we have query_layer = query_layer * self.attention_head_size**-0.5
346
+ )
347
+ else:
348
+ context_layer = scaled_dot_product_attention(
349
+ query_layer,
350
+ key_layer,
351
+ value_layer,
352
+ attn_mask=attention_mask,
353
+ scale=1, # we have query_layer = query_layer * self.attention_head_size**-0.5
354
+ )
355
+
356
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
357
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
358
+ context_layer = context_layer.view(new_context_layer_shape)
359
+
360
+ outputs = (
361
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
362
+ )
363
+
364
+ if self.is_decoder:
365
+ outputs = outputs + (past_key_value,)
366
+ return outputs
367
+
368
+
369
+ class ProPrimeSelfOutput(nn.Module):
370
+ def __init__(self, config):
371
+ super().__init__()
372
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
373
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
374
+
375
+ def forward(self, hidden_states, input_tensor):
376
+ hidden_states = self.dense(hidden_states)
377
+ hidden_states = self.dropout(hidden_states)
378
+ hidden_states = hidden_states + input_tensor
379
+ return hidden_states
380
+
381
+
382
+ class ProPrimeAttention(nn.Module):
383
+ def __init__(self, config):
384
+ super().__init__()
385
+ self.self = ProPrimeSelfAttention(config)
386
+ self.output = ProPrimeSelfOutput(config)
387
+ self.pruned_heads = set()
388
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
389
+
390
+ def prune_heads(self, heads):
391
+ if len(heads) == 0:
392
+ return
393
+ heads, index = find_pruneable_heads_and_indices(
394
+ heads,
395
+ self.self.num_attention_heads,
396
+ self.self.attention_head_size,
397
+ self.pruned_heads,
398
+ )
399
+
400
+ # Prune linear layers
401
+ self.self.query = prune_linear_layer(self.self.query, index)
402
+ self.self.key = prune_linear_layer(self.self.key, index)
403
+ self.self.value = prune_linear_layer(self.self.value, index)
404
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
405
+
406
+ # Update hyper params and store pruned heads
407
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
408
+ self.self.all_head_size = (
409
+ self.self.attention_head_size * self.self.num_attention_heads
410
+ )
411
+ self.pruned_heads = self.pruned_heads.union(heads)
412
+
413
+ def forward(
414
+ self,
415
+ hidden_states,
416
+ attention_mask=None,
417
+ head_mask=None,
418
+ encoder_hidden_states=None,
419
+ encoder_attention_mask=None,
420
+ past_key_value=None,
421
+ output_attentions=False,
422
+ ):
423
+ hidden_states_ln = self.LayerNorm(hidden_states)
424
+ self_outputs = self.self(
425
+ hidden_states_ln,
426
+ attention_mask,
427
+ head_mask,
428
+ encoder_hidden_states,
429
+ encoder_attention_mask,
430
+ past_key_value,
431
+ output_attentions,
432
+ )
433
+ attention_output = self.output(self_outputs[0], hidden_states)
434
+ outputs = (attention_output,) + self_outputs[
435
+ 1:
436
+ ] # add attentions if we output them
437
+ return outputs
438
+
439
+
440
+ class ProPrimeIntermediate(nn.Module):
441
+ def __init__(self, config):
442
+ super().__init__()
443
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
444
+
445
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
446
+ hidden_states = self.dense(hidden_states)
447
+ hidden_states = gelu(hidden_states)
448
+ return hidden_states
449
+
450
+
451
+ class ProPrimeOutput(nn.Module):
452
+ def __init__(self, config):
453
+ super().__init__()
454
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
455
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
456
+
457
+ def forward(self, hidden_states, input_tensor):
458
+ hidden_states = self.dense(hidden_states)
459
+ hidden_states = self.dropout(hidden_states)
460
+ hidden_states = hidden_states + input_tensor
461
+ return hidden_states
462
+
463
+
464
+ class ProPrimeLayer(nn.Module):
465
+ def __init__(self, config):
466
+ super().__init__()
467
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
468
+ self.seq_len_dim = 1
469
+ self.attention = ProPrimeAttention(config)
470
+ self.is_decoder = config.is_decoder
471
+ self.add_cross_attention = config.add_cross_attention
472
+ if self.add_cross_attention:
473
+ if not self.is_decoder:
474
+ raise RuntimeError(
475
+ f"{self} should be used as a decoder model if cross attention is added"
476
+ )
477
+ self.crossattention = ProPrimeAttention(config)
478
+ self.intermediate = ProPrimeIntermediate(config)
479
+ self.output = ProPrimeOutput(config)
480
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
481
+
482
+ def forward(
483
+ self,
484
+ hidden_states,
485
+ attention_mask=None,
486
+ head_mask=None,
487
+ encoder_hidden_states=None,
488
+ encoder_attention_mask=None,
489
+ past_key_value=None,
490
+ output_attentions=False,
491
+ ):
492
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
493
+ self_attn_past_key_value = (
494
+ past_key_value[:2] if past_key_value is not None else None
495
+ )
496
+ self_attention_outputs = self.attention(
497
+ hidden_states,
498
+ attention_mask,
499
+ head_mask,
500
+ output_attentions=output_attentions,
501
+ past_key_value=self_attn_past_key_value,
502
+ )
503
+ attention_output = self_attention_outputs[0]
504
+
505
+ # if decoder, the last output is tuple of self-attn cache
506
+ if self.is_decoder:
507
+ outputs = self_attention_outputs[1:-1]
508
+ present_key_value = self_attention_outputs[-1]
509
+ else:
510
+ outputs = self_attention_outputs[
511
+ 1:
512
+ ] # add self attentions if we output attention weights
513
+
514
+ cross_attn_present_key_value = None
515
+ if self.is_decoder and encoder_hidden_states is not None:
516
+ if not hasattr(self, "crossattention"):
517
+ raise AttributeError(
518
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
519
+ " with cross-attention layers by setting `config.add_cross_attention=True`"
520
+ )
521
+
522
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
523
+ cross_attn_past_key_value = (
524
+ past_key_value[-2:] if past_key_value is not None else None
525
+ )
526
+ cross_attention_outputs = self.crossattention(
527
+ attention_output,
528
+ attention_mask,
529
+ head_mask,
530
+ encoder_hidden_states,
531
+ encoder_attention_mask,
532
+ cross_attn_past_key_value,
533
+ output_attentions,
534
+ )
535
+ attention_output = cross_attention_outputs[0]
536
+ outputs = (
537
+ outputs + cross_attention_outputs[1:-1]
538
+ ) # add cross attentions if we output attention weights
539
+
540
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
541
+ cross_attn_present_key_value = cross_attention_outputs[-1]
542
+ present_key_value = present_key_value + cross_attn_present_key_value
543
+
544
+ layer_output = self.feed_forward_chunk(attention_output)
545
+
546
+ outputs = (layer_output,) + outputs
547
+
548
+ # if decoder, return the attn key/values as the last output
549
+ if self.is_decoder:
550
+ outputs = outputs + (present_key_value,)
551
+ return outputs
552
+
553
+ def feed_forward_chunk(self, attention_output):
554
+ attention_output_ln = self.LayerNorm(attention_output)
555
+ intermediate_output = self.intermediate(attention_output_ln)
556
+ layer_output = self.output(intermediate_output, attention_output)
557
+ return layer_output
558
+
559
+
560
+ class ProPrimeEncoder(nn.Module):
561
+ def __init__(self, config):
562
+ super().__init__()
563
+ self.config = config
564
+ self.layer = nn.ModuleList(
565
+ [ProPrimeLayer(config) for _ in range(config.num_hidden_layers)]
566
+ )
567
+ self.emb_layer_norm_after = nn.LayerNorm(
568
+ config.hidden_size, eps=config.layer_norm_eps
569
+ )
570
+ self.gradient_checkpointing = False
571
+
572
+ def forward(
573
+ self,
574
+ hidden_states,
575
+ attention_mask=None,
576
+ head_mask=None,
577
+ encoder_hidden_states=None,
578
+ encoder_attention_mask=None,
579
+ past_key_values=None,
580
+ use_cache=None,
581
+ output_attentions=False,
582
+ output_hidden_states=False,
583
+ return_dict=True,
584
+ ):
585
+ if self.gradient_checkpointing and self.training:
586
+ if use_cache:
587
+ logger.warning_once(
588
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
589
+ "`use_cache=False`..."
590
+ )
591
+ use_cache = False
592
+ all_hidden_states = () if output_hidden_states else None
593
+ all_self_attentions = () if output_attentions else None
594
+ all_cross_attentions = (
595
+ () if output_attentions and self.config.add_cross_attention else None
596
+ )
597
+
598
+ next_decoder_cache = () if use_cache else None
599
+ for i, layer_module in enumerate(self.layer):
600
+ if output_hidden_states:
601
+ all_hidden_states = all_hidden_states + (hidden_states,)
602
+
603
+ layer_head_mask = head_mask[i] if head_mask is not None else None
604
+ past_key_value = past_key_values[i] if past_key_values is not None else None
605
+
606
+ if self.gradient_checkpointing and self.training:
607
+ layer_outputs = self._gradient_checkpointing_func(
608
+ layer_module.__call__,
609
+ hidden_states,
610
+ attention_mask,
611
+ layer_head_mask,
612
+ encoder_hidden_states,
613
+ encoder_attention_mask,
614
+ past_key_value,
615
+ output_attentions,
616
+ )
617
+ else:
618
+ layer_outputs = layer_module(
619
+ hidden_states,
620
+ attention_mask,
621
+ layer_head_mask,
622
+ encoder_hidden_states,
623
+ encoder_attention_mask,
624
+ past_key_value,
625
+ output_attentions,
626
+ )
627
+
628
+ hidden_states = layer_outputs[0]
629
+ if use_cache:
630
+ next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
631
+ if output_attentions:
632
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
633
+ if self.config.add_cross_attention:
634
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
635
+
636
+ if self.emb_layer_norm_after:
637
+ hidden_states = self.emb_layer_norm_after(hidden_states)
638
+
639
+ if output_hidden_states:
640
+ all_hidden_states = all_hidden_states + (hidden_states,)
641
+
642
+ if not return_dict:
643
+ return tuple(
644
+ v
645
+ for v in [
646
+ hidden_states,
647
+ next_decoder_cache,
648
+ all_hidden_states,
649
+ all_self_attentions,
650
+ all_cross_attentions,
651
+ ]
652
+ if v is not None
653
+ )
654
+ return BaseModelOutputWithPastAndCrossAttentions(
655
+ last_hidden_state=hidden_states,
656
+ past_key_values=next_decoder_cache,
657
+ hidden_states=all_hidden_states,
658
+ attentions=all_self_attentions,
659
+ cross_attentions=all_cross_attentions,
660
+ )
661
+
662
+
663
+ class ProPrimePreTrainedModel(PreTrainedModel):
664
+ """
665
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
666
+ models.
667
+ """
668
+
669
+ config_class = ProPrimeConfig
670
+ base_model_prefix = "proprime"
671
+ supports_gradient_checkpointing = True
672
+ _no_split_modules = [
673
+ "ProPrimeLayer",
674
+ "ProPrimeEmbeddings",
675
+ ]
676
+
677
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
678
+ def _init_weights(self, module):
679
+ """Initialize the weights"""
680
+ if isinstance(module, nn.Linear):
681
+ # Slightly different from the TF version which uses truncated_normal for initialization
682
+ # cf https://github.com/pytorch/pytorch/pull/5617
683
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
684
+ if module.bias is not None:
685
+ module.bias.data.zero_()
686
+ elif isinstance(module, nn.Embedding):
687
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
688
+ if module.padding_idx is not None:
689
+ module.weight.data[module.padding_idx].zero_()
690
+ elif isinstance(module, nn.LayerNorm):
691
+ module.bias.data.zero_()
692
+ module.weight.data.fill_(1.0)
693
+
694
+
695
+ class ProPrimeModel(ProPrimePreTrainedModel):
696
+ def __init__(self, config, add_pooling_layer=True):
697
+ super().__init__(config)
698
+ self.config = config
699
+ self.embeddings = ProPrimeEmbeddings(config)
700
+ self.encoder = ProPrimeEncoder(config)
701
+ self.post_init()
702
+
703
+ def get_input_embeddings(self):
704
+ return self.embeddings.word_embeddings
705
+
706
+ def set_input_embeddings(self, value):
707
+ self.embeddings.word_embeddings = value
708
+
709
+ def _prune_heads(self, heads_to_prune):
710
+ """
711
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
712
+ class PreTrainedModel
713
+ """
714
+ for layer, heads in heads_to_prune.items():
715
+ self.encoder.layer[layer].attention.prune_heads(heads)
716
+
717
+ def forward(
718
+ self,
719
+ input_ids: Optional[torch.Tensor] = None,
720
+ attention_mask: Optional[torch.Tensor] = None,
721
+ position_ids: Optional[torch.Tensor] = None,
722
+ head_mask: Optional[torch.Tensor] = None,
723
+ inputs_embeds: Optional[torch.Tensor] = None,
724
+ encoder_hidden_states: Optional[torch.Tensor] = None,
725
+ encoder_attention_mask: Optional[torch.Tensor] = None,
726
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
727
+ use_cache: Optional[bool] = None,
728
+ output_attentions: Optional[bool] = None,
729
+ output_hidden_states: Optional[bool] = None,
730
+ return_dict: Optional[bool] = None,
731
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
732
+ output_attentions = (
733
+ output_attentions
734
+ if output_attentions is not None
735
+ else self.config.output_attentions
736
+ )
737
+ output_hidden_states = (
738
+ output_hidden_states
739
+ if output_hidden_states is not None
740
+ else self.config.output_hidden_states
741
+ )
742
+ return_dict = (
743
+ return_dict if return_dict is not None else self.config.use_return_dict
744
+ )
745
+
746
+ if self.config.is_decoder:
747
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
748
+ else:
749
+ use_cache = False
750
+
751
+ if input_ids is not None and inputs_embeds is not None:
752
+ raise ValueError(
753
+ "You cannot specify both input_ids and inputs_embeds at the same time"
754
+ )
755
+ elif input_ids is not None:
756
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
757
+ input_shape = input_ids.size()
758
+ elif inputs_embeds is not None:
759
+ input_shape = inputs_embeds.size()[:-1]
760
+ else:
761
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
762
+
763
+ batch_size, seq_length = input_shape
764
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
765
+
766
+ # past_key_values_length
767
+ past_key_values_length = (
768
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
769
+ )
770
+
771
+ if attention_mask is None:
772
+ attention_mask = torch.ones(
773
+ ((batch_size, seq_length + past_key_values_length)), device=device
774
+ )
775
+
776
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
777
+ attention_mask, input_shape
778
+ )
779
+
780
+ if self.config.is_decoder and encoder_hidden_states is not None:
781
+ encoder_batch_size, encoder_sequence_length, _ = (
782
+ encoder_hidden_states.size()
783
+ )
784
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
785
+ if encoder_attention_mask is None:
786
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
787
+ encoder_extended_attention_mask = self.invert_attention_mask(
788
+ encoder_attention_mask
789
+ )
790
+ else:
791
+ encoder_extended_attention_mask = None
792
+
793
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
794
+
795
+ embedding_output = self.embeddings(
796
+ input_ids=input_ids,
797
+ position_ids=position_ids,
798
+ attention_mask=attention_mask,
799
+ inputs_embeds=inputs_embeds,
800
+ past_key_values_length=past_key_values_length,
801
+ )
802
+ encoder_outputs = self.encoder(
803
+ embedding_output,
804
+ attention_mask=extended_attention_mask,
805
+ head_mask=head_mask,
806
+ encoder_hidden_states=encoder_hidden_states,
807
+ encoder_attention_mask=encoder_extended_attention_mask,
808
+ past_key_values=past_key_values,
809
+ use_cache=use_cache,
810
+ output_attentions=output_attentions,
811
+ output_hidden_states=output_hidden_states,
812
+ return_dict=return_dict,
813
+ )
814
+ sequence_output = encoder_outputs[0]
815
+
816
+ return BaseModelOutputWithPoolingAndCrossAttentions(
817
+ last_hidden_state=sequence_output,
818
+ past_key_values=encoder_outputs.past_key_values,
819
+ hidden_states=encoder_outputs.hidden_states,
820
+ attentions=encoder_outputs.attentions,
821
+ cross_attentions=encoder_outputs.cross_attentions,
822
+ )
823
+
824
+
825
+ class ProPrimeForMaskedLM(ProPrimePreTrainedModel):
826
+ _tied_weights_keys = ["lm_head.decoder.weight"]
827
+
828
+ def __init__(self, config):
829
+ super().__init__(config)
830
+
831
+ if config.is_decoder:
832
+ logger.warning(
833
+ "If you want to use `ProPrimeForMaskedLM` make sure `config.is_decoder=False` for "
834
+ "bi-directional self-attention."
835
+ )
836
+
837
+ self.pro_prime = ProPrimeModel(config, add_pooling_layer=False)
838
+ self.lm_head = ProPrimeLMHead(config)
839
+ self.init_weights()
840
+
841
+ def get_input_embeddings(self):
842
+ return self.pro_prime.embeddings.word_embeddings
843
+
844
+ def get_output_embeddings(self):
845
+ return self.lm_head.decoder
846
+
847
+ def set_output_embeddings(self, new_embeddings):
848
+ self.lm_head.decoder = new_embeddings
849
+
850
+ def forward(
851
+ self,
852
+ input_ids: Optional[torch.LongTensor] = None,
853
+ attention_mask: Optional[torch.Tensor] = None,
854
+ position_ids: Optional[torch.LongTensor] = None,
855
+ head_mask: Optional[torch.Tensor] = None,
856
+ inputs_embeds: Optional[torch.FloatTensor] = None,
857
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
858
+ encoder_attention_mask: Optional[torch.Tensor] = None,
859
+ labels: Optional[torch.LongTensor] = None,
860
+ output_attentions: Optional[bool] = None,
861
+ output_hidden_states: Optional[bool] = None,
862
+ return_dict: Optional[bool] = None,
863
+ ) -> Union[Tuple, MaskedLMOutput]:
864
+ return_dict = (
865
+ return_dict if return_dict is not None else self.config.use_return_dict
866
+ )
867
+
868
+ outputs = self.pro_prime(
869
+ input_ids,
870
+ attention_mask=attention_mask,
871
+ position_ids=position_ids,
872
+ head_mask=head_mask,
873
+ inputs_embeds=inputs_embeds,
874
+ encoder_hidden_states=encoder_hidden_states,
875
+ encoder_attention_mask=encoder_attention_mask,
876
+ output_attentions=output_attentions,
877
+ output_hidden_states=output_hidden_states,
878
+ return_dict=return_dict,
879
+ )
880
+ sequence_output = outputs[0]
881
+ prediction_scores = self.lm_head(sequence_output)
882
+
883
+ masked_lm_loss = None
884
+ if labels is not None:
885
+ loss_fct = CrossEntropyLoss()
886
+
887
+ labels = labels.to(prediction_scores.device)
888
+ masked_lm_loss = loss_fct(
889
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
890
+ )
891
+
892
+ if not return_dict:
893
+ output = (prediction_scores,) + outputs[2:]
894
+ return (
895
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
896
+ )
897
+
898
+ return MaskedLMOutput(
899
+ loss=masked_lm_loss,
900
+ logits=prediction_scores,
901
+ hidden_states=outputs.hidden_states,
902
+ attentions=outputs.attentions,
903
+ )
904
+
905
+
906
+ class ProPrimeLMHead(nn.Module):
907
+
908
+ def __init__(self, config):
909
+ super().__init__()
910
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
911
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
912
+
913
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
914
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
915
+
916
+ def forward(self, features, **kwargs):
917
+ x = self.dense(features)
918
+ x = gelu(x)
919
+ x = self.layer_norm(x)
920
+
921
+ # project back to size of vocabulary with bias
922
+ x = self.decoder(x) + self.bias
923
+ return x
924
+
925
+
926
+ def create_position_ids_from_input_ids(
927
+ input_ids, padding_idx, past_key_values_length=0
928
+ ):
929
+ """
930
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
931
+ are ignored. This is modified from fairseq's `utils.make_positions`.
932
+
933
+ Args:
934
+ x: torch.Tensor x:
935
+
936
+ Returns: torch.Tensor
937
+ """
938
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
939
+ mask = input_ids.ne(padding_idx).int()
940
+ incremental_indices = (
941
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
942
+ ) * mask
943
+ return incremental_indices.long() + padding_idx
944
+
945
+
946
+ # POOLING_HEAD
947
+ class MaskedConv1d(nn.Conv1d):
948
+ """A masked 1-dimensional convolution layer.
949
+
950
+ Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically.
951
+
952
+ Shape:
953
+ Input: (N, L, in_channels)
954
+ input_mask: (N, L, 1), optional
955
+ Output: (N, L, out_channels)
956
+ """
957
+
958
+ def __init__(
959
+ self,
960
+ in_channels: int,
961
+ out_channels: int,
962
+ kernel_size: int,
963
+ stride: int = 1,
964
+ dilation: int = 1,
965
+ groups: int = 1,
966
+ bias: bool = True,
967
+ ):
968
+ """
969
+ :param in_channels: input channels
970
+ :param out_channels: output channels
971
+ :param kernel_size: the kernel width
972
+ :param stride: filter shift
973
+ :param dilation: dilation factor
974
+ :param groups: perform depth-wise convolutions
975
+ :param bias: adds learnable bias to output
976
+ """
977
+ padding = dilation * (kernel_size - 1) // 2
978
+ super().__init__(
979
+ in_channels,
980
+ out_channels,
981
+ kernel_size,
982
+ stride=stride,
983
+ dilation=dilation,
984
+ groups=groups,
985
+ bias=bias,
986
+ padding=padding,
987
+ )
988
+
989
+ def forward(self, x, input_mask=None):
990
+ if input_mask is not None:
991
+ x = x * input_mask
992
+ return super().forward(x.transpose(1, 2)).transpose(1, 2)
993
+
994
+
995
+ class Attention1d(nn.Module):
996
+ def __init__(self, config):
997
+ super().__init__()
998
+ self.layer = MaskedConv1d(config.hidden_size, 1, 1)
999
+ self.out = nn.Linear(config.hidden_size, config.hidden_size)
1000
+
1001
+ def forward(self, x, input_mask=None):
1002
+ batch_szie = x.shape[0]
1003
+ attn = self.layer(x)
1004
+ attn = attn.view(batch_szie, -1)
1005
+ if input_mask is not None:
1006
+ attn = attn.masked_fill_(
1007
+ ~input_mask.view(batch_szie, -1).bool(), float("-inf")
1008
+ )
1009
+ attn = F.softmax(attn, dim=-1).view(batch_szie, -1, 1)
1010
+ out = (attn * x).sum(dim=1)
1011
+ out = self.out(out)
1012
+ return out
1013
+
1014
+
1015
+ class FFN1d(nn.Module):
1016
+ def __init__(self, config):
1017
+ super().__init__()
1018
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
1019
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
1020
+ self.act = nn.GELU()
1021
+
1022
+ def forward(self, x):
1023
+ x = self.fc1(x)
1024
+ x = self.act(x)
1025
+ x = self.fc2(x)
1026
+ return x
1027
+
1028
+
1029
+ class Attention1dPooling(nn.Module):
1030
+ """Outputs of the model with the attention1d"""
1031
+
1032
+ def __init__(
1033
+ self, config
1034
+ ): # [batch x sequence(751) x embedding (1280)] --> [batch x embedding] --> [batch x 1]
1035
+ super(Attention1dPooling, self).__init__()
1036
+ self.attention1d = Attention1d(config)
1037
+ self.ffn = FFN1d(config)
1038
+ # self.norm1 = nn.BatchNorm1d(config.hidden_size)
1039
+ # self.norm2 = nn.BatchNorm1d(config.hidden_size)
1040
+ self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
1041
+ self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
1042
+
1043
+ def forward(self, x, input_mask):
1044
+ attn_out = self.attention1d(x, input_mask=input_mask.unsqueeze(-1))
1045
+ x = self.dropout1(attn_out)
1046
+ # x = self.norm1(x)
1047
+ ffn_out = self.ffn(x)
1048
+ x = x + self.dropout2(ffn_out)
1049
+ # x = self.norm2(x)
1050
+ return x
1051
+
1052
+
1053
+ @dataclass
1054
+ class MaskedLMOutput(ModelOutput):
1055
+ loss: Optional[torch.FloatTensor] = None
1056
+ logits: torch.FloatTensor = None
1057
+ sequence_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
1058
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
1059
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
1060
+
1061
+
1062
+ class ProPrimeMV(ProPrimePreTrainedModel):
1063
+ _tied_weights_keys = ["lm_head.decoder.weight"]
1064
+
1065
+ def __init__(self, config):
1066
+ super().__init__(config)
1067
+ self.pro_prime = ProPrimeModel(config, add_pooling_layer=False)
1068
+ self.lm_head = ProPrimeLMHead(config)
1069
+ self.sequence_pooling = Attention1dPooling(config)
1070
+ self.value_projection = nn.Sequential(
1071
+ nn.Linear(config.hidden_size, config.hidden_size),
1072
+ nn.Tanh(),
1073
+ nn.Linear(config.hidden_size, 1),
1074
+ )
1075
+ self.init_weights()
1076
+
1077
+ def get_input_embeddings(self):
1078
+ return self.pro_prime.embeddings.word_embeddings
1079
+
1080
+ def get_output_embeddings(self):
1081
+ return self.lm_head.decoder
1082
+
1083
+ def set_output_embeddings(self, new_embeddings):
1084
+ self.lm_head.decoder = new_embeddings
1085
+
1086
+ def forward(
1087
+ self,
1088
+ input_ids: Optional[torch.LongTensor] = None,
1089
+ attention_mask: Optional[torch.Tensor] = None,
1090
+ position_ids: Optional[torch.LongTensor] = None,
1091
+ head_mask: Optional[torch.Tensor] = None,
1092
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1093
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1094
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1095
+ labels: Optional[torch.LongTensor] = None,
1096
+ values: Optional[torch.FloatTensor] = None,
1097
+ output_attentions: Optional[bool] = None,
1098
+ output_hidden_states: Optional[bool] = None,
1099
+ return_dict: Optional[bool] = None,
1100
+ ) -> Union[Tuple, MaskedLMOutput]:
1101
+ return_dict = (
1102
+ return_dict if return_dict is not None else self.config.use_return_dict
1103
+ )
1104
+
1105
+ outputs = self.pro_prime(
1106
+ input_ids,
1107
+ attention_mask=attention_mask,
1108
+ position_ids=position_ids,
1109
+ head_mask=head_mask,
1110
+ inputs_embeds=inputs_embeds,
1111
+ encoder_hidden_states=encoder_hidden_states,
1112
+ encoder_attention_mask=encoder_attention_mask,
1113
+ output_attentions=output_attentions,
1114
+ output_hidden_states=output_hidden_states,
1115
+ return_dict=return_dict,
1116
+ )
1117
+ sequence_output = outputs[0]
1118
+ prediction_scores = self.lm_head(sequence_output)
1119
+
1120
+ masked_lm_loss = None
1121
+ if labels is not None:
1122
+ loss_fct = CrossEntropyLoss()
1123
+
1124
+ labels = labels.to(prediction_scores.device)
1125
+ masked_lm_loss = loss_fct(
1126
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1127
+ )
1128
+
1129
+ if not return_dict:
1130
+ output = (prediction_scores,) + outputs[2:]
1131
+ return (
1132
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1133
+ )
1134
+
1135
+ sequence_states = self.sequence_pooling(sequence_output, attention_mask)
1136
+ predicted_values = self.value_projection(sequence_states)
1137
+ values = values.to(predicted_values.dtype)
1138
+ values = values.reshape(-1, 1)
1139
+ value_loss = nn.MSELoss()(predicted_values, values)
1140
+
1141
+ return MaskedLMOutput(
1142
+ loss=masked_lm_loss + value_loss,
1143
+ logits=prediction_scores,
1144
+ hidden_states=outputs.hidden_states,
1145
+ sequence_hidden_states=sequence_states,
1146
+ attentions=outputs.attentions,
1147
+ )
1148
+
1149
+
1150
+ ProPrimeModel.register_for_auto_class("AutoModel")
1151
+ ProPrimeForMaskedLM.register_for_auto_class("AutoModelForMaskedLM")
1152
+
1153
+
1154
+ if __name__ == "__main__":
1155
+ from ProPrime.tokenization_proprime import ProPrimeTokenizer
1156
+ from transformers.models.esm import EsmForMaskedLM
1157
+
1158
+ tokenizer = ProPrimeTokenizer("ProPrime/vocab.txt")
1159
+ config = ProPrimeConfig()
1160
+ model = ProPrimeMV(config)
1161
+ model.eval()
1162
+ s = [
1163
+ "MSFSHJGIOSJGKLOSJGSLKJWRPRQR",
1164
+ "MSRPRQR",
1165
+ "MSFSHJKLOSJGSLKJWRPRQR",
1166
+ "MSFSHJGIOSJGKLOSJG",
1167
+ ]
1168
+ input_ids = tokenizer(s, return_tensors="pt", padding=True).input_ids[:1, ]
1169
+ attention_mask = tokenizer(s, return_tensors="pt", padding=True).attention_mask[:1, ]
1170
+ values = torch.tensor([1.0, ])
1171
+
1172
+ print(
1173
+ model.forward(
1174
+ input_ids=input_ids,
1175
+ attention_mask=attention_mask,
1176
+ values=values,
1177
+ labels=input_ids,
1178
+ )
1179
+ )