Update cache_utils.py
Browse files- cache_utils.py +550 -48
cache_utils.py
CHANGED
@@ -1,12 +1,21 @@
|
|
|
|
|
|
|
|
1 |
from dataclasses import dataclass
|
2 |
-
from typing import Any, Dict, List, Optional, Tuple
|
3 |
|
4 |
import torch
|
5 |
|
6 |
from transformers.configuration_utils import PretrainedConfig
|
7 |
-
from transformers.utils import logging
|
8 |
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
logger = logging.get_logger(__name__)
|
11 |
|
12 |
|
@@ -44,6 +53,7 @@ class Cache:
|
|
44 |
|
45 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
46 |
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
|
|
47 |
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
48 |
|
49 |
def get_max_length(self) -> Optional[int]:
|
@@ -61,6 +71,14 @@ class Cache:
|
|
61 |
return max_length - new_seq_length
|
62 |
return previous_seq_length
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
@property
|
65 |
def seen_tokens(self):
|
66 |
logger.warning_once(
|
@@ -73,6 +91,201 @@ class Cache:
|
|
73 |
return None
|
74 |
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
class DynamicCache(Cache):
|
77 |
"""
|
78 |
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
@@ -150,6 +363,7 @@ class DynamicCache(Cache):
|
|
150 |
|
151 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
152 |
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
|
|
153 |
if len(self.key_cache) <= layer_idx:
|
154 |
return 0
|
155 |
return self.key_cache[layer_idx].shape[-2]
|
@@ -158,14 +372,6 @@ class DynamicCache(Cache):
|
|
158 |
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
159 |
return None
|
160 |
|
161 |
-
def reorder_cache(self, beam_idx: torch.LongTensor):
|
162 |
-
"""Reorders the cache for beam search, given the selected beam indices."""
|
163 |
-
for layer_idx in range(len(self.key_cache)):
|
164 |
-
device = self.key_cache[layer_idx].device
|
165 |
-
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
166 |
-
device = self.value_cache[layer_idx].device
|
167 |
-
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
168 |
-
|
169 |
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
170 |
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
|
171 |
legacy_cache = ()
|
@@ -184,6 +390,168 @@ class DynamicCache(Cache):
|
|
184 |
return cache
|
185 |
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
class SinkCache(Cache):
|
188 |
"""
|
189 |
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
@@ -205,7 +573,9 @@ class SinkCache(Cache):
|
|
205 |
self.value_cache: List[torch.Tensor] = []
|
206 |
self.window_length = window_length
|
207 |
self.num_sink_tokens = num_sink_tokens
|
208 |
-
self.
|
|
|
|
|
209 |
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
210 |
|
211 |
@staticmethod
|
@@ -223,7 +593,7 @@ class SinkCache(Cache):
|
|
223 |
def _get_rerotation_cos_sin(
|
224 |
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
225 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
226 |
-
if key_states.shape[-2] not in self.
|
227 |
# Upcast to float32 temporarily for better accuracy
|
228 |
cos = cos.to(torch.float32)
|
229 |
sin = sin.to(torch.float32)
|
@@ -236,14 +606,15 @@ class SinkCache(Cache):
|
|
236 |
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
237 |
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
238 |
|
239 |
-
self.
|
240 |
rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
241 |
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
242 |
)
|
243 |
-
return self.
|
244 |
|
245 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
246 |
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
|
|
247 |
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
248 |
if len(self.key_cache) <= layer_idx:
|
249 |
return 0
|
@@ -289,6 +660,21 @@ class SinkCache(Cache):
|
|
289 |
if layer_idx == 0:
|
290 |
self._seen_tokens += key_states.shape[-2]
|
291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
# [bsz, num_heads, seq_len, head_dim]
|
293 |
if len(self.key_cache) <= layer_idx:
|
294 |
# Empty cache
|
@@ -309,7 +695,7 @@ class SinkCache(Cache):
|
|
309 |
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
310 |
if using_rope:
|
311 |
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
312 |
-
key_states,
|
313 |
)
|
314 |
if partial_rotation_size is not None:
|
315 |
keys_to_keep, keys_pass = (
|
@@ -332,14 +718,6 @@ class SinkCache(Cache):
|
|
332 |
|
333 |
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
334 |
|
335 |
-
def reorder_cache(self, beam_idx: torch.LongTensor):
|
336 |
-
"""Reorders the cache for beam search, given the selected beam indices."""
|
337 |
-
for layer_idx in range(len(self.key_cache)):
|
338 |
-
device = self.key_cache[layer_idx].device
|
339 |
-
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
340 |
-
device = self.value_cache[layer_idx].device
|
341 |
-
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
342 |
-
|
343 |
|
344 |
class StaticCache(Cache):
|
345 |
"""
|
@@ -347,8 +725,7 @@ class StaticCache(Cache):
|
|
347 |
|
348 |
Parameters:
|
349 |
config (`PretrainedConfig):
|
350 |
-
The configuration file defining the
|
351 |
-
required to initialize the static cache.
|
352 |
max_batch_size (`int`):
|
353 |
The maximum batch size with which the model will be used.
|
354 |
max_cache_len (`int`):
|
@@ -373,9 +750,18 @@ class StaticCache(Cache):
|
|
373 |
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
374 |
)
|
375 |
|
|
|
|
|
376 |
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
377 |
-
|
378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
def update(
|
381 |
self,
|
@@ -394,42 +780,158 @@ class StaticCache(Cache):
|
|
394 |
value_states (`torch.Tensor`):
|
395 |
The new value states to cache.
|
396 |
layer_idx (`int`):
|
397 |
-
The index of the layer to cache the states for.
|
398 |
cache_kwargs (`Dict[str, Any]`, `optional`):
|
399 |
-
Additional arguments for the cache subclass. The `StaticCache`
|
400 |
-
to know how
|
401 |
|
402 |
Return:
|
403 |
A tuple containing the updated key and value states.
|
404 |
"""
|
405 |
-
|
406 |
-
k_out = self.key_cache
|
407 |
-
v_out = self.value_cache
|
408 |
|
409 |
-
k_out[:, :,
|
410 |
-
v_out[:, :,
|
411 |
|
412 |
return k_out, v_out
|
413 |
|
414 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
415 |
-
"""Returns the sequence length of the cached states that were seen by the model.
|
416 |
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
417 |
# limit the check to the first batch member and head dimension.
|
418 |
-
# TODO:
|
419 |
-
|
420 |
-
return (self.key_cache[0, 0].any(dim=-1)).sum()
|
421 |
|
422 |
def get_max_length(self) -> Optional[int]:
|
423 |
-
"""Returns the maximum sequence length of the cached states.
|
424 |
return self.max_cache_len
|
425 |
|
426 |
-
def
|
427 |
-
"""
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
|
433 |
-
|
434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
return None
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import os
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
6 |
|
7 |
import torch
|
8 |
|
9 |
from transformers.configuration_utils import PretrainedConfig
|
10 |
+
from transformers.utils import is_hqq_available, is_quanto_available, logging
|
11 |
|
12 |
|
13 |
+
if is_quanto_available():
|
14 |
+
from quanto import QBitsTensor, qint2, qint4
|
15 |
+
|
16 |
+
if is_hqq_available():
|
17 |
+
from hqq.core.quantize import Quantizer as HQQQuantizer
|
18 |
+
|
19 |
logger = logging.get_logger(__name__)
|
20 |
|
21 |
|
|
|
53 |
|
54 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
55 |
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
56 |
+
# TODO: deprecate this function in favor of `cache_position`
|
57 |
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
58 |
|
59 |
def get_max_length(self) -> Optional[int]:
|
|
|
71 |
return max_length - new_seq_length
|
72 |
return previous_seq_length
|
73 |
|
74 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
75 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
76 |
+
for layer_idx in range(len(self.key_cache)):
|
77 |
+
device = self.key_cache[layer_idx].device
|
78 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
79 |
+
device = self.value_cache[layer_idx].device
|
80 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
81 |
+
|
82 |
@property
|
83 |
def seen_tokens(self):
|
84 |
logger.warning_once(
|
|
|
91 |
return None
|
92 |
|
93 |
|
94 |
+
@dataclass
|
95 |
+
class CacheConfig:
|
96 |
+
"""
|
97 |
+
Base class for cache configs
|
98 |
+
"""
|
99 |
+
|
100 |
+
cache_implementation: None
|
101 |
+
|
102 |
+
@classmethod
|
103 |
+
def from_dict(cls, config_dict, **kwargs):
|
104 |
+
"""
|
105 |
+
Constructs a CacheConfig instance from a dictionary of parameters.
|
106 |
+
Args:
|
107 |
+
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
|
108 |
+
**kwargs: Additional keyword arguments to override dictionary values.
|
109 |
+
Returns:
|
110 |
+
CacheConfig: Instance of CacheConfig constructed from the dictionary.
|
111 |
+
"""
|
112 |
+
config = cls(**config_dict)
|
113 |
+
to_remove = []
|
114 |
+
for key, value in kwargs.items():
|
115 |
+
if hasattr(config, key):
|
116 |
+
setattr(config, key, value)
|
117 |
+
to_remove.append(key)
|
118 |
+
for key in to_remove:
|
119 |
+
kwargs.pop(key, None)
|
120 |
+
return config
|
121 |
+
|
122 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
|
123 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
124 |
+
"""
|
125 |
+
Save this instance to a JSON file.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
json_file_path (`str` or `os.PathLike`):
|
129 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
130 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
131 |
+
If set to `True`, only the difference between the config instance and the default
|
132 |
+
`QuantizationConfig()` is serialized to JSON file.
|
133 |
+
"""
|
134 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
135 |
+
config_dict = self.to_dict()
|
136 |
+
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
137 |
+
|
138 |
+
writer.write(json_string)
|
139 |
+
|
140 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
|
141 |
+
def to_dict(self) -> Dict[str, Any]:
|
142 |
+
"""
|
143 |
+
Serializes this instance to a Python dictionary. Returns:
|
144 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
145 |
+
"""
|
146 |
+
return copy.deepcopy(self.__dict__)
|
147 |
+
|
148 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
|
149 |
+
def __iter__(self):
|
150 |
+
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
|
151 |
+
for attr, value in copy.deepcopy(self.__dict__).items():
|
152 |
+
yield attr, value
|
153 |
+
|
154 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
|
155 |
+
def __repr__(self):
|
156 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
157 |
+
|
158 |
+
def to_json_string(self):
|
159 |
+
"""
|
160 |
+
Serializes this instance to a JSON formatted string.
|
161 |
+
Returns:
|
162 |
+
str: JSON formatted string representing the configuration instance.
|
163 |
+
"""
|
164 |
+
return json.dumps(self.__dict__, indent=2) + "\n"
|
165 |
+
|
166 |
+
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
|
167 |
+
def update(self, **kwargs):
|
168 |
+
"""
|
169 |
+
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
|
170 |
+
returning all the unused kwargs.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
kwargs (`Dict[str, Any]`):
|
174 |
+
Dictionary of attributes to tentatively update this class.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
|
178 |
+
"""
|
179 |
+
to_remove = []
|
180 |
+
for key, value in kwargs.items():
|
181 |
+
if hasattr(self, key):
|
182 |
+
setattr(self, key, value)
|
183 |
+
to_remove.append(key)
|
184 |
+
|
185 |
+
# Remove all the attributes that were updated, without modifying the input dict
|
186 |
+
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
187 |
+
return unused_kwargs
|
188 |
+
|
189 |
+
|
190 |
+
@dataclass
|
191 |
+
class QuantizedCacheConfig(CacheConfig):
|
192 |
+
"""
|
193 |
+
Configuration class for quantized cache settings.
|
194 |
+
|
195 |
+
Attributes:
|
196 |
+
backend (`str`, *optional*, defaults to `"quanto"`):
|
197 |
+
Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`]
|
198 |
+
nbits (`Optional[int]`, *optional*, defaults to 4):
|
199 |
+
Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2.
|
200 |
+
axis_key (`int`, *optional*, defaults to 0):
|
201 |
+
Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
|
202 |
+
axis_value (`int`, *optional*, defaults to 0):
|
203 |
+
Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
|
204 |
+
q_group_size (`Optional[int]`, *optional*, defaults to 64):
|
205 |
+
Size of the quantization group, should be a divisor of the model's hidden dimension.
|
206 |
+
Defaults to 64.
|
207 |
+
residual_length (`Optional[int]`, *optional*, defaults to 128):
|
208 |
+
Length of the residual cache which will always be stored in original presicion.
|
209 |
+
Defaults to 128.
|
210 |
+
compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
|
211 |
+
The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization.
|
212 |
+
device (`str`, *optional*, defaults to `"cpu"`):
|
213 |
+
Device on which to peform computations, should be same as the model's device.
|
214 |
+
"""
|
215 |
+
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
backend: str = "quanto",
|
219 |
+
nbits: Optional[int] = 4,
|
220 |
+
axis_key: Optional[int] = 0,
|
221 |
+
axis_value: Optional[int] = 0,
|
222 |
+
q_group_size: Optional[int] = 64,
|
223 |
+
residual_length: Optional[int] = 128,
|
224 |
+
compute_dtype: Optional[torch.dtype] = torch.float16,
|
225 |
+
device: Optional[str] = "cpu",
|
226 |
+
):
|
227 |
+
self.backend = backend
|
228 |
+
self.nbits = nbits
|
229 |
+
self.axis_key = axis_key
|
230 |
+
self.axis_value = axis_value
|
231 |
+
self.q_group_size = q_group_size
|
232 |
+
self.residual_length = residual_length
|
233 |
+
self.compute_dtype = compute_dtype
|
234 |
+
self.device = device
|
235 |
+
|
236 |
+
def validate(self):
|
237 |
+
"""Validates if the arguments passed are correct"""
|
238 |
+
|
239 |
+
incorrect_arg_msg = (
|
240 |
+
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
|
241 |
+
"but found {found_value}"
|
242 |
+
)
|
243 |
+
# Check that the values are reasonable in general (nbits, axis)
|
244 |
+
# Later in QuantizedCache init we check if they are supported for that particular backend
|
245 |
+
if self.nbits not in [1, 2, 3, 4, 8]:
|
246 |
+
raise ValueError(
|
247 |
+
incorrect_arg_msg.format(
|
248 |
+
key="nbits",
|
249 |
+
correct_value="2 or 4 or 8",
|
250 |
+
found_value=self.nbits,
|
251 |
+
),
|
252 |
+
)
|
253 |
+
if self.q_group_size <= 0:
|
254 |
+
raise ValueError(
|
255 |
+
incorrect_arg_msg.format(
|
256 |
+
key="q_group_size",
|
257 |
+
correct_value="a positive integer",
|
258 |
+
found_value=self.q_group_size,
|
259 |
+
),
|
260 |
+
)
|
261 |
+
if self.residual_length < 0:
|
262 |
+
raise ValueError(
|
263 |
+
incorrect_arg_msg.format(
|
264 |
+
key="residual_length",
|
265 |
+
correct_value="a positive integer",
|
266 |
+
found_value=self.residual_length,
|
267 |
+
),
|
268 |
+
)
|
269 |
+
|
270 |
+
if self.axis_key not in [0, 1, -1]:
|
271 |
+
raise ValueError(
|
272 |
+
incorrect_arg_msg.format(
|
273 |
+
key="axis_key",
|
274 |
+
correct_value="`1` or `0`, `-1`",
|
275 |
+
found_value=self.axis_key,
|
276 |
+
),
|
277 |
+
)
|
278 |
+
|
279 |
+
if self.axis_value not in [0, 1, -1]:
|
280 |
+
raise ValueError(
|
281 |
+
incorrect_arg_msg.format(
|
282 |
+
key="axis_value",
|
283 |
+
correct_value="`1` or `0` or `-1`",
|
284 |
+
found_value=self.axis_value,
|
285 |
+
),
|
286 |
+
)
|
287 |
+
|
288 |
+
|
289 |
class DynamicCache(Cache):
|
290 |
"""
|
291 |
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
|
|
363 |
|
364 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
365 |
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
366 |
+
# TODO: deprecate this function in favor of `cache_position`
|
367 |
if len(self.key_cache) <= layer_idx:
|
368 |
return 0
|
369 |
return self.key_cache[layer_idx].shape[-2]
|
|
|
372 |
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
373 |
return None
|
374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
376 |
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
|
377 |
legacy_cache = ()
|
|
|
390 |
return cache
|
391 |
|
392 |
|
393 |
+
class QuantizedCache(DynamicCache):
|
394 |
+
"""
|
395 |
+
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
|
396 |
+
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
|
397 |
+
|
398 |
+
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
|
399 |
+
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
|
400 |
+
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
|
401 |
+
|
402 |
+
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
|
403 |
+
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
|
404 |
+
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
|
405 |
+
"""
|
406 |
+
|
407 |
+
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
|
408 |
+
self._quantized_key_cache: List[torch.Tensor] = []
|
409 |
+
self._quantized_value_cache: List[torch.Tensor] = []
|
410 |
+
|
411 |
+
self.nbits = cache_config.nbits
|
412 |
+
self.residual_length = cache_config.residual_length
|
413 |
+
self.q_group_size = cache_config.q_group_size
|
414 |
+
self.axis_key = cache_config.axis_key
|
415 |
+
self.axis_value = cache_config.axis_value
|
416 |
+
self.compute_dtype = cache_config.compute_dtype
|
417 |
+
self.device = cache_config.device
|
418 |
+
|
419 |
+
super().__init__()
|
420 |
+
|
421 |
+
def update(
|
422 |
+
self,
|
423 |
+
key_states: torch.Tensor,
|
424 |
+
value_states: torch.Tensor,
|
425 |
+
layer_idx: int,
|
426 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
427 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
428 |
+
# Update the number of seen tokens
|
429 |
+
if layer_idx == 0:
|
430 |
+
self._seen_tokens += key_states.shape[-2]
|
431 |
+
|
432 |
+
if len(self.key_cache) <= layer_idx:
|
433 |
+
self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key))
|
434 |
+
self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value))
|
435 |
+
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
|
436 |
+
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
|
437 |
+
keys_to_return, values_to_return = key_states, value_states
|
438 |
+
else:
|
439 |
+
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
|
440 |
+
dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
|
441 |
+
keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
|
442 |
+
values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
|
443 |
+
|
444 |
+
keys_to_return = torch.cat(keys_to_return, dim=-2)
|
445 |
+
values_to_return = torch.cat(values_to_return, dim=-2)
|
446 |
+
if (
|
447 |
+
self.key_cache[layer_idx].dim() == 4
|
448 |
+
and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
|
449 |
+
):
|
450 |
+
self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
|
451 |
+
self._quantized_value_cache[layer_idx] = self._quantize(
|
452 |
+
values_to_return.contiguous(), axis=self.axis_value
|
453 |
+
)
|
454 |
+
self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
|
455 |
+
self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
|
456 |
+
else:
|
457 |
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
458 |
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
459 |
+
|
460 |
+
return keys_to_return, values_to_return
|
461 |
+
|
462 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
463 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
464 |
+
if len(self.key_cache) <= layer_idx:
|
465 |
+
return 0
|
466 |
+
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
|
467 |
+
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
|
468 |
+
# this part of code otherwise fails when used to verify attn_weight shape in some models
|
469 |
+
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
|
470 |
+
|
471 |
+
def _quantize(self, tensor, axis):
|
472 |
+
"""Quantizes a key/value using a defined quantization method."""
|
473 |
+
raise NotImplementedError("Make sure to implement `_quantize` in a subclass.")
|
474 |
+
|
475 |
+
def _dequantize(self, q_tensor):
|
476 |
+
"""Dequantizes back the tensor that was quantized by `self._quantize()`"""
|
477 |
+
raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.")
|
478 |
+
|
479 |
+
|
480 |
+
class QuantoQuantizedCache(QuantizedCache):
|
481 |
+
"""
|
482 |
+
Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
|
483 |
+
|
484 |
+
Parameters:
|
485 |
+
cache_config (`QuantizedCacheConfig`,):
|
486 |
+
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
|
487 |
+
"""
|
488 |
+
|
489 |
+
def __init__(self, cache_config: CacheConfig) -> None:
|
490 |
+
super().__init__(cache_config)
|
491 |
+
if self.nbits not in [2, 4]:
|
492 |
+
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
|
493 |
+
|
494 |
+
if self.axis_key not in [0, -1]:
|
495 |
+
raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
|
496 |
+
|
497 |
+
if self.axis_value not in [0, -1]:
|
498 |
+
raise ValueError(
|
499 |
+
f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
|
500 |
+
)
|
501 |
+
|
502 |
+
self.qtype = qint4 if self.nbits == 4 else qint2
|
503 |
+
|
504 |
+
def _quantize(self, tensor, axis):
|
505 |
+
qtensor = QBitsTensor.quantize(tensor, axis=axis, qtype=self.qtype, group_size=self.q_group_size)
|
506 |
+
return qtensor
|
507 |
+
|
508 |
+
def _dequantize(self, qtensor):
|
509 |
+
return qtensor.dequantize()
|
510 |
+
|
511 |
+
|
512 |
+
class HQQQuantizedCache(QuantizedCache):
|
513 |
+
"""
|
514 |
+
Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
|
515 |
+
|
516 |
+
Parameters:
|
517 |
+
cache_config (`QuantizedCacheConfig`,):
|
518 |
+
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
|
519 |
+
"""
|
520 |
+
|
521 |
+
def __init__(self, cache_config: CacheConfig) -> None:
|
522 |
+
super().__init__(cache_config)
|
523 |
+
if self.nbits not in [1, 2, 3, 4, 8]:
|
524 |
+
raise ValueError(
|
525 |
+
f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
|
526 |
+
)
|
527 |
+
|
528 |
+
if self.axis_key not in [0, 1]:
|
529 |
+
raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
|
530 |
+
|
531 |
+
if self.axis_value not in [0, 1]:
|
532 |
+
raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
|
533 |
+
|
534 |
+
self.quantizer = HQQQuantizer
|
535 |
+
|
536 |
+
def _quantize(self, tensor, axis):
|
537 |
+
qtensor, meta = self.quantizer.quantize(
|
538 |
+
tensor,
|
539 |
+
axis=axis,
|
540 |
+
device=self.device,
|
541 |
+
compute_dtype=self.compute_dtype,
|
542 |
+
nbits=self.nbits,
|
543 |
+
group_size=self.q_group_size,
|
544 |
+
)
|
545 |
+
meta["compute_dtype"] = self.compute_dtype
|
546 |
+
self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype
|
547 |
+
return qtensor, meta
|
548 |
+
|
549 |
+
def _dequantize(self, qtensor):
|
550 |
+
quant_tensor, meta = qtensor
|
551 |
+
tensor = self.quantizer.dequantize(quant_tensor, meta)
|
552 |
+
return tensor
|
553 |
+
|
554 |
+
|
555 |
class SinkCache(Cache):
|
556 |
"""
|
557 |
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
|
|
573 |
self.value_cache: List[torch.Tensor] = []
|
574 |
self.window_length = window_length
|
575 |
self.num_sink_tokens = num_sink_tokens
|
576 |
+
self.cos_sin_rerotation_cache = {}
|
577 |
+
self._cos_cache = None
|
578 |
+
self._sin_cache = None
|
579 |
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
580 |
|
581 |
@staticmethod
|
|
|
593 |
def _get_rerotation_cos_sin(
|
594 |
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
595 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
596 |
+
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
|
597 |
# Upcast to float32 temporarily for better accuracy
|
598 |
cos = cos.to(torch.float32)
|
599 |
sin = sin.to(torch.float32)
|
|
|
606 |
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
607 |
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
608 |
|
609 |
+
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
|
610 |
rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
611 |
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
612 |
)
|
613 |
+
return self.cos_sin_rerotation_cache[key_states.shape[-2]]
|
614 |
|
615 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
616 |
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
617 |
+
# TODO: deprecate this function in favor of `cache_position`
|
618 |
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
619 |
if len(self.key_cache) <= layer_idx:
|
620 |
return 0
|
|
|
660 |
if layer_idx == 0:
|
661 |
self._seen_tokens += key_states.shape[-2]
|
662 |
|
663 |
+
# Update the sin/cos cache, which holds sin/cos values for all possible positions
|
664 |
+
if using_rope and layer_idx == 0:
|
665 |
+
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
|
666 |
+
# after all RoPE models have a llama-like cache utilization.
|
667 |
+
if cos.dim() == 2:
|
668 |
+
self._cos_cache = cos
|
669 |
+
self._sin_cache = sin
|
670 |
+
else:
|
671 |
+
if self._cos_cache is None:
|
672 |
+
self._cos_cache = cos[0, ...]
|
673 |
+
self._sin_cache = sin[0, ...]
|
674 |
+
elif self._cos_cache.shape[0] < self.window_length:
|
675 |
+
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
|
676 |
+
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
|
677 |
+
|
678 |
# [bsz, num_heads, seq_len, head_dim]
|
679 |
if len(self.key_cache) <= layer_idx:
|
680 |
# Empty cache
|
|
|
695 |
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
696 |
if using_rope:
|
697 |
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
698 |
+
key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
|
699 |
)
|
700 |
if partial_rotation_size is not None:
|
701 |
keys_to_keep, keys_pass = (
|
|
|
718 |
|
719 |
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
720 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
721 |
|
722 |
class StaticCache(Cache):
|
723 |
"""
|
|
|
725 |
|
726 |
Parameters:
|
727 |
config (`PretrainedConfig):
|
728 |
+
The configuration file defining the shape-related attributes required to initialize the static cache.
|
|
|
729 |
max_batch_size (`int`):
|
730 |
The maximum batch size with which the model will be used.
|
731 |
max_cache_len (`int`):
|
|
|
750 |
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
751 |
)
|
752 |
|
753 |
+
self.key_cache: List[torch.Tensor] = []
|
754 |
+
self.value_cache: List[torch.Tensor] = []
|
755 |
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
756 |
+
for _ in range(config.num_hidden_layers):
|
757 |
+
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
758 |
+
# breaks when updating the cache.
|
759 |
+
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
760 |
+
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
761 |
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
762 |
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
763 |
+
self.key_cache.append(new_layer_key_cache)
|
764 |
+
self.value_cache.append(new_layer_value_cache)
|
765 |
|
766 |
def update(
|
767 |
self,
|
|
|
780 |
value_states (`torch.Tensor`):
|
781 |
The new value states to cache.
|
782 |
layer_idx (`int`):
|
783 |
+
The index of the layer to cache the states for.
|
784 |
cache_kwargs (`Dict[str, Any]`, `optional`):
|
785 |
+
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
|
786 |
+
to know how where to write in the cache.
|
787 |
|
788 |
Return:
|
789 |
A tuple containing the updated key and value states.
|
790 |
"""
|
791 |
+
cache_position = cache_kwargs.get("cache_position")
|
792 |
+
k_out = self.key_cache[layer_idx]
|
793 |
+
v_out = self.value_cache[layer_idx]
|
794 |
|
795 |
+
k_out[:, :, cache_position] = key_states
|
796 |
+
v_out[:, :, cache_position] = value_states
|
797 |
|
798 |
return k_out, v_out
|
799 |
|
800 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
801 |
+
"""Returns the sequence length of the cached states that were seen by the model."""
|
802 |
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
803 |
# limit the check to the first batch member and head dimension.
|
804 |
+
# TODO: deprecate this function in favor of `cache_position`
|
805 |
+
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
|
|
806 |
|
807 |
def get_max_length(self) -> Optional[int]:
|
808 |
+
"""Returns the maximum sequence length of the cached states."""
|
809 |
return self.max_cache_len
|
810 |
|
811 |
+
def reset(self):
|
812 |
+
"""Resets the cache values while preserving the objects"""
|
813 |
+
for layer_idx in range(len(self.key_cache)):
|
814 |
+
# In-place ops prevent breaking the static address
|
815 |
+
self.key_cache[layer_idx].zero_()
|
816 |
+
self.value_cache[layer_idx].zero_()
|
817 |
+
|
818 |
+
|
819 |
+
class SlidingWindowCache(Cache):
|
820 |
+
"""
|
821 |
+
Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
|
822 |
+
Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window_size - 1`,
|
823 |
+
if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
|
824 |
+
we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
|
825 |
+
|
826 |
+
The `to_shift` is only true once we are above sliding_window_size. Thus with `sliding_window_size==64`:
|
827 |
|
828 |
+
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window_size
|
829 |
+
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
830 |
+
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
|
831 |
+
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
|
832 |
+
55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
|
833 |
+
|
834 |
+
We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window_size`)
|
835 |
+
|
836 |
+
Parameters:
|
837 |
+
config (`PretrainedConfig):
|
838 |
+
The configuration file defining the shape-related attributes required to initialize the static cache.
|
839 |
+
max_batch_size (`int`):
|
840 |
+
The maximum batch size with which the model will be used.
|
841 |
+
max_cache_len (`int`):
|
842 |
+
The maximum sequence length with which the model will be used.
|
843 |
+
device (`torch.device`):
|
844 |
+
The device on which the cache should be initialized. Should be the same as the layer.
|
845 |
+
dtype (*optional*, defaults to `torch.float32`):
|
846 |
+
The default `dtype` to use when initializing the layer.
|
847 |
+
"""
|
848 |
+
|
849 |
+
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
|
850 |
+
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
851 |
+
raise ValueError(
|
852 |
+
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
853 |
+
"sliding window attention, please check if there is a `sliding_window` field in the model "
|
854 |
+
"config and it's not set to None."
|
855 |
+
)
|
856 |
+
|
857 |
+
super().__init__()
|
858 |
+
self.max_batch_size = max_batch_size
|
859 |
+
# take the minimum of max_cache_len and config.sliding_window so that we allocate less memory
|
860 |
+
# when we do short-sentence generation
|
861 |
+
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
862 |
+
self.model_sliding_window_size = config.sliding_window
|
863 |
+
self.sliding_window_size = min(self.max_cache_len, self.model_sliding_window_size)
|
864 |
+
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
865 |
+
self.head_dim = (
|
866 |
+
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
867 |
+
)
|
868 |
+
|
869 |
+
self.dtype = dtype if dtype is not None else torch.float32
|
870 |
+
self.num_key_value_heads = (
|
871 |
+
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
872 |
+
)
|
873 |
+
|
874 |
+
cache_shape = (
|
875 |
+
config.num_hidden_layers,
|
876 |
+
max_batch_size,
|
877 |
+
self.num_key_value_heads,
|
878 |
+
self.sliding_window_size,
|
879 |
+
self.head_dim,
|
880 |
+
)
|
881 |
+
|
882 |
+
self.key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
883 |
+
self.value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
884 |
+
|
885 |
+
torch._dynamo.mark_static_address(self.key_cache)
|
886 |
+
torch._dynamo.mark_static_address(self.value_cache)
|
887 |
+
|
888 |
+
def update(
|
889 |
+
self,
|
890 |
+
key_states: torch.Tensor,
|
891 |
+
value_states: torch.Tensor,
|
892 |
+
layer_idx: int,
|
893 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
894 |
+
) -> Tuple[torch.Tensor]:
|
895 |
+
cache_position = cache_kwargs.get("cache_position")
|
896 |
+
k_out = self.key_cache[layer_idx]
|
897 |
+
v_out = self.value_cache[layer_idx]
|
898 |
+
|
899 |
+
# assume this only happens in prefill phase when prompt length > sliding_window_size
|
900 |
+
if cache_position.shape[0] > self.sliding_window_size:
|
901 |
+
k_out = key_states[:, :, -self.sliding_window_size :, :]
|
902 |
+
v_out = value_states[:, :, -self.sliding_window_size :, :]
|
903 |
+
self.key_cache[layer_idx] = k_out
|
904 |
+
self.value_cache[layer_idx] = v_out
|
905 |
+
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
906 |
+
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
907 |
+
return key_states, value_states
|
908 |
+
|
909 |
+
slicing = torch.ones(self.sliding_window_size, dtype=torch.long, device=value_states.device).cumsum(0)
|
910 |
+
cache_position = cache_position.clamp(0, self.sliding_window_size - 1)
|
911 |
+
to_shift = cache_position >= self.sliding_window_size - 1
|
912 |
+
indices = (slicing + to_shift[-1].int() - 1) % self.sliding_window_size
|
913 |
+
|
914 |
+
k_out = k_out[:, :, indices]
|
915 |
+
v_out = v_out[:, :, indices]
|
916 |
+
|
917 |
+
k_out[:, :, cache_position] = key_states
|
918 |
+
v_out[:, :, cache_position] = value_states
|
919 |
+
|
920 |
+
self.key_cache[layer_idx] = k_out
|
921 |
+
self.value_cache[layer_idx] = v_out
|
922 |
+
|
923 |
+
return k_out, v_out
|
924 |
+
|
925 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
926 |
+
# assume this will be called only in the first generation step
|
927 |
+
# `cache_postion` will be used in other cases
|
928 |
+
return 0
|
929 |
+
|
930 |
+
def get_max_length(self) -> Optional[int]:
|
931 |
+
# in theory there is no limit because the sliding window size is fixed
|
932 |
+
# no matter how long the sentence is
|
933 |
return None
|
934 |
+
|
935 |
+
def reset(self):
|
936 |
+
self.key_cache.zero_()
|
937 |
+
self.value_cache.zero_()
|