|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import importlib.metadata |
|
import json |
|
import os |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
from packaging import version |
|
|
|
from ..utils import is_torch_available, logging |
|
|
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class QuantizationMethod(str, Enum): |
|
BITS_AND_BYTES = "bitsandbytes" |
|
GPTQ = "gptq" |
|
|
|
|
|
@dataclass |
|
class QuantizationConfigMixin: |
|
""" |
|
Mixin class for quantization config |
|
""" |
|
|
|
quant_method: QuantizationMethod |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): |
|
""" |
|
Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters. |
|
|
|
Args: |
|
config_dict (`Dict[str, Any]`): |
|
Dictionary that will be used to instantiate the configuration object. |
|
return_unused_kwargs (`bool`,*optional*, defaults to `False`): |
|
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in |
|
`PreTrainedModel`. |
|
kwargs (`Dict[str, Any]`): |
|
Additional parameters from which to initialize the configuration object. |
|
|
|
Returns: |
|
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters. |
|
""" |
|
|
|
config = cls(**config_dict) |
|
|
|
to_remove = [] |
|
for key, value in kwargs.items(): |
|
if hasattr(config, key): |
|
setattr(config, key, value) |
|
to_remove.append(key) |
|
for key in to_remove: |
|
kwargs.pop(key, None) |
|
|
|
if return_unused_kwargs: |
|
return config, kwargs |
|
else: |
|
return config |
|
|
|
def to_json_file(self, json_file_path: Union[str, os.PathLike]): |
|
""" |
|
Save this instance to a JSON file. |
|
|
|
Args: |
|
json_file_path (`str` or `os.PathLike`): |
|
Path to the JSON file in which this configuration instance's parameters will be saved. |
|
use_diff (`bool`, *optional*, defaults to `True`): |
|
If set to `True`, only the difference between the config instance and the default |
|
`QuantizationConfig()` is serialized to JSON file. |
|
""" |
|
with open(json_file_path, "w", encoding="utf-8") as writer: |
|
config_dict = self.to_dict() |
|
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" |
|
|
|
writer.write(json_string) |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
""" |
|
Serializes this instance to a Python dictionary. Returns: |
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. |
|
""" |
|
return copy.deepcopy(self.__dict__) |
|
|
|
def __repr__(self): |
|
return f"{self.__class__.__name__} {self.to_json_string()}" |
|
|
|
def to_json_string(self, use_diff: bool = True) -> str: |
|
""" |
|
Serializes this instance to a JSON string. |
|
|
|
Args: |
|
use_diff (`bool`, *optional*, defaults to `True`): |
|
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` |
|
is serialized to JSON string. |
|
|
|
Returns: |
|
`str`: String containing all the attributes that make up this configuration instance in JSON format. |
|
""" |
|
if use_diff is True: |
|
config_dict = self.to_diff_dict() |
|
else: |
|
config_dict = self.to_dict() |
|
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" |
|
|
|
|
|
@dataclass |
|
class BitsAndBytesConfig(QuantizationConfigMixin): |
|
""" |
|
This is a wrapper class about all possible attributes and features that you can play with a model that has been |
|
loaded using `bitsandbytes`. |
|
|
|
This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive. |
|
|
|
Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`, |
|
then more arguments will be added to this class. |
|
|
|
Args: |
|
load_in_8bit (`bool`, *optional*, defaults to `False`): |
|
This flag is used to enable 8-bit quantization with LLM.int8(). |
|
load_in_4bit (`bool`, *optional*, defaults to `False`): |
|
This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from |
|
`bitsandbytes`. |
|
llm_int8_threshold (`float`, *optional*, defaults to 6.0): |
|
This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix |
|
Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value |
|
that is above this threshold will be considered an outlier and the operation on those values will be done |
|
in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but |
|
there are some exceptional systematic outliers that are very differently distributed for large models. |
|
These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of |
|
magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, |
|
but a lower threshold might be needed for more unstable models (small models, fine-tuning). |
|
llm_int8_skip_modules (`List[str]`, *optional*): |
|
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as |
|
Jukebox that has several heads in different places and not necessarily at the last position. For example |
|
for `CausalLM` models, the last `lm_head` is kept in its original `dtype`. |
|
llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`): |
|
This flag is used for advanced use cases and users that are aware of this feature. If you want to split |
|
your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use |
|
this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8 |
|
operations will not be run on CPU. |
|
llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`): |
|
This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not |
|
have to be converted back and forth for the backward pass. |
|
bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`): |
|
This sets the computational type which might be different than the input time. For example, inputs might be |
|
fp32, but computation can be set to bf16 for speedups. |
|
bnb_4bit_quant_type (`str`, *optional*, defaults to `"fp4"`): |
|
This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types |
|
which are specified by `fp4` or `nf4`. |
|
bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`): |
|
This flag is used for nested quantization where the quantization constants from the first quantization are |
|
quantized again. |
|
kwargs (`Dict[str, Any]`, *optional*): |
|
Additional parameters from which to initialize the configuration object. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
load_in_8bit=False, |
|
load_in_4bit=False, |
|
llm_int8_threshold=6.0, |
|
llm_int8_skip_modules=None, |
|
llm_int8_enable_fp32_cpu_offload=False, |
|
llm_int8_has_fp16_weight=False, |
|
bnb_4bit_compute_dtype=None, |
|
bnb_4bit_quant_type="fp4", |
|
bnb_4bit_use_double_quant=False, |
|
**kwargs, |
|
): |
|
self.quant_method = QuantizationMethod.BITS_AND_BYTES |
|
self.load_in_8bit = load_in_8bit |
|
self.load_in_4bit = load_in_4bit |
|
self.llm_int8_threshold = llm_int8_threshold |
|
self.llm_int8_skip_modules = llm_int8_skip_modules |
|
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload |
|
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight |
|
self.bnb_4bit_quant_type = bnb_4bit_quant_type |
|
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant |
|
|
|
if bnb_4bit_compute_dtype is None: |
|
self.bnb_4bit_compute_dtype = torch.float32 |
|
elif isinstance(bnb_4bit_compute_dtype, str): |
|
self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype) |
|
elif isinstance(bnb_4bit_compute_dtype, torch.dtype): |
|
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype |
|
else: |
|
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") |
|
|
|
self.post_init() |
|
|
|
def post_init(self): |
|
r""" |
|
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. |
|
""" |
|
if not isinstance(self.llm_int8_threshold, float): |
|
raise ValueError("llm_int8_threshold must be a float") |
|
|
|
if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list): |
|
raise ValueError("llm_int8_skip_modules must be a list of strings") |
|
if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool): |
|
raise ValueError("llm_int8_enable_fp32_cpu_offload must be a boolean") |
|
|
|
if not isinstance(self.llm_int8_has_fp16_weight, bool): |
|
raise ValueError("llm_int8_has_fp16_weight must be a boolean") |
|
|
|
if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): |
|
raise ValueError("bnb_4bit_compute_dtype must be torch.dtype") |
|
|
|
if not isinstance(self.bnb_4bit_quant_type, str): |
|
raise ValueError("bnb_4bit_quant_type must be a string") |
|
|
|
if not isinstance(self.bnb_4bit_use_double_quant, bool): |
|
raise ValueError("bnb_4bit_use_double_quant must be a boolean") |
|
|
|
if self.load_in_4bit and not version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse( |
|
"0.39.0" |
|
): |
|
raise ValueError( |
|
"4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version" |
|
) |
|
|
|
def is_quantizable(self): |
|
r""" |
|
Returns `True` if the model is quantizable, `False` otherwise. |
|
""" |
|
return self.load_in_8bit or self.load_in_4bit |
|
|
|
def quantization_method(self): |
|
r""" |
|
This method returns the quantization method used for the model. If the model is not quantizable, it returns |
|
`None`. |
|
""" |
|
if self.load_in_8bit: |
|
return "llm_int8" |
|
elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4": |
|
return "fp4" |
|
elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4": |
|
return "nf4" |
|
else: |
|
return None |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
""" |
|
Serializes this instance to a Python dictionary. Returns: |
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. |
|
""" |
|
output = copy.deepcopy(self.__dict__) |
|
output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1] |
|
|
|
return output |
|
|
|
def __repr__(self): |
|
config_dict = self.to_dict() |
|
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" |
|
|
|
def to_diff_dict(self) -> Dict[str, Any]: |
|
""" |
|
Removes all attributes from config which correspond to the default config attributes for better readability and |
|
serializes to a Python dictionary. |
|
|
|
Returns: |
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, |
|
""" |
|
config_dict = self.to_dict() |
|
|
|
|
|
default_config_dict = BitsAndBytesConfig().to_dict() |
|
|
|
serializable_config_dict = {} |
|
|
|
|
|
for key, value in config_dict.items(): |
|
if value != default_config_dict[key]: |
|
serializable_config_dict[key] = value |
|
|
|
return serializable_config_dict |
|
|
|
|
|
@dataclass |
|
class GPTQConfig(QuantizationConfigMixin): |
|
""" |
|
This is a wrapper class about all possible attributes and features that you can play with a model that has been |
|
loaded using `optimum` api for gptq quantization relying on auto_gptq backend. |
|
|
|
Args: |
|
bits (`int`): |
|
The number of bits to quantize to, supported numbers are (2, 3, 4, 8). |
|
tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*): |
|
The tokenizer used to process the dataset. You can pass either: |
|
- A custom tokenizer object. |
|
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. |
|
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a |
|
user or organization name, like `dbmdz/bert-base-german-cased`. |
|
- A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved |
|
using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. |
|
dataset (`Union[List[str]]`, *optional*): |
|
The dataset used for quantization. You can provide your own dataset in a list of string or just use the |
|
original datasets used in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new'] |
|
group_size (`int`, *optional*, defaults to 128): |
|
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. |
|
damp_percent (`float`, *optional*, defaults to 0.1): |
|
The percent of the average Hessian diagonal to use for dampening. Recommended value is 0.1. |
|
desc_act (`bool`, *optional*, defaults to `False`): |
|
Whether to quantize columns in order of decreasing activation size. Setting it to False can significantly |
|
speed up inference but the perplexity may become slightly worse. Also known as act-order. |
|
sym (`bool`, *optional*, defaults to `True`): |
|
Whether to use symetric quantization. |
|
true_sequential (`bool`, *optional*, defaults to `True`): |
|
Whether to perform sequential quantization even within a single Transformer block. Instead of quantizing |
|
the entire block at once, we perform layer-wise quantization. As a result, each layer undergoes |
|
quantization using inputs that have passed through the previously quantized layers. |
|
use_cuda_fp16 (`bool`, *optional*, defaults to `False`): |
|
Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16. |
|
model_seqlen (`int`, *optional*): |
|
The maximum sequence length that the model can take. |
|
block_name_to_quantize (`str`, *optional*): |
|
The transformers block name to quantize. |
|
module_name_preceding_first_block (`List[str]`, *optional*): |
|
The layers that are preceding the first Transformer block. |
|
batch_size (`int`, *optional*, defaults to 1): |
|
The batch size used when processing the dataset |
|
pad_token_id (`int`, *optional*): |
|
The pad token id. Needed to prepare the dataset when `batch_size` > 1. |
|
disable_exllama (`bool`, *optional*, defaults to `False`): |
|
Whether to use exllama backend. Only works with `bits` = 4. |
|
max_input_length (`int`, *optional*): |
|
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input |
|
length. It is specific to the exllama backend with act-order. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
bits: int, |
|
tokenizer: Any = None, |
|
dataset: Optional[Union[List[str], str]] = None, |
|
group_size: int = 128, |
|
damp_percent: float = 0.1, |
|
desc_act: bool = False, |
|
sym: bool = True, |
|
true_sequential: bool = True, |
|
use_cuda_fp16: bool = False, |
|
model_seqlen: Optional[int] = None, |
|
block_name_to_quantize: Optional[str] = None, |
|
module_name_preceding_first_block: Optional[List[str]] = None, |
|
batch_size: int = 1, |
|
pad_token_id: Optional[int] = None, |
|
disable_exllama: bool = False, |
|
max_input_length: Optional[int] = None, |
|
**kwargs, |
|
): |
|
self.quant_method = QuantizationMethod.GPTQ |
|
self.bits = bits |
|
self.tokenizer = tokenizer |
|
self.dataset = dataset |
|
self.group_size = group_size |
|
self.damp_percent = damp_percent |
|
self.desc_act = desc_act |
|
self.sym = sym |
|
self.true_sequential = true_sequential |
|
self.use_cuda_fp16 = use_cuda_fp16 |
|
self.model_seqlen = model_seqlen |
|
self.block_name_to_quantize = block_name_to_quantize |
|
self.module_name_preceding_first_block = module_name_preceding_first_block |
|
self.batch_size = batch_size |
|
self.pad_token_id = pad_token_id |
|
self.disable_exllama = disable_exllama |
|
self.max_input_length = max_input_length |
|
self.post_init() |
|
|
|
def get_loading_attributes(self): |
|
attibutes_dict = copy.deepcopy(self.__dict__) |
|
loading_attibutes = ["disable_exllama", "use_cuda_fp16", "max_input_length"] |
|
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} |
|
return loading_attibutes_dict |
|
|
|
def post_init(self): |
|
r""" |
|
Safety checker that arguments are correct |
|
""" |
|
if self.bits not in [2, 3, 4, 8]: |
|
raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}") |
|
if self.group_size != -1 and self.group_size <= 0: |
|
raise ValueError("group_size must be greater than 0 or equal to -1") |
|
if not (0 < self.damp_percent < 1): |
|
raise ValueError("damp_percent must between 0 and 1.") |
|
if self.dataset is not None: |
|
if isinstance(self.dataset, str): |
|
if self.dataset not in ["wikitext2", "c4", "c4-new", "ptb", "ptb-new"]: |
|
raise ValueError( |
|
f"""You have entered a string value for dataset. You can only choose between |
|
['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}""" |
|
) |
|
elif not isinstance(self.dataset, list): |
|
raise ValueError( |
|
f"""dataset needs to be either a list of string or a value in |
|
['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}""" |
|
) |
|
|