Update modeling_zhinao.py

#2
Files changed (1) hide show
  1. modeling_zhinao.py +12 -13
modeling_zhinao.py CHANGED
@@ -748,6 +748,17 @@ class ZhinaoForCausalLM(ZhinaoPreTrainedModel):
748
 
749
  def __init__(self, config):
750
  super().__init__(config)
 
 
 
 
 
 
 
 
 
 
 
751
  self.model = ZhinaoModel(config)
752
  self.vocab_size = config.vocab_size
753
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@@ -761,19 +772,7 @@ class ZhinaoForCausalLM(ZhinaoPreTrainedModel):
761
  if config.fp16:
762
  self.model.half()
763
  self.lm_head.half()
764
- self.linear.half()
765
-
766
- if config.use_flash_attn == "auto":
767
- if flash_attn_varlen_func:
768
- if config.bf16 or config.fp16:
769
- logger.warn("Try importing flash-attention.")
770
- config.use_flash_attn = True
771
- else:
772
- config.use_flash_attn = False
773
- logger.warn("Flash attention will be disabled because it does NOT support fp32.")
774
- else:
775
- config.use_flash_attn = False
776
- logger.warn("Please install FlashAttention first, " "e.g., with pip install flash-attn")
777
 
778
  self.post_init()
779
 
 
748
 
749
  def __init__(self, config):
750
  super().__init__(config)
751
+ if config.use_flash_attn == "auto":
752
+ if flash_attn_varlen_func:
753
+ if config.bf16 or config.fp16:
754
+ logger.warn("Try importing flash-attention.")
755
+ config.use_flash_attn = True
756
+ else:
757
+ config.use_flash_attn = False
758
+ logger.warn("Flash attention will be disabled because it does NOT support fp32.")
759
+ else:
760
+ config.use_flash_attn = False
761
+ logger.warn("Please install FlashAttention first, " "e.g., with pip install flash-attn")
762
  self.model = ZhinaoModel(config)
763
  self.vocab_size = config.vocab_size
764
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
772
  if config.fp16:
773
  self.model.half()
774
  self.lm_head.half()
775
+ self.linear.half()
 
 
 
 
 
 
 
 
 
 
 
 
776
 
777
  self.post_init()
778