GradientGuru commited on
Commit
a731bb0
1 Parent(s): d1816c6

cache alibi_mask to accelerate training

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +9 -3
modeling_baichuan.py CHANGED
@@ -249,7 +249,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
249
  self.gradient_checkpointing = config.gradient_checkpointing
250
  self.post_init()
251
  self.max_cache_pos = config.model_max_length
252
- self.first_run = True
 
253
 
254
  def get_input_embeddings(self):
255
  return self.embed_tokens
@@ -306,8 +307,13 @@ class BaichuanModel(BaichuanPreTrainedModel):
306
  if inputs_embeds is None:
307
  inputs_embeds = self.embed_tokens(input_ids)
308
 
 
 
 
 
 
 
309
 
310
- alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
311
  if attention_mask is not None:
312
  if len(attention_mask.shape) == 2:
313
  expanded_mask = attention_mask.to(alibi_mask.dtype)
@@ -597,4 +603,4 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
597
  self.__class__.generate = PreTrainedModel.generate # disable stream
598
  outputs = self.generate(input_ids, generation_config=generation_config)
599
  response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
600
- return response
 
249
  self.gradient_checkpointing = config.gradient_checkpointing
250
  self.post_init()
251
  self.max_cache_pos = config.model_max_length
252
+ self.first_run = True
253
+ self.alibi_mask = None
254
 
255
  def get_input_embeddings(self):
256
  return self.embed_tokens
 
307
  if inputs_embeds is None:
308
  inputs_embeds = self.embed_tokens(input_ids)
309
 
310
+ if self.training:
311
+ if self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past:
312
+ self.alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
313
+ alibi_mask = self.alibi_mask
314
+ else:
315
+ alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
316
 
 
317
  if attention_mask is not None:
318
  if len(attention_mask.shape) == 2:
319
  expanded_mask = attention_mask.to(alibi_mask.dtype)
 
603
  self.__class__.generate = PreTrainedModel.generate # disable stream
604
  outputs = self.generate(input_ids, generation_config=generation_config)
605
  response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
606
+ return response