Upload cache_utils.py
Browse files- cache_utils.py +435 -0
cache_utils.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .configuration_utils import PretrainedConfig
|
7 |
+
from .utils import logging
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.get_logger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class Cache:
|
15 |
+
"""
|
16 |
+
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def update(
|
20 |
+
self,
|
21 |
+
key_states: torch.Tensor,
|
22 |
+
value_states: torch.Tensor,
|
23 |
+
layer_idx: int,
|
24 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
25 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
26 |
+
"""
|
27 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
28 |
+
|
29 |
+
Parameters:
|
30 |
+
key_states (`torch.Tensor`):
|
31 |
+
The new key states to cache.
|
32 |
+
value_states (`torch.Tensor`):
|
33 |
+
The new value states to cache.
|
34 |
+
layer_idx (`int`):
|
35 |
+
The index of the layer to cache the states for.
|
36 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
37 |
+
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
|
38 |
+
cache to be created.
|
39 |
+
|
40 |
+
Return:
|
41 |
+
A tuple containing the updated key and value states.
|
42 |
+
"""
|
43 |
+
raise NotImplementedError("Make sure to implement `update` in a subclass.")
|
44 |
+
|
45 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
46 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
47 |
+
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
48 |
+
|
49 |
+
def get_max_length(self) -> Optional[int]:
|
50 |
+
"""Returns the maximum sequence length of the cached states, if there is any."""
|
51 |
+
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
|
52 |
+
|
53 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
54 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
55 |
+
# Cache without size limit -> all cache is usable
|
56 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
57 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
58 |
+
max_length = self.get_max_length()
|
59 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
60 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
61 |
+
return max_length - new_seq_length
|
62 |
+
return previous_seq_length
|
63 |
+
|
64 |
+
@property
|
65 |
+
def seen_tokens(self):
|
66 |
+
logger.warning_once(
|
67 |
+
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
|
68 |
+
"model input instead."
|
69 |
+
)
|
70 |
+
if hasattr(self, "_seen_tokens"):
|
71 |
+
return self._seen_tokens
|
72 |
+
else:
|
73 |
+
return None
|
74 |
+
|
75 |
+
|
76 |
+
class DynamicCache(Cache):
|
77 |
+
"""
|
78 |
+
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
79 |
+
|
80 |
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
81 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self) -> None:
|
85 |
+
self.key_cache: List[torch.Tensor] = []
|
86 |
+
self.value_cache: List[torch.Tensor] = []
|
87 |
+
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
88 |
+
|
89 |
+
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
90 |
+
"""
|
91 |
+
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
92 |
+
sequence length.
|
93 |
+
"""
|
94 |
+
if layer_idx < len(self):
|
95 |
+
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
96 |
+
else:
|
97 |
+
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
98 |
+
|
99 |
+
def __iter__(self):
|
100 |
+
"""
|
101 |
+
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
|
102 |
+
keys and values
|
103 |
+
"""
|
104 |
+
for layer_idx in range(len(self)):
|
105 |
+
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
"""
|
109 |
+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
110 |
+
to the number of layers in the model.
|
111 |
+
"""
|
112 |
+
return len(self.key_cache)
|
113 |
+
|
114 |
+
def update(
|
115 |
+
self,
|
116 |
+
key_states: torch.Tensor,
|
117 |
+
value_states: torch.Tensor,
|
118 |
+
layer_idx: int,
|
119 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
120 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
121 |
+
"""
|
122 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
123 |
+
|
124 |
+
Parameters:
|
125 |
+
key_states (`torch.Tensor`):
|
126 |
+
The new key states to cache.
|
127 |
+
value_states (`torch.Tensor`):
|
128 |
+
The new value states to cache.
|
129 |
+
layer_idx (`int`):
|
130 |
+
The index of the layer to cache the states for.
|
131 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
132 |
+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
133 |
+
|
134 |
+
Return:
|
135 |
+
A tuple containing the updated key and value states.
|
136 |
+
"""
|
137 |
+
# Update the number of seen tokens
|
138 |
+
if layer_idx == 0:
|
139 |
+
self._seen_tokens += key_states.shape[-2]
|
140 |
+
|
141 |
+
# Update the cache
|
142 |
+
if len(self.key_cache) <= layer_idx:
|
143 |
+
self.key_cache.append(key_states)
|
144 |
+
self.value_cache.append(value_states)
|
145 |
+
else:
|
146 |
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
147 |
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
148 |
+
|
149 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
150 |
+
|
151 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
152 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
153 |
+
if len(self.key_cache) <= layer_idx:
|
154 |
+
return 0
|
155 |
+
return self.key_cache[layer_idx].shape[-2]
|
156 |
+
|
157 |
+
def get_max_length(self) -> Optional[int]:
|
158 |
+
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
159 |
+
return None
|
160 |
+
|
161 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
162 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
163 |
+
for layer_idx in range(len(self.key_cache)):
|
164 |
+
device = self.key_cache[layer_idx].device
|
165 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
166 |
+
device = self.value_cache[layer_idx].device
|
167 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
168 |
+
|
169 |
+
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
170 |
+
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
|
171 |
+
legacy_cache = ()
|
172 |
+
for layer_idx in range(len(self)):
|
173 |
+
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
|
174 |
+
return legacy_cache
|
175 |
+
|
176 |
+
@classmethod
|
177 |
+
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
178 |
+
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
|
179 |
+
cache = cls()
|
180 |
+
if past_key_values is not None:
|
181 |
+
for layer_idx in range(len(past_key_values)):
|
182 |
+
key_states, value_states = past_key_values[layer_idx]
|
183 |
+
cache.update(key_states, value_states, layer_idx)
|
184 |
+
return cache
|
185 |
+
|
186 |
+
|
187 |
+
class SinkCache(Cache):
|
188 |
+
"""
|
189 |
+
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
190 |
+
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
|
191 |
+
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
|
192 |
+
|
193 |
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
194 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
195 |
+
|
196 |
+
Parameters:
|
197 |
+
window_length (`int`):
|
198 |
+
The length of the context window.
|
199 |
+
num_sink_tokens (`int`):
|
200 |
+
The number of sink tokens. See the original paper for more information.
|
201 |
+
"""
|
202 |
+
|
203 |
+
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
204 |
+
self.key_cache: List[torch.Tensor] = []
|
205 |
+
self.value_cache: List[torch.Tensor] = []
|
206 |
+
self.window_length = window_length
|
207 |
+
self.num_sink_tokens = num_sink_tokens
|
208 |
+
self.cos_sin_cache = {}
|
209 |
+
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
210 |
+
|
211 |
+
@staticmethod
|
212 |
+
def _rotate_half(x):
|
213 |
+
x1 = x[..., : x.shape[-1] // 2]
|
214 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
215 |
+
return torch.cat((-x2, x1), dim=-1)
|
216 |
+
|
217 |
+
def _apply_key_rotary_pos_emb(
|
218 |
+
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
219 |
+
) -> torch.Tensor:
|
220 |
+
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
|
221 |
+
return rotated_key_states
|
222 |
+
|
223 |
+
def _get_rerotation_cos_sin(
|
224 |
+
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
225 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
226 |
+
if key_states.shape[-2] not in self.cos_sin_cache:
|
227 |
+
# Upcast to float32 temporarily for better accuracy
|
228 |
+
cos = cos.to(torch.float32)
|
229 |
+
sin = sin.to(torch.float32)
|
230 |
+
|
231 |
+
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
|
232 |
+
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
|
233 |
+
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
|
234 |
+
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
|
235 |
+
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
|
236 |
+
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
237 |
+
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
238 |
+
|
239 |
+
self.cos_sin_cache[key_states.shape[-2]] = (
|
240 |
+
rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
241 |
+
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
242 |
+
)
|
243 |
+
return self.cos_sin_cache[key_states.shape[-2]]
|
244 |
+
|
245 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
246 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
247 |
+
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
248 |
+
if len(self.key_cache) <= layer_idx:
|
249 |
+
return 0
|
250 |
+
return self.key_cache[layer_idx].shape[-2]
|
251 |
+
|
252 |
+
def get_max_length(self) -> Optional[int]:
|
253 |
+
"""Returns the maximum sequence length of the cached states."""
|
254 |
+
return self.window_length
|
255 |
+
|
256 |
+
def update(
|
257 |
+
self,
|
258 |
+
key_states: torch.Tensor,
|
259 |
+
value_states: torch.Tensor,
|
260 |
+
layer_idx: int,
|
261 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
262 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
263 |
+
"""
|
264 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
265 |
+
|
266 |
+
Parameters:
|
267 |
+
key_states (`torch.Tensor`):
|
268 |
+
The new key states to cache.
|
269 |
+
value_states (`torch.Tensor`):
|
270 |
+
The new value states to cache.
|
271 |
+
layer_idx (`int`):
|
272 |
+
The index of the layer to cache the states for.
|
273 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
274 |
+
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
|
275 |
+
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
|
276 |
+
rotation as the tokens are shifted.
|
277 |
+
|
278 |
+
Return:
|
279 |
+
A tuple containing the updated key and value states.
|
280 |
+
"""
|
281 |
+
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
|
282 |
+
# with partially rotated position embeddings, like Phi or Persimmon.
|
283 |
+
sin = cache_kwargs.get("sin")
|
284 |
+
cos = cache_kwargs.get("cos")
|
285 |
+
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
286 |
+
using_rope = cos is not None and sin is not None
|
287 |
+
|
288 |
+
# Update the number of seen tokens
|
289 |
+
if layer_idx == 0:
|
290 |
+
self._seen_tokens += key_states.shape[-2]
|
291 |
+
|
292 |
+
# [bsz, num_heads, seq_len, head_dim]
|
293 |
+
if len(self.key_cache) <= layer_idx:
|
294 |
+
# Empty cache
|
295 |
+
self.key_cache.append(key_states)
|
296 |
+
self.value_cache.append(value_states)
|
297 |
+
|
298 |
+
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
299 |
+
# Growing cache
|
300 |
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
301 |
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
302 |
+
|
303 |
+
else:
|
304 |
+
# Shifting cache
|
305 |
+
keys_to_keep = self.key_cache[layer_idx][
|
306 |
+
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
|
307 |
+
]
|
308 |
+
|
309 |
+
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
310 |
+
if using_rope:
|
311 |
+
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
312 |
+
key_states, cos[: self.window_length], sin[: self.window_length]
|
313 |
+
)
|
314 |
+
if partial_rotation_size is not None:
|
315 |
+
keys_to_keep, keys_pass = (
|
316 |
+
keys_to_keep[..., :partial_rotation_size],
|
317 |
+
keys_to_keep[..., partial_rotation_size:],
|
318 |
+
)
|
319 |
+
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
|
320 |
+
if partial_rotation_size is not None:
|
321 |
+
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
|
322 |
+
|
323 |
+
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
|
324 |
+
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
|
325 |
+
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
|
326 |
+
|
327 |
+
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
|
328 |
+
values_to_keep = self.value_cache[layer_idx][
|
329 |
+
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
|
330 |
+
]
|
331 |
+
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
|
332 |
+
|
333 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
334 |
+
|
335 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
336 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
337 |
+
for layer_idx in range(len(self.key_cache)):
|
338 |
+
device = self.key_cache[layer_idx].device
|
339 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
340 |
+
device = self.value_cache[layer_idx].device
|
341 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
342 |
+
|
343 |
+
|
344 |
+
class StaticCache(Cache):
|
345 |
+
"""
|
346 |
+
Static Cache class to be used with `torch.compile(model)`.
|
347 |
+
|
348 |
+
Parameters:
|
349 |
+
config (`PretrainedConfig):
|
350 |
+
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
|
351 |
+
required to initialize the static cache.
|
352 |
+
max_batch_size (`int`):
|
353 |
+
The maximum batch size with which the model will be used.
|
354 |
+
max_cache_len (`int`):
|
355 |
+
The maximum sequence length with which the model will be used.
|
356 |
+
device (`torch.device`):
|
357 |
+
The device on which the cache should be initialized. Should be the same as the layer.
|
358 |
+
dtype (*optional*, defaults to `torch.float32`):
|
359 |
+
The default `dtype` to use when initializing the layer.
|
360 |
+
"""
|
361 |
+
|
362 |
+
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
|
363 |
+
super().__init__()
|
364 |
+
self.max_batch_size = max_batch_size
|
365 |
+
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
366 |
+
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
367 |
+
self.head_dim = (
|
368 |
+
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
369 |
+
)
|
370 |
+
|
371 |
+
self.dtype = dtype if dtype is not None else torch.float32
|
372 |
+
self.num_key_value_heads = (
|
373 |
+
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
374 |
+
)
|
375 |
+
|
376 |
+
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
377 |
+
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
378 |
+
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
379 |
+
|
380 |
+
def update(
|
381 |
+
self,
|
382 |
+
key_states: torch.Tensor,
|
383 |
+
value_states: torch.Tensor,
|
384 |
+
layer_idx: int,
|
385 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
386 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
387 |
+
"""
|
388 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
389 |
+
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
|
390 |
+
|
391 |
+
Parameters:
|
392 |
+
key_states (`torch.Tensor`):
|
393 |
+
The new key states to cache.
|
394 |
+
value_states (`torch.Tensor`):
|
395 |
+
The new value states to cache.
|
396 |
+
layer_idx (`int`):
|
397 |
+
The index of the layer to cache the states for. Kept for backward compatibility
|
398 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
399 |
+
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
|
400 |
+
to know how much of the cache it should overwrite.
|
401 |
+
|
402 |
+
Return:
|
403 |
+
A tuple containing the updated key and value states.
|
404 |
+
"""
|
405 |
+
new_cache_positions = cache_kwargs.get("cache_position")
|
406 |
+
k_out = self.key_cache
|
407 |
+
v_out = self.value_cache
|
408 |
+
|
409 |
+
k_out[:, :, new_cache_positions] = key_states
|
410 |
+
v_out[:, :, new_cache_positions] = value_states
|
411 |
+
|
412 |
+
return k_out, v_out
|
413 |
+
|
414 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
415 |
+
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
|
416 |
+
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
417 |
+
# limit the check to the first batch member and head dimension.
|
418 |
+
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
|
419 |
+
# https://github.com/pytorch/pytorch/issues/120248 is fixed
|
420 |
+
return (self.key_cache[0, 0].any(dim=-1)).sum()
|
421 |
+
|
422 |
+
def get_max_length(self) -> Optional[int]:
|
423 |
+
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
424 |
+
return self.max_cache_len
|
425 |
+
|
426 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
427 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
428 |
+
device = self.key_cache.device
|
429 |
+
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
|
430 |
+
device = self.value_cache.device
|
431 |
+
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
|
432 |
+
|
433 |
+
def to_legacy_cache(self):
|
434 |
+
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
|
435 |
+
return None
|