AoiKazama commited on
Commit
440f295
1 Parent(s): 9e5171f

Upload cache_utils.py

Browse files
Files changed (1) hide show
  1. cache_utils.py +435 -0
cache_utils.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+
6
+ from .configuration_utils import PretrainedConfig
7
+ from .utils import logging
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ @dataclass
14
+ class Cache:
15
+ """
16
+ Base, abstract class for all caches. The actual data structure is specific to each subclass.
17
+ """
18
+
19
+ def update(
20
+ self,
21
+ key_states: torch.Tensor,
22
+ value_states: torch.Tensor,
23
+ layer_idx: int,
24
+ cache_kwargs: Optional[Dict[str, Any]] = None,
25
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
26
+ """
27
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
28
+
29
+ Parameters:
30
+ key_states (`torch.Tensor`):
31
+ The new key states to cache.
32
+ value_states (`torch.Tensor`):
33
+ The new value states to cache.
34
+ layer_idx (`int`):
35
+ The index of the layer to cache the states for.
36
+ cache_kwargs (`Dict[str, Any]`, `optional`):
37
+ Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
38
+ cache to be created.
39
+
40
+ Return:
41
+ A tuple containing the updated key and value states.
42
+ """
43
+ raise NotImplementedError("Make sure to implement `update` in a subclass.")
44
+
45
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
46
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
47
+ raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
48
+
49
+ def get_max_length(self) -> Optional[int]:
50
+ """Returns the maximum sequence length of the cached states, if there is any."""
51
+ raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
52
+
53
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
54
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
55
+ # Cache without size limit -> all cache is usable
56
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
57
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
58
+ max_length = self.get_max_length()
59
+ previous_seq_length = self.get_seq_length(layer_idx)
60
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
61
+ return max_length - new_seq_length
62
+ return previous_seq_length
63
+
64
+ @property
65
+ def seen_tokens(self):
66
+ logger.warning_once(
67
+ "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
68
+ "model input instead."
69
+ )
70
+ if hasattr(self, "_seen_tokens"):
71
+ return self._seen_tokens
72
+ else:
73
+ return None
74
+
75
+
76
+ class DynamicCache(Cache):
77
+ """
78
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
79
+
80
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
81
+ `[batch_size, num_heads, seq_len, head_dim]`.
82
+ """
83
+
84
+ def __init__(self) -> None:
85
+ self.key_cache: List[torch.Tensor] = []
86
+ self.value_cache: List[torch.Tensor] = []
87
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
88
+
89
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
90
+ """
91
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
92
+ sequence length.
93
+ """
94
+ if layer_idx < len(self):
95
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
96
+ else:
97
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
98
+
99
+ def __iter__(self):
100
+ """
101
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
102
+ keys and values
103
+ """
104
+ for layer_idx in range(len(self)):
105
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
106
+
107
+ def __len__(self):
108
+ """
109
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
110
+ to the number of layers in the model.
111
+ """
112
+ return len(self.key_cache)
113
+
114
+ def update(
115
+ self,
116
+ key_states: torch.Tensor,
117
+ value_states: torch.Tensor,
118
+ layer_idx: int,
119
+ cache_kwargs: Optional[Dict[str, Any]] = None,
120
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
121
+ """
122
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
123
+
124
+ Parameters:
125
+ key_states (`torch.Tensor`):
126
+ The new key states to cache.
127
+ value_states (`torch.Tensor`):
128
+ The new value states to cache.
129
+ layer_idx (`int`):
130
+ The index of the layer to cache the states for.
131
+ cache_kwargs (`Dict[str, Any]`, `optional`):
132
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
133
+
134
+ Return:
135
+ A tuple containing the updated key and value states.
136
+ """
137
+ # Update the number of seen tokens
138
+ if layer_idx == 0:
139
+ self._seen_tokens += key_states.shape[-2]
140
+
141
+ # Update the cache
142
+ if len(self.key_cache) <= layer_idx:
143
+ self.key_cache.append(key_states)
144
+ self.value_cache.append(value_states)
145
+ else:
146
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
147
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
148
+
149
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
150
+
151
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
152
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
153
+ if len(self.key_cache) <= layer_idx:
154
+ return 0
155
+ return self.key_cache[layer_idx].shape[-2]
156
+
157
+ def get_max_length(self) -> Optional[int]:
158
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
159
+ return None
160
+
161
+ def reorder_cache(self, beam_idx: torch.LongTensor):
162
+ """Reorders the cache for beam search, given the selected beam indices."""
163
+ for layer_idx in range(len(self.key_cache)):
164
+ device = self.key_cache[layer_idx].device
165
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
166
+ device = self.value_cache[layer_idx].device
167
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
168
+
169
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
170
+ """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
171
+ legacy_cache = ()
172
+ for layer_idx in range(len(self)):
173
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
174
+ return legacy_cache
175
+
176
+ @classmethod
177
+ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
178
+ """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
179
+ cache = cls()
180
+ if past_key_values is not None:
181
+ for layer_idx in range(len(past_key_values)):
182
+ key_states, value_states = past_key_values[layer_idx]
183
+ cache.update(key_states, value_states, layer_idx)
184
+ return cache
185
+
186
+
187
+ class SinkCache(Cache):
188
+ """
189
+ A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
190
+ generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
191
+ tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
192
+
193
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
194
+ `[batch_size, num_heads, seq_len, head_dim]`.
195
+
196
+ Parameters:
197
+ window_length (`int`):
198
+ The length of the context window.
199
+ num_sink_tokens (`int`):
200
+ The number of sink tokens. See the original paper for more information.
201
+ """
202
+
203
+ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
204
+ self.key_cache: List[torch.Tensor] = []
205
+ self.value_cache: List[torch.Tensor] = []
206
+ self.window_length = window_length
207
+ self.num_sink_tokens = num_sink_tokens
208
+ self.cos_sin_cache = {}
209
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
210
+
211
+ @staticmethod
212
+ def _rotate_half(x):
213
+ x1 = x[..., : x.shape[-1] // 2]
214
+ x2 = x[..., x.shape[-1] // 2 :]
215
+ return torch.cat((-x2, x1), dim=-1)
216
+
217
+ def _apply_key_rotary_pos_emb(
218
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
219
+ ) -> torch.Tensor:
220
+ rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
221
+ return rotated_key_states
222
+
223
+ def _get_rerotation_cos_sin(
224
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
225
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
226
+ if key_states.shape[-2] not in self.cos_sin_cache:
227
+ # Upcast to float32 temporarily for better accuracy
228
+ cos = cos.to(torch.float32)
229
+ sin = sin.to(torch.float32)
230
+
231
+ # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
232
+ original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
233
+ shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
234
+ original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
235
+ shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
236
+ rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
237
+ rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
238
+
239
+ self.cos_sin_cache[key_states.shape[-2]] = (
240
+ rerotation_cos.to(key_states.dtype).unsqueeze(0),
241
+ rerotation_sin.to(key_states.dtype).unsqueeze(0),
242
+ )
243
+ return self.cos_sin_cache[key_states.shape[-2]]
244
+
245
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
246
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
247
+ # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
248
+ if len(self.key_cache) <= layer_idx:
249
+ return 0
250
+ return self.key_cache[layer_idx].shape[-2]
251
+
252
+ def get_max_length(self) -> Optional[int]:
253
+ """Returns the maximum sequence length of the cached states."""
254
+ return self.window_length
255
+
256
+ def update(
257
+ self,
258
+ key_states: torch.Tensor,
259
+ value_states: torch.Tensor,
260
+ layer_idx: int,
261
+ cache_kwargs: Optional[Dict[str, Any]] = None,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
263
+ """
264
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
265
+
266
+ Parameters:
267
+ key_states (`torch.Tensor`):
268
+ The new key states to cache.
269
+ value_states (`torch.Tensor`):
270
+ The new value states to cache.
271
+ layer_idx (`int`):
272
+ The index of the layer to cache the states for.
273
+ cache_kwargs (`Dict[str, Any]`, `optional`):
274
+ Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
275
+ `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
276
+ rotation as the tokens are shifted.
277
+
278
+ Return:
279
+ A tuple containing the updated key and value states.
280
+ """
281
+ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
282
+ # with partially rotated position embeddings, like Phi or Persimmon.
283
+ sin = cache_kwargs.get("sin")
284
+ cos = cache_kwargs.get("cos")
285
+ partial_rotation_size = cache_kwargs.get("partial_rotation_size")
286
+ using_rope = cos is not None and sin is not None
287
+
288
+ # Update the number of seen tokens
289
+ if layer_idx == 0:
290
+ self._seen_tokens += key_states.shape[-2]
291
+
292
+ # [bsz, num_heads, seq_len, head_dim]
293
+ if len(self.key_cache) <= layer_idx:
294
+ # Empty cache
295
+ self.key_cache.append(key_states)
296
+ self.value_cache.append(value_states)
297
+
298
+ elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
299
+ # Growing cache
300
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
301
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
302
+
303
+ else:
304
+ # Shifting cache
305
+ keys_to_keep = self.key_cache[layer_idx][
306
+ :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
307
+ ]
308
+
309
+ # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
310
+ if using_rope:
311
+ rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
312
+ key_states, cos[: self.window_length], sin[: self.window_length]
313
+ )
314
+ if partial_rotation_size is not None:
315
+ keys_to_keep, keys_pass = (
316
+ keys_to_keep[..., :partial_rotation_size],
317
+ keys_to_keep[..., partial_rotation_size:],
318
+ )
319
+ keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
320
+ if partial_rotation_size is not None:
321
+ keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
322
+
323
+ # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
324
+ sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
325
+ self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
326
+
327
+ sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
328
+ values_to_keep = self.value_cache[layer_idx][
329
+ :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
330
+ ]
331
+ self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
332
+
333
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
334
+
335
+ def reorder_cache(self, beam_idx: torch.LongTensor):
336
+ """Reorders the cache for beam search, given the selected beam indices."""
337
+ for layer_idx in range(len(self.key_cache)):
338
+ device = self.key_cache[layer_idx].device
339
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
340
+ device = self.value_cache[layer_idx].device
341
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
342
+
343
+
344
+ class StaticCache(Cache):
345
+ """
346
+ Static Cache class to be used with `torch.compile(model)`.
347
+
348
+ Parameters:
349
+ config (`PretrainedConfig):
350
+ The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
351
+ required to initialize the static cache.
352
+ max_batch_size (`int`):
353
+ The maximum batch size with which the model will be used.
354
+ max_cache_len (`int`):
355
+ The maximum sequence length with which the model will be used.
356
+ device (`torch.device`):
357
+ The device on which the cache should be initialized. Should be the same as the layer.
358
+ dtype (*optional*, defaults to `torch.float32`):
359
+ The default `dtype` to use when initializing the layer.
360
+ """
361
+
362
+ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
363
+ super().__init__()
364
+ self.max_batch_size = max_batch_size
365
+ self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
366
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
367
+ self.head_dim = (
368
+ config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
369
+ )
370
+
371
+ self.dtype = dtype if dtype is not None else torch.float32
372
+ self.num_key_value_heads = (
373
+ config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
374
+ )
375
+
376
+ cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
377
+ self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
378
+ self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
379
+
380
+ def update(
381
+ self,
382
+ key_states: torch.Tensor,
383
+ value_states: torch.Tensor,
384
+ layer_idx: int,
385
+ cache_kwargs: Optional[Dict[str, Any]] = None,
386
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
387
+ """
388
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
389
+ It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
390
+
391
+ Parameters:
392
+ key_states (`torch.Tensor`):
393
+ The new key states to cache.
394
+ value_states (`torch.Tensor`):
395
+ The new value states to cache.
396
+ layer_idx (`int`):
397
+ The index of the layer to cache the states for. Kept for backward compatibility
398
+ cache_kwargs (`Dict[str, Any]`, `optional`):
399
+ Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
400
+ to know how much of the cache it should overwrite.
401
+
402
+ Return:
403
+ A tuple containing the updated key and value states.
404
+ """
405
+ new_cache_positions = cache_kwargs.get("cache_position")
406
+ k_out = self.key_cache
407
+ v_out = self.value_cache
408
+
409
+ k_out[:, :, new_cache_positions] = key_states
410
+ v_out[:, :, new_cache_positions] = value_states
411
+
412
+ return k_out, v_out
413
+
414
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
415
+ """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
416
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
417
+ # limit the check to the first batch member and head dimension.
418
+ # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
419
+ # https://github.com/pytorch/pytorch/issues/120248 is fixed
420
+ return (self.key_cache[0, 0].any(dim=-1)).sum()
421
+
422
+ def get_max_length(self) -> Optional[int]:
423
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
424
+ return self.max_cache_len
425
+
426
+ def reorder_cache(self, beam_idx: torch.LongTensor):
427
+ """Reorders the cache for beam search, given the selected beam indices."""
428
+ device = self.key_cache.device
429
+ self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
430
+ device = self.value_cache.device
431
+ self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
432
+
433
+ def to_legacy_cache(self):
434
+ """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
435
+ return None