修改 quantization.py 中待量化权重的移动逻辑

#47
Files changed (1) hide show
  1. quantization.py +6 -5
quantization.py CHANGED
@@ -125,8 +125,9 @@ class QuantizedLinear(torch.nn.Module):
125
  def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
126
  **kwargs):
127
  super().__init__()
 
 
128
  self.weight_bit_width = weight_bit_width
129
-
130
  shape = weight.shape
131
 
132
  if weight is None or empty_init:
@@ -154,7 +155,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
154
  for layer in model.layers:
155
  layer.self_attention.query_key_value = QuantizedLinear(
156
  weight_bit_width=weight_bit_width,
157
- weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()),
158
  bias=layer.self_attention.query_key_value.bias,
159
  dtype=layer.self_attention.query_key_value.weight.dtype,
160
  device=layer.self_attention.query_key_value.weight.device if device is None else device,
@@ -162,7 +163,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
162
  )
163
  layer.self_attention.dense = QuantizedLinear(
164
  weight_bit_width=weight_bit_width,
165
- weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()),
166
  bias=layer.self_attention.dense.bias,
167
  dtype=layer.self_attention.dense.weight.dtype,
168
  device=layer.self_attention.dense.weight.device if device is None else device,
@@ -170,7 +171,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
170
  )
171
  layer.mlp.dense_h_to_4h = QuantizedLinear(
172
  weight_bit_width=weight_bit_width,
173
- weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
174
  bias=layer.mlp.dense_h_to_4h.bias,
175
  dtype=layer.mlp.dense_h_to_4h.weight.dtype,
176
  device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
@@ -178,7 +179,7 @@ def quantize(model, weight_bit_width, empty_init=False, device=None):
178
  )
179
  layer.mlp.dense_4h_to_h = QuantizedLinear(
180
  weight_bit_width=weight_bit_width,
181
- weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
182
  bias=layer.mlp.dense_4h_to_h.bias,
183
  dtype=layer.mlp.dense_4h_to_h.weight.dtype,
184
  device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
 
125
  def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
126
  **kwargs):
127
  super().__init__()
128
+ assert str(weight.device).startswith('cuda'), 'The weights that need to be quantified should be on the CUDA device'
129
+
130
  self.weight_bit_width = weight_bit_width
 
131
  shape = weight.shape
132
 
133
  if weight is None or empty_init:
 
155
  for layer in model.layers:
156
  layer.self_attention.query_key_value = QuantizedLinear(
157
  weight_bit_width=weight_bit_width,
158
+ weight=layer.self_attention.query_key_value.weight,
159
  bias=layer.self_attention.query_key_value.bias,
160
  dtype=layer.self_attention.query_key_value.weight.dtype,
161
  device=layer.self_attention.query_key_value.weight.device if device is None else device,
 
163
  )
164
  layer.self_attention.dense = QuantizedLinear(
165
  weight_bit_width=weight_bit_width,
166
+ weight=layer.self_attention.dense.weight,
167
  bias=layer.self_attention.dense.bias,
168
  dtype=layer.self_attention.dense.weight.dtype,
169
  device=layer.self_attention.dense.weight.device if device is None else device,
 
171
  )
172
  layer.mlp.dense_h_to_4h = QuantizedLinear(
173
  weight_bit_width=weight_bit_width,
174
+ weight=layer.mlp.dense_h_to_4h.weight,
175
  bias=layer.mlp.dense_h_to_4h.bias,
176
  dtype=layer.mlp.dense_h_to_4h.weight.dtype,
177
  device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
 
179
  )
180
  layer.mlp.dense_4h_to_h = QuantizedLinear(
181
  weight_bit_width=weight_bit_width,
182
+ weight=layer.mlp.dense_4h_to_h.weight,
183
  bias=layer.mlp.dense_4h_to_h.bias,
184
  dtype=layer.mlp.dense_4h_to_h.weight.dtype,
185
  device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,