add batch_size attribute to VariableCache (#15)
Browse files- add batch_size attribute to VariableCache (371f97675dd2a63a873f4f9eb71909b6a428eaac)
- variable_cache.py +3 -2
variable_cache.py
CHANGED
@@ -34,18 +34,19 @@ class VariableCache(Cache_4_44_2, Cache):
|
|
34 |
|
35 |
def __init__(
|
36 |
self,
|
|
|
37 |
config: DeciLMConfig,
|
38 |
batch_size: int = None,
|
39 |
max_cache_len: int = None,
|
40 |
-
device: torch.device = None,
|
41 |
dtype: torch.dtype = torch.float32,
|
42 |
max_batch_size: Optional[int] = None,
|
43 |
**kwargs: Any,
|
44 |
) -> None:
|
45 |
Cache_4_44_2.__init__(self)
|
46 |
|
47 |
-
self.config = config
|
48 |
self.max_batch_size = batch_size or max_batch_size
|
|
|
49 |
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
50 |
self.dtype = dtype
|
51 |
|
|
|
34 |
|
35 |
def __init__(
|
36 |
self,
|
37 |
+
*, # key-word only, no positional args allowed to avoid mix-ups with newer transformers versions
|
38 |
config: DeciLMConfig,
|
39 |
batch_size: int = None,
|
40 |
max_cache_len: int = None,
|
|
|
41 |
dtype: torch.dtype = torch.float32,
|
42 |
max_batch_size: Optional[int] = None,
|
43 |
**kwargs: Any,
|
44 |
) -> None:
|
45 |
Cache_4_44_2.__init__(self)
|
46 |
|
47 |
+
self.config = deepcopy(config)
|
48 |
self.max_batch_size = batch_size or max_batch_size
|
49 |
+
self.batch_size = self.max_batch_size
|
50 |
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
51 |
self.dtype = dtype
|
52 |
|