GradientGuru
commited on
Commit
•
a731bb0
1
Parent(s):
d1816c6
cache alibi_mask to accelerate training
Browse files- modeling_baichuan.py +9 -3
modeling_baichuan.py
CHANGED
@@ -249,7 +249,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
249 |
self.gradient_checkpointing = config.gradient_checkpointing
|
250 |
self.post_init()
|
251 |
self.max_cache_pos = config.model_max_length
|
252 |
-
self.first_run = True
|
|
|
253 |
|
254 |
def get_input_embeddings(self):
|
255 |
return self.embed_tokens
|
@@ -306,8 +307,13 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
306 |
if inputs_embeds is None:
|
307 |
inputs_embeds = self.embed_tokens(input_ids)
|
308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
-
alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
|
311 |
if attention_mask is not None:
|
312 |
if len(attention_mask.shape) == 2:
|
313 |
expanded_mask = attention_mask.to(alibi_mask.dtype)
|
@@ -597,4 +603,4 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
597 |
self.__class__.generate = PreTrainedModel.generate # disable stream
|
598 |
outputs = self.generate(input_ids, generation_config=generation_config)
|
599 |
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
600 |
-
return response
|
|
|
249 |
self.gradient_checkpointing = config.gradient_checkpointing
|
250 |
self.post_init()
|
251 |
self.max_cache_pos = config.model_max_length
|
252 |
+
self.first_run = True
|
253 |
+
self.alibi_mask = None
|
254 |
|
255 |
def get_input_embeddings(self):
|
256 |
return self.embed_tokens
|
|
|
307 |
if inputs_embeds is None:
|
308 |
inputs_embeds = self.embed_tokens(input_ids)
|
309 |
|
310 |
+
if self.training:
|
311 |
+
if self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past:
|
312 |
+
self.alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
|
313 |
+
alibi_mask = self.alibi_mask
|
314 |
+
else:
|
315 |
+
alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
|
316 |
|
|
|
317 |
if attention_mask is not None:
|
318 |
if len(attention_mask.shape) == 2:
|
319 |
expanded_mask = attention_mask.to(alibi_mask.dtype)
|
|
|
603 |
self.__class__.generate = PreTrainedModel.generate # disable stream
|
604 |
outputs = self.generate(input_ids, generation_config=generation_config)
|
605 |
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
606 |
+
return response
|