Comet line 296-297 to remove self.model_parallel

#72
by lbwavebo - opened
Files changed (1) hide show
  1. modeling_mpt.py +79 -51
modeling_mpt.py CHANGED
@@ -12,52 +12,47 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokeniz
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
- from .custom_embedding import SharedEmbedding
16
  from .norm import NORM_CLASS_REGISTRY
17
  from .configuration_mpt import MPTConfig
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
19
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
20
  from .meta_init_context import init_empty_weights
21
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
22
- try:
23
- from .flash_attn_triton import flash_attn_func
24
- except:
25
- pass
26
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
27
 
28
  class MPTPreTrainedModel(PreTrainedModel):
29
  config_class = MPTConfig
30
  base_model_prefix = 'model'
31
- _no_split_modules = ['MPTBlock']
 
 
 
 
 
32
 
33
  class MPTModel(MPTPreTrainedModel):
34
 
35
  def __init__(self, config: MPTConfig):
36
  config._validate_config()
37
  super().__init__(config)
 
38
  self.attn_impl = config.attn_config['attn_impl']
39
  self.prefix_lm = config.attn_config['prefix_lm']
40
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
41
  self.alibi = config.attn_config['alibi']
42
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
43
- if config.init_device == 'mixed':
44
- if dist.get_local_rank() == 0:
45
- config.init_device = 'cpu'
46
- else:
47
- config.init_device = 'meta'
48
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
49
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
50
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
51
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
52
  self.embedding_fraction = config.embedding_fraction
53
- self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
54
  if not self.alibi:
55
- self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
56
  self.emb_drop = nn.Dropout(config.emb_pdrop)
57
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
58
  self.norm_f = norm_class(config.d_model, device=config.init_device)
59
  if config.init_device != 'meta':
60
- print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
61
  self.apply(self.param_init_fn)
62
  self.is_causal = not self.prefix_lm
63
  self._attn_bias_initialized = False
@@ -107,8 +102,7 @@ class MPTModel(MPTPreTrainedModel):
107
  if attn_bias is None:
108
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
109
  else:
110
- _s_k = max(0, attn_bias.size(-1) - s_k)
111
- attn_bias = attn_bias[:, :, :, _s_k:]
112
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
113
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
114
  min_val = torch.finfo(attn_bias.dtype).min
@@ -140,32 +134,57 @@ class MPTModel(MPTPreTrainedModel):
140
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
141
  return attn_bias
142
 
143
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None):
144
  return_dict = return_dict if return_dict is not None else self.config.return_dict
145
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  if attention_mask is not None:
147
  attention_mask = attention_mask.bool()
 
 
 
 
 
 
 
 
 
 
148
  if prefix_mask is not None:
149
  prefix_mask = prefix_mask.bool()
150
  if not return_dict:
151
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
152
  if output_attentions:
153
- if self.attn_impl != 'torch':
154
- raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
155
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
156
- raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
159
- if inputs_embeds is not None:
160
- raise NotImplementedError('inputs_embeds is not implemented for MPT.')
161
  if self.training:
162
  if self.attn_uses_sequence_id and sequence_id is None:
163
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
164
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
165
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
166
- S = input_ids.size(1)
167
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
168
- tok_emb = self.wte(input_ids)
169
  if self.alibi:
170
  x = tok_emb
171
  else:
@@ -174,12 +193,10 @@ class MPTModel(MPTPreTrainedModel):
174
  if len(past_key_values) != self.config.n_layers:
175
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
176
  past_position = past_key_values[0][0].size(1)
177
- if self.attn_impl == 'torch':
178
- past_position = past_key_values[0][0].size(3)
179
  if S + past_position > self.config.max_seq_len:
180
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
181
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
182
- if attention_mask is not None:
183
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
184
  pos_emb = self.wpe(pos)
185
  x = tok_emb + pos_emb
@@ -189,27 +206,41 @@ class MPTModel(MPTPreTrainedModel):
189
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
190
  assert isinstance(self.emb_drop, nn.Module)
191
  x = self.emb_drop(x_shrunk)
192
- (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
193
  if use_cache and past_key_values is None:
194
  past_key_values = [() for _ in range(self.config.n_layers)]
 
195
  all_hidden_states = () if output_hidden_states else None
196
- all_self_attns = () if output_attentions else None
197
  for (b_idx, block) in enumerate(self.blocks):
198
  if output_hidden_states:
199
  assert all_hidden_states is not None
200
  all_hidden_states = all_hidden_states + (x,)
201
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
202
- (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  if past_key_values is not None:
204
  past_key_values[b_idx] = past_key_value
205
- if output_attentions:
206
- assert all_self_attns is not None
207
- all_self_attns = all_self_attns + (attn_weights,)
208
  x = self.norm_f(x)
209
- if output_hidden_states:
210
- assert all_hidden_states is not None
211
- all_hidden_states = all_hidden_states + (x,)
212
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
213
 
214
  def param_init_fn(self, module):
215
  init_fn_name = self.config.init_config['name']
@@ -227,13 +258,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
227
  super().__init__(config)
228
  if not config.tie_word_embeddings:
229
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
230
- print(f'Instantiating an MPTForCausalLM model from {__file__}')
231
  self.transformer = MPTModel(config)
232
- for child in self.transformer.children():
233
- if isinstance(child, torch.nn.ModuleList):
234
- continue
235
- if isinstance(child, torch.nn.Module):
236
- child._fsdp_wrap = True
237
  self.logit_scale = None
238
  if config.logit_scale is not None:
239
  logit_scale = config.logit_scale
@@ -262,13 +287,16 @@ class MPTForCausalLM(MPTPreTrainedModel):
262
  def get_decoder(self):
263
  return self.transformer
264
 
265
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None):
266
  return_dict = return_dict if return_dict is not None else self.config.return_dict
267
  use_cache = use_cache if use_cache is not None else self.config.use_cache
268
- if inputs_embeds is not None:
269
- raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
270
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
271
- logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
 
 
 
272
  if self.logit_scale is not None:
273
  if self.logit_scale == 0:
274
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
@@ -278,7 +306,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
278
  labels = torch.roll(labels, shifts=-1)
279
  labels[:, -1] = -100
280
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
281
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
282
 
283
  def param_init_fn(self, module):
284
  init_fn_name = self.config.init_config['name']
@@ -320,4 +348,4 @@ class MPTForCausalLM(MPTPreTrainedModel):
320
  reordered_past = []
321
  for layer_past in past_key_values:
322
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
323
- return reordered_past
 
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
 
15
  from .norm import NORM_CLASS_REGISTRY
16
  from .configuration_mpt import MPTConfig
17
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
 
 
 
 
21
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
22
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
26
+ _no_split_modules = ["MPTBlock"]
27
+ supports_gradient_checkpointing = True
28
+
29
+ def _set_gradient_checkpointing(self, module, value=False):
30
+ if isinstance(module, MPTModel):
31
+ module.gradient_checkpointing = value
32
 
33
  class MPTModel(MPTPreTrainedModel):
34
 
35
  def __init__(self, config: MPTConfig):
36
  config._validate_config()
37
  super().__init__(config)
38
+ self.gradient_checkpointing = False
39
  self.attn_impl = config.attn_config['attn_impl']
40
  self.prefix_lm = config.attn_config['prefix_lm']
41
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
42
  self.alibi = config.attn_config['alibi']
43
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
 
 
 
 
 
44
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
45
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
46
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
47
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
48
  self.embedding_fraction = config.embedding_fraction
49
+ self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
50
  if not self.alibi:
51
+ self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
52
  self.emb_drop = nn.Dropout(config.emb_pdrop)
53
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
54
  self.norm_f = norm_class(config.d_model, device=config.init_device)
55
  if config.init_device != 'meta':
 
56
  self.apply(self.param_init_fn)
57
  self.is_causal = not self.prefix_lm
58
  self._attn_bias_initialized = False
 
102
  if attn_bias is None:
103
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
104
  else:
105
+ attn_bias = attn_bias[:, :, :, -s_k:]
 
106
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
107
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
108
  min_val = torch.finfo(attn_bias.dtype).min
 
134
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
135
  return attn_bias
136
 
137
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
138
  return_dict = return_dict if return_dict is not None else self.config.return_dict
139
  use_cache = use_cache if use_cache is not None else self.config.use_cache
140
+ if self.gradient_checkpointing and self.training:
141
+ if use_cache:
142
+ use_cache = False
143
+ if input_ids is not None and inputs_embeds is not None:
144
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
145
+ elif input_ids is not None:
146
+ batch_size, seq_length = input_ids.shape
147
+ elif inputs_embeds is not None:
148
+ batch_size, seq_length, _ = inputs_embeds.shape
149
+ else:
150
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
151
+
152
+ seq_length_with_past = seq_length
153
+ past_key_values_length = 0
154
+
155
+ if past_key_values is not None:
156
+ past_key_values_length = past_key_values[0][0].shape[2]
157
+ seq_length_with_past = seq_length_with_past + past_key_values_length
158
+
159
  if attention_mask is not None:
160
  attention_mask = attention_mask.bool()
161
+ else:
162
+ attention_mask = torch.ones(
163
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
164
+ )
165
+
166
+ if inputs_embeds is None:
167
+ tok_emb = self.wte(input_ids)
168
+ else:
169
+ tok_emb = inputs_embeds
170
+
171
  if prefix_mask is not None:
172
  prefix_mask = prefix_mask.bool()
173
  if not return_dict:
174
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
175
  if output_attentions:
176
+ raise NotImplementedError('output_attentions is not implemented yet for MPT')
177
+ #if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
178
+ # raise NotImplementedError('MPT does not support training with left padding.')
 
179
  if self.prefix_lm and prefix_mask is None:
180
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
 
181
  if self.training:
182
  if self.attn_uses_sequence_id and sequence_id is None:
183
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
184
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
185
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
186
+ S = seq_length
187
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
 
188
  if self.alibi:
189
  x = tok_emb
190
  else:
 
193
  if len(past_key_values) != self.config.n_layers:
194
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
195
  past_position = past_key_values[0][0].size(1)
 
 
196
  if S + past_position > self.config.max_seq_len:
197
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
198
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
199
+ if attention_mask is not None and not self.training:
200
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
201
  pos_emb = self.wpe(pos)
202
  x = tok_emb + pos_emb
 
206
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
207
  assert isinstance(self.emb_drop, nn.Module)
208
  x = self.emb_drop(x_shrunk)
209
+ (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
210
  if use_cache and past_key_values is None:
211
  past_key_values = [() for _ in range(self.config.n_layers)]
212
+
213
  all_hidden_states = () if output_hidden_states else None
 
214
  for (b_idx, block) in enumerate(self.blocks):
215
  if output_hidden_states:
216
  assert all_hidden_states is not None
217
  all_hidden_states = all_hidden_states + (x,)
218
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
219
+
220
+ if self.gradient_checkpointing and self.training:
221
+
222
+ def create_custom_forward(module):
223
+ def custom_forward(*inputs):
224
+ # None for past_key_value
225
+ return module(*inputs)
226
+
227
+ return custom_forward
228
+
229
+ (x, past_key_value) = torch.utils.checkpoint.checkpoint(
230
+ create_custom_forward(block),
231
+ x,
232
+ past_key_value,
233
+ attn_bias,
234
+ attention_mask,
235
+ self.is_causal,
236
+ )
237
+ else:
238
+ (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
239
+
240
  if past_key_values is not None:
241
  past_key_values[b_idx] = past_key_value
 
 
 
242
  x = self.norm_f(x)
243
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
 
 
 
244
 
245
  def param_init_fn(self, module):
246
  init_fn_name = self.config.init_config['name']
 
258
  super().__init__(config)
259
  if not config.tie_word_embeddings:
260
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
 
261
  self.transformer = MPTModel(config)
 
 
 
 
 
262
  self.logit_scale = None
263
  if config.logit_scale is not None:
264
  logit_scale = config.logit_scale
 
287
  def get_decoder(self):
288
  return self.transformer
289
 
290
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor] = None):
291
  return_dict = return_dict if return_dict is not None else self.config.return_dict
292
  use_cache = use_cache if use_cache is not None else self.config.use_cache
293
+ outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds)
294
+
295
+ last_hidden_state = outputs.last_hidden_state
296
+ # if self.model_parallel:
297
+ # last_hidden_state = last_hidden_state.to(self.transformer.wte.weight.device)
298
+ logits = F.linear(last_hidden_state, self.transformer.wte.weight)
299
+
300
  if self.logit_scale is not None:
301
  if self.logit_scale == 0:
302
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
306
  labels = torch.roll(labels, shifts=-1)
307
  labels[:, -1] = -100
308
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
309
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
310
 
311
  def param_init_fn(self, module):
312
  init_fn_name = self.config.init_config['name']
 
348
  reordered_past = []
349
  for layer_past in past_key_values:
350
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
351
+ return reordered_past