Abhaykoul commited on
Commit
8a626aa
1 Parent(s): 21203f4

Delete modeling_HelpingAI.py

Browse files
Files changed (1) hide show
  1. modeling_HelpingAI.py +0 -670
modeling_HelpingAI.py DELETED
@@ -1,670 +0,0 @@
1
- """ HelpingAI model . """
2
- from typing import Optional, Tuple, Union
3
- import math
4
-
5
- import torch
6
- import torch.utils.checkpoint
7
- from transformers import AutoModel, AutoModelForCausalLM
8
- from torch import nn
9
- from torch.nn import CrossEntropyLoss
10
- from transformers.modeling_outputs import (
11
- BaseModelOutputWithPast,
12
- CausalLMOutputWithPast,
13
- )
14
- from transformers.modeling_utils import PreTrainedModel
15
- from transformers.utils import logging
16
- from .configuration_HelpingAI import HelpingAIConfig
17
-
18
-
19
- logger = logging.get_logger(__name__)
20
-
21
-
22
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
23
- def _make_causal_mask(
24
- input_ids_shape: torch.Size,
25
- dtype: torch.dtype,
26
- device: torch.device,
27
- past_key_values_length: int = 0,
28
- ):
29
- """Make causal mask used for bi-directional self-attention."""
30
- batch_size, tgt_len = input_ids_shape
31
- mask = torch.full((tgt_len, tgt_len), torch.finfo(torch.float16).min, device=device)
32
- mask_cond = torch.arange(mask.size(-1), device=device)
33
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
34
- mask = mask.to(dtype)
35
- if past_key_values_length > 0:
36
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
37
- return mask[None, None, :, :].expand(batch_size, 1, tgt_len, tgt_len + past_key_values_length)
38
-
39
-
40
- # Copied from transformers.models.bart.modeling_bart._expand_mask
41
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
42
- """Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, tgt_seq_len, src_seq_len]`."""
43
- batch_size, src_len = mask.size()
44
- tgt_len = tgt_len if tgt_len is not None else src_len
45
-
46
- expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype)
47
- inverted_mask = 1.0 - expanded_mask
48
-
49
- return inverted_mask.masked_fill(
50
- inverted_mask.to(torch.bool), torch.finfo(dtype).min
51
- )
52
-
53
-
54
- class RotaryEmbedding(nn.Module):
55
- def __init__(
56
- self,
57
- dim: int,
58
- max_position_embeddings: int,
59
- base: int = 10_000,
60
- device: Optional[torch.device] = None,
61
- ):
62
- super().__init__()
63
-
64
- self.dim = dim
65
- self.max_position_embeddings = max_position_embeddings
66
- self.base = base
67
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
68
- self.register_buffer("inv_freq", inv_freq, persistent=False)
69
-
70
- # Build here to make `torch.jit.trace` work.
71
- self._set_cos_sin_cache(
72
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype(),
73
- )
74
-
75
- def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
76
- self.max_seq_len_cached = seq_len
77
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
78
-
79
- # Don't do einsum, it converts fp32 to fp16 under AMP
80
- # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
81
- freqs = torch.outer(t, self.inv_freq)
82
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
83
- emb = torch.cat((freqs, freqs), dim=-1)
84
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
85
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
86
-
87
- def forward(self, x: torch.Tensor, seq_len: Optional[int] = None):
88
- # x: [batch_size, num_heads, seq_len, head_size]
89
- if seq_len > self.max_seq_len_cached:
90
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.get_default_dtype())
91
- return (
92
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
93
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
94
- )
95
-
96
-
97
- def rotate_half(x: torch.Tensor):
98
- """Rotates half the hidden dims of the input."""
99
- x1, x2 = torch.chunk(x, 2, dim=-1)
100
- return torch.cat((-x2, x1), dim=-1)
101
-
102
-
103
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
104
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
105
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
106
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
107
- cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
108
- sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
109
- q_embed = (q * cos) + (rotate_half(q) * sin)
110
- k_embed = (k * cos) + (rotate_half(k) * sin)
111
- return q_embed, k_embed
112
-
113
-
114
- class MLP(nn.Module):
115
- def __init__(self, config: HelpingAIConfig):
116
- super().__init__()
117
- self.config = config
118
- self.hidden_size = config.hidden_size
119
- self.intermediate_size = config.intermediate_size
120
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
121
- self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
122
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
123
- self.act_fn = nn.SiLU()
124
-
125
- def forward(self, x: torch.Tensor) -> torch.Tensor:
126
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
127
-
128
-
129
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
130
- """
131
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
132
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
133
- """
134
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
135
- if n_rep == 1:
136
- return hidden_states
137
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
138
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
139
-
140
-
141
- class Attention(nn.Module):
142
- def __init__(self, config: HelpingAIConfig):
143
- super().__init__()
144
- self.config = config
145
- self.hidden_size = config.hidden_size
146
- self.num_heads = config.num_attention_heads
147
- self.head_dim = self.hidden_size // self.num_heads
148
- self.num_key_value_heads = config.num_key_value_heads
149
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
150
- self.max_position_embeddings = config.max_position_embeddings
151
-
152
- if (self.head_dim * self.num_heads) != self.hidden_size:
153
- raise ValueError(
154
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
155
- f" and `num_heads`: {self.num_heads})."
156
- )
157
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
158
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
159
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
160
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
161
-
162
- self._init_rope()
163
-
164
- def _init_rope(self):
165
- self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
166
- self.rotary_emb = RotaryEmbedding(
167
- self.rotary_ndims,
168
- max_position_embeddings=self.config.max_position_embeddings,
169
- base=self.config.rope_theta,
170
- )
171
-
172
- def forward(
173
- self,
174
- hidden_states: torch.FloatTensor,
175
- attention_mask: torch.FloatTensor,
176
- position_ids: torch.LongTensor,
177
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
178
- output_attentions: Optional[bool] = False,
179
- use_cache: Optional[bool] = False,
180
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
181
- bsz, q_len, _ = hidden_states.size()
182
-
183
- query_states = self.q_proj(hidden_states)
184
- key_states = self.k_proj(hidden_states)
185
- value_states = self.v_proj(hidden_states)
186
-
187
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
188
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
189
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
190
-
191
- query_rot = query_states[..., : self.rotary_ndims]
192
- query_pass = query_states[..., self.rotary_ndims :]
193
- key_rot = key_states[..., : self.rotary_ndims]
194
- key_pass = key_states[..., self.rotary_ndims :]
195
-
196
- kv_seq_len = key_states.shape[-2]
197
- if past_key_value is not None:
198
- kv_seq_len += past_key_value[0].shape[-2]
199
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
200
- query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
201
-
202
- # [batch_size, num_heads, seq_len, head_dim]
203
- query_states = torch.cat((query_states, query_pass), dim=-1)
204
- key_states = torch.cat((key_states, key_pass), dim=-1)
205
-
206
- if past_key_value is not None:
207
- # Reuse k, v, self_attention
208
- key_states = torch.cat((past_key_value[0], key_states), dim=2)
209
- value_states = torch.cat((past_key_value[1], value_states), dim=2)
210
-
211
- past_key_value = (key_states, value_states) if use_cache else None
212
-
213
- # Repeat k/v heads if n_kv_heads < n_heads
214
- key_states = repeat_kv(key_states, self.num_key_value_groups)
215
- value_states = repeat_kv(value_states, self.num_key_value_groups)
216
-
217
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
218
-
219
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
220
- raise ValueError(
221
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
222
- f" {attn_weights.size()}"
223
- )
224
-
225
- if attention_mask is not None:
226
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
227
- raise ValueError(
228
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
229
- )
230
- attn_weights = attn_weights + attention_mask
231
-
232
- # Upcast attention to fp32
233
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
234
- attn_output = torch.matmul(attn_weights, value_states)
235
-
236
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
237
- raise ValueError(
238
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
239
- f" {attn_output.size()}"
240
- )
241
-
242
- # Merge heads
243
- attn_output = attn_output.transpose(1, 2).contiguous()
244
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
245
-
246
- # Final linear projection
247
- attn_output = self.o_proj(attn_output)
248
-
249
- if not output_attentions:
250
- attn_weights = None
251
-
252
- return attn_output, attn_weights, past_key_value
253
-
254
-
255
- class DecoderLayer(nn.Module):
256
- def __init__(self, config: HelpingAIConfig):
257
- super().__init__()
258
- self.self_attn = Attention(config)
259
- self.mlp = MLP(config)
260
- self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
261
- self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
262
-
263
- def forward(
264
- self,
265
- hidden_states: Optional[torch.FloatTensor],
266
- attention_mask: Optional[torch.FloatTensor] = None,
267
- position_ids: Optional[torch.LongTensor] = None,
268
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
269
- output_attentions: Optional[bool] = False,
270
- use_cache: Optional[bool] = False,
271
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
272
- residual = hidden_states
273
-
274
- hidden_states = self.input_layernorm(hidden_states)
275
-
276
- # Self Attention
277
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
278
- hidden_states=hidden_states,
279
- attention_mask=attention_mask,
280
- position_ids=position_ids,
281
- past_key_value=past_key_value,
282
- output_attentions=output_attentions,
283
- use_cache=use_cache,
284
- )
285
- hidden_states = residual + hidden_states
286
-
287
- # Fully Connected
288
- residual = hidden_states
289
- hidden_states = self.post_attention_layernorm(hidden_states)
290
- hidden_states = self.mlp(hidden_states)
291
- hidden_states = residual + hidden_states
292
-
293
- outputs = (hidden_states,)
294
-
295
- if output_attentions:
296
- outputs += (self_attn_weights,)
297
-
298
- if use_cache:
299
- outputs += (present_key_value,)
300
-
301
- return outputs
302
-
303
-
304
- class HelpingAIPreTrainedModel(PreTrainedModel):
305
- """An abstract class to handle weights initialization and a simple interface
306
- for downloading and loading pretrained models.
307
- """
308
-
309
- config_class = HelpingAIConfig
310
- base_model_prefix = "transformer"
311
- supports_gradient_checkpointing = True
312
- _no_split_modules = ["DecoderLayer"]
313
- _skip_keys_device_placement = "past_key_values"
314
-
315
- def _init_weights(self, module: nn.Module):
316
- """Initialize the weights"""
317
- if isinstance(module, nn.Linear):
318
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
319
- if module.bias is not None:
320
- module.bias.data.zero_()
321
- elif isinstance(module, nn.Embedding):
322
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
323
- if module.padding_idx is not None:
324
- module.weight.data[module.padding_idx].zero_()
325
- elif isinstance(module, nn.LayerNorm):
326
- module.bias.data.zero_()
327
- module.weight.data.fill_(1.0)
328
-
329
- def _set_gradient_checkpointing(self, module: nn.Module, value=False):
330
- if isinstance(module, HelpingAIModel):
331
- module.gradient_checkpointing = value
332
-
333
-
334
- class HelpingAIModel(HelpingAIPreTrainedModel):
335
- def __init__(self, config: HelpingAIConfig):
336
- super().__init__(config)
337
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
338
- self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
339
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
340
-
341
- self.gradient_checkpointing = False
342
- # Initialize weights and apply final processing
343
- self.post_init()
344
-
345
- def get_input_embeddings(self):
346
- return self.embed_tokens
347
-
348
- def set_input_embeddings(self, value: nn.Module):
349
- self.embed_tokens = value
350
-
351
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
352
- def _prepare_decoder_attention_mask(
353
- self,
354
- attention_mask: torch.Tensor,
355
- input_shape: torch.Size,
356
- inputs_embeds: torch.Tensor,
357
- past_key_values_length: int,
358
- ):
359
- # Create causal mask
360
- # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
361
- combined_attention_mask = None
362
- if input_shape[-1] > 1:
363
- combined_attention_mask = _make_causal_mask(
364
- input_shape,
365
- inputs_embeds.dtype,
366
- device=inputs_embeds.device,
367
- past_key_values_length=past_key_values_length,
368
- )
369
-
370
- if attention_mask is not None:
371
- # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
372
- expanded_attn_mask = _expand_mask(
373
- attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
374
- ).to(inputs_embeds.device)
375
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
376
-
377
- return combined_attention_mask
378
-
379
- def forward(
380
- self,
381
- input_ids: Optional[torch.LongTensor] = None,
382
- attention_mask: Optional[torch.FloatTensor] = None,
383
- position_ids: Optional[torch.LongTensor] = None,
384
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
385
- inputs_embeds: Optional[torch.FloatTensor] = None,
386
- use_cache: Optional[bool] = None,
387
- output_attentions: Optional[bool] = None,
388
- output_hidden_states: Optional[bool] = None,
389
- return_dict: Optional[bool] = None,
390
- ) -> Union[Tuple, BaseModelOutputWithPast]:
391
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
392
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
393
- use_cache = use_cache if use_cache is not None else self.config.use_cache
394
-
395
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
396
-
397
- # Retrieve input_ids and inputs_embeds
398
- if input_ids is not None and inputs_embeds is not None:
399
- raise ValueError(
400
- "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
401
- )
402
- elif input_ids is not None:
403
- batch_size, seq_length = input_ids.shape
404
- elif inputs_embeds is not None:
405
- batch_size, seq_length, _ = inputs_embeds.shape
406
- else:
407
- raise ValueError(
408
- "You have to specify either decoder_input_ids or decoder_inputs_embeds"
409
- )
410
-
411
- seq_length_with_past = seq_length
412
- past_key_values_length = 0
413
-
414
- if past_key_values is not None:
415
- past_key_values_length = past_key_values[0][0].shape[2]
416
- seq_length_with_past = seq_length_with_past + past_key_values_length
417
-
418
- if position_ids is None:
419
- device = input_ids.device if input_ids is not None else inputs_embeds.device
420
- position_ids = torch.arange(
421
- past_key_values_length,
422
- seq_length + past_key_values_length,
423
- dtype=torch.long,
424
- device=device,
425
- )
426
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
427
- else:
428
- position_ids = position_ids.view(-1, seq_length).long()
429
-
430
- if inputs_embeds is None:
431
- inputs_embeds = self.embed_tokens(input_ids)
432
- # Embed positions
433
- if attention_mask is None:
434
- attention_mask = torch.ones(
435
- (batch_size, seq_length_with_past),
436
- dtype=torch.bool,
437
- device=inputs_embeds.device,
438
- )
439
- attention_mask = self._prepare_decoder_attention_mask(
440
- attention_mask,
441
- (batch_size, seq_length),
442
- inputs_embeds,
443
- past_key_values_length,
444
- )
445
-
446
- hidden_states = inputs_embeds
447
-
448
- if self.gradient_checkpointing and self.training:
449
- if use_cache:
450
- logger.warning(
451
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
452
- )
453
- use_cache = False
454
-
455
- # Decoder layers
456
- all_hidden_states = () if output_hidden_states else None
457
- all_self_attns = () if output_attentions else None
458
- next_decoder_cache = () if use_cache else None
459
-
460
- for idx, decoder_layer in enumerate(self.layers):
461
- if output_hidden_states:
462
- all_hidden_states += (hidden_states,)
463
-
464
- past_key_value = (
465
- past_key_values[idx] if past_key_values is not None else None
466
- )
467
-
468
- if self.gradient_checkpointing and self.training:
469
-
470
- def create_custom_forward(module):
471
- def custom_forward(*inputs):
472
- # None for past_key_value
473
- return module(*inputs, past_key_value, output_attentions)
474
-
475
- return custom_forward
476
-
477
- layer_outputs = torch.utils.checkpoint.checkpoint(
478
- create_custom_forward(decoder_layer),
479
- hidden_states,
480
- attention_mask,
481
- position_ids,
482
- )
483
- else:
484
- layer_outputs = decoder_layer(
485
- hidden_states,
486
- attention_mask=attention_mask,
487
- position_ids=position_ids,
488
- past_key_value=past_key_value,
489
- output_attentions=output_attentions,
490
- use_cache=use_cache,
491
- )
492
-
493
- hidden_states = layer_outputs[0]
494
-
495
- if use_cache:
496
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
497
-
498
- if output_attentions:
499
- all_self_attns += (layer_outputs[1],)
500
-
501
- hidden_states = self.norm(hidden_states)
502
-
503
- # Add hidden states from the last decoder layer
504
- if output_hidden_states:
505
- all_hidden_states += (hidden_states,)
506
-
507
- next_cache = next_decoder_cache if use_cache else None
508
- if not return_dict:
509
- return tuple(
510
- v
511
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
512
- if v is not None
513
- )
514
- return BaseModelOutputWithPast(
515
- last_hidden_state=hidden_states,
516
- past_key_values=next_cache,
517
- hidden_states=all_hidden_states,
518
- attentions=all_self_attns,
519
- )
520
-
521
-
522
- class HelpingAIForCausalLM(HelpingAIPreTrainedModel):
523
- _tied_weights_keys = ["lm_head.weight"]
524
-
525
- def __init__(self, config: HelpingAIConfig):
526
- super().__init__(config)
527
-
528
- self.model = HelpingAIModel(config)
529
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
530
-
531
- # Initialize weights and apply final processing
532
- self.post_init()
533
-
534
- def get_input_embeddings(self):
535
- return self.model.embed_tokens
536
-
537
- def set_input_embeddings(self, value):
538
- self.model.embed_tokens = value
539
-
540
- def get_output_embeddings(self):
541
- return self.lm_head
542
-
543
- def set_output_embeddings(self, new_embeddings: nn.Module):
544
- self.lm_head = new_embeddings
545
-
546
- def get_decoder(self):
547
- return self.transformer
548
-
549
- def set_decoder(self, decoder):
550
- self.transformer = decoder
551
-
552
- def forward(
553
- self,
554
- input_ids: Optional[torch.LongTensor] = None,
555
- attention_mask: Optional[torch.FloatTensor] = None,
556
- position_ids: Optional[torch.LongTensor] = None,
557
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
558
- inputs_embeds: Optional[torch.FloatTensor] = None,
559
- labels: Optional[torch.LongTensor] = None,
560
- use_cache: Optional[bool] = None,
561
- output_attentions: Optional[bool] = None,
562
- output_hidden_states: Optional[bool] = None,
563
- return_dict: Optional[bool] = None,
564
- ) -> Union[Tuple, CausalLMOutputWithPast]:
565
- output_attentions = (
566
- output_attentions
567
- if output_attentions is not None
568
- else self.config.output_attentions
569
- )
570
- output_hidden_states = (
571
- output_hidden_states
572
- if output_hidden_states is not None
573
- else self.config.output_hidden_states
574
- )
575
- return_dict = (
576
- return_dict if return_dict is not None else self.config.use_return_dict
577
- )
578
-
579
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
580
- outputs = self.model(
581
- input_ids,
582
- attention_mask=attention_mask,
583
- position_ids=position_ids,
584
- past_key_values=past_key_values,
585
- inputs_embeds=inputs_embeds,
586
- use_cache=use_cache,
587
- output_attentions=output_attentions,
588
- output_hidden_states=output_hidden_states,
589
- return_dict=return_dict,
590
- )
591
-
592
- hidden_states = outputs[0]
593
- logits = self.lm_head(hidden_states).float()
594
-
595
- loss = None
596
- if labels is not None:
597
- # Shift so that tokens < n predict n
598
- shift_logits = logits[..., :-1, :].contiguous()
599
- shift_labels = labels[..., 1:].contiguous()
600
- # Flatten the tokens
601
- loss_fct = CrossEntropyLoss()
602
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
603
- shift_labels = shift_labels.view(-1)
604
- # Enable model parallelism
605
- shift_labels = shift_labels.to(shift_logits.device)
606
- loss = loss_fct(shift_logits, shift_labels)
607
-
608
- if not return_dict:
609
- output = (logits,) + outputs[1:]
610
- return (loss,) + output if loss is not None else output
611
-
612
- return CausalLMOutputWithPast(
613
- loss=loss,
614
- logits=logits,
615
- past_key_values=outputs.past_key_values,
616
- hidden_states=outputs.hidden_states,
617
- attentions=outputs.attentions,
618
- )
619
-
620
- def prepare_inputs_for_generation(
621
- self,
622
- input_ids,
623
- past_key_values: Optional[torch.Tensor] = None,
624
- attention_mask: Optional[torch.Tensor] = None,
625
- inputs_embeds: Optional[torch.Tensor] = None,
626
- **kwargs,
627
- ):
628
- # Trim decoder_input_ids if past is used
629
- if past_key_values and past_key_values[0] is not None:
630
- input_ids = input_ids[:, -1:]
631
-
632
- position_ids = kwargs.get("position_ids", None)
633
- if attention_mask is not None and position_ids is None:
634
- # Create position_ids on the fly for batch generation
635
- position_ids = attention_mask.long().cumsum(-1) - 1
636
- position_ids.masked_fill_(attention_mask == 0, 1)
637
- if past_key_values:
638
- position_ids = position_ids[:, -1].unsqueeze(-1)
639
-
640
- # If `inputs_embeds` are passed, we only want to use them in the 1st generation step
641
- if inputs_embeds is not None and past_key_values is None:
642
- model_inputs = {"inputs_embeds": inputs_embeds}
643
- else:
644
- model_inputs = {"input_ids": input_ids}
645
-
646
- model_inputs.update(
647
- {
648
- "attention_mask": attention_mask,
649
- "past_key_values": past_key_values,
650
- "use_cache": kwargs.get("use_cache"),
651
- "position_ids": position_ids,
652
- }
653
- )
654
- return model_inputs
655
-
656
- @staticmethod
657
- def _reorder_cache(past_key_values, beam_idx):
658
- reordered_past = ()
659
- for layer_past in past_key_values:
660
- reordered_past += (
661
- tuple(
662
- past_state.index_select(0, beam_idx.to(past_state.device))
663
- for past_state in layer_past
664
- ),
665
- )
666
- return reordered_past
667
-
668
-
669
- HelpingAIConfig.register_for_auto_class()
670
- HelpingAIForCausalLM.register_for_auto_class("AutoModelForCausalLM")