wuzhiying
commited on
Commit
•
4d1208f
1
Parent(s):
ddad89a
sync base to chat
Browse files- modeling_baichuan.py +14 -14
modeling_baichuan.py
CHANGED
@@ -528,7 +528,6 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
528 |
self.model = BaichuanModel(config)
|
529 |
|
530 |
self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
|
531 |
-
#if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']:
|
532 |
if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False):
|
533 |
try:
|
534 |
from .quantizer import quantize_offline, init_model_weight_int4
|
@@ -609,22 +608,23 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
609 |
model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin')
|
610 |
state_dict = torch.load(model_file, map_location="cpu")
|
611 |
model.is_quantized = True
|
612 |
-
|
613 |
device_map = kwargs.pop("device_map", None)
|
614 |
torch_dtype = kwargs.pop("torch_dtype", None)
|
615 |
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
|
|
628 |
model = init_model_weight_int4(config, model, state_dict)
|
629 |
|
630 |
# Set model in evaluation mode to deactivate DropOut modules by default
|
|
|
528 |
self.model = BaichuanModel(config)
|
529 |
|
530 |
self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
|
|
|
531 |
if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False):
|
532 |
try:
|
533 |
from .quantizer import quantize_offline, init_model_weight_int4
|
|
|
608 |
model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin')
|
609 |
state_dict = torch.load(model_file, map_location="cpu")
|
610 |
model.is_quantized = True
|
611 |
+
|
612 |
device_map = kwargs.pop("device_map", None)
|
613 |
torch_dtype = kwargs.pop("torch_dtype", None)
|
614 |
|
615 |
+
if device_map is not None:
|
616 |
+
kwargs = {"no_split_module_classes": model._no_split_modules}
|
617 |
+
target_dtype = CustomDtype.INT4
|
618 |
+
max_memory = get_balanced_memory(
|
619 |
+
model,
|
620 |
+
dtype=target_dtype,
|
621 |
+
low_zero=(device_map == "balanced_low_0"),
|
622 |
+
max_memory=None,
|
623 |
+
**kwargs,
|
624 |
+
)
|
625 |
+
kwargs["max_memory"] = max_memory
|
626 |
+
device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
|
627 |
+
|
628 |
model = init_model_weight_int4(config, model, state_dict)
|
629 |
|
630 |
# Set model in evaluation mode to deactivate DropOut modules by default
|