v4.46 support (#7)
Browse files- v4.46 support (5b585ec9a9cf91f67bd5696bc4df1090d08bd7fc)
- variable_cache.py +11 -9
variable_cache.py
CHANGED
@@ -32,18 +32,20 @@ class VariableCache(Cache_4_44_2, Cache):
|
|
32 |
The cache of each layer is allocated to the same gpu as the layer itself.
|
33 |
"""
|
34 |
|
35 |
-
def __init__(
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
43 |
Cache_4_44_2.__init__(self)
|
44 |
|
45 |
self.config = config
|
46 |
-
self.max_batch_size = max_batch_size
|
47 |
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
48 |
self.dtype = dtype
|
49 |
|
|
|
32 |
The cache of each layer is allocated to the same gpu as the layer itself.
|
33 |
"""
|
34 |
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
config: DeciLMConfig,
|
38 |
+
batch_size: int = None,
|
39 |
+
max_cache_len: int = None,
|
40 |
+
device: torch.device = None,
|
41 |
+
dtype: torch.dtype = torch.float32,
|
42 |
+
max_batch_size: Optional[int] = None,
|
43 |
+
**kwargs: Any,
|
44 |
+
) -> None:
|
45 |
Cache_4_44_2.__init__(self)
|
46 |
|
47 |
self.config = config
|
48 |
+
self.max_batch_size = batch_size or max_batch_size
|
49 |
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
50 |
self.dtype = dtype
|
51 |
|