# coding=utf-8 # Copyright 2024 Nvidia Corporation. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from typing import Optional, Dict, Any, Tuple import torch from transformers.cache_utils import Cache # used to let GenerationMixin know that we use a Cache object from .configuration_decilm import DeciLMConfig, AttentionConfig from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2, StaticCache class VariableCache(Cache_4_44_2, Cache): """ A Cache object that supports a different Cache implementation for every layer, including layers without any kv-cache. Implemented using a list of Cache objects, each represents a "model" with 1 layer. The default implementation for the layer caches is StaticCache. The cache of each layer is allocated to the same gpu as the layer itself. """ def __init__(self, config: DeciLMConfig, max_batch_size: int, max_cache_len: int | None, device: torch.device | str | None = None, dtype: torch.dtype | None = None, **kwargs: Any, ): Cache_4_44_2.__init__(self) self.config = config self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.dtype = dtype self.layer_caches: list[Cache | None] = [None] * config.num_hidden_layers def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: layer_cache = self.layer_caches[layer_idx] if layer_cache is None: block_config = self.config.block_configs[layer_idx] layer_cache = self._init_layer_cache(attention_config=block_config.attention, device=key_states.device) assert layer_cache is not None, "Trying to update the cache of a cache-less layer" self.layer_caches[layer_idx] = layer_cache k_out, v_out = layer_cache.update(key_states=key_states, value_states=value_states, layer_idx=0, cache_kwargs=cache_kwargs) seq_len = self.get_seq_length(layer_idx) k_out = k_out[:, :, :seq_len, :] v_out = v_out[:, :, :seq_len, :] return k_out, v_out def _init_layer_cache(self, attention_config: AttentionConfig, device: torch.device, ) -> Cache | None: if attention_config.no_op or attention_config.replace_with_linear: return None config = deepcopy(self.config) config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group return StaticCache(config, self.max_batch_size, self.max_cache_len, device, self.dtype) def _get_first_real_cache(self) -> Cache: for layer_cache in self.layer_caches: if layer_cache is not None: return layer_cache raise ValueError(f"No real cache found, all layer caches are None.") def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: if layer_idx == 0 and self.layer_caches[0] is None: try: layer_cache = self._get_first_real_cache() except ValueError: return 0 else: layer_cache = self.layer_caches[layer_idx] return layer_cache.get_seq_length() def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states.""" return self.max_cache_len def reset(self): for layer_cache in self.layer_caches: if hasattr(layer_cache, "reset"): layer_cache.reset()