Add config_class
Browse files- modeling_bit_llama.py +2 -0
modeling_bit_llama.py
CHANGED
@@ -63,6 +63,8 @@ class BitLlamaDecoderLayer(LlamaDecoderLayer):
|
|
63 |
self.mlp = BitLlamaMLP(config)
|
64 |
|
65 |
class BitLlamaModel(LlamaModel):
|
|
|
|
|
66 |
def __init__(self, config: BitLlamaConfig):
|
67 |
super().__init__(config)
|
68 |
self.layers = nn.ModuleList(
|
|
|
63 |
self.mlp = BitLlamaMLP(config)
|
64 |
|
65 |
class BitLlamaModel(LlamaModel):
|
66 |
+
config_class = BitLlamaConfig
|
67 |
+
|
68 |
def __init__(self, config: BitLlamaConfig):
|
69 |
super().__init__(config)
|
70 |
self.layers = nn.ModuleList(
|