update kernels
Browse files
kernels/cache_autogptq_cuda_256.cpp → cache_autogptq_cuda_256.cpp
RENAMED
File without changes
|
kernels/cache_autogptq_cuda_kernel_256.cu → cache_autogptq_cuda_kernel_256.cu
RENAMED
File without changes
|
kernels/cpp_kernels.py → cpp_kernels.py
RENAMED
@@ -50,6 +50,6 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
|
50 |
|
51 |
extra_flags = []
|
52 |
|
53 |
-
cache_autogptq_cuda_256_sources = ["./
|
54 |
-
"./
|
55 |
cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags)
|
|
|
50 |
|
51 |
extra_flags = []
|
52 |
|
53 |
+
cache_autogptq_cuda_256_sources = ["./cache_autogptq_cuda_256.cpp",
|
54 |
+
"./cache_autogptq_cuda_kernel_256.cu"]
|
55 |
cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags)
|
modeling_qwen.py
CHANGED
@@ -32,11 +32,6 @@ except ImportError:
|
|
32 |
rearrange = None
|
33 |
from torch import nn
|
34 |
|
35 |
-
try:
|
36 |
-
from kernels.cpp_kernels import cache_autogptq_cuda_256
|
37 |
-
except ImportError:
|
38 |
-
cache_autogptq_cuda_256 = None
|
39 |
-
|
40 |
SUPPORT_CUDA = torch.cuda.is_available()
|
41 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
42 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
@@ -294,14 +289,21 @@ class QWenAttention(nn.Module):
|
|
294 |
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
295 |
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
|
298 |
device = query.device
|
299 |
if self.use_cache_quantization:
|
300 |
qk, qk_scale, qk_zero = key
|
301 |
-
if self.use_cache_kernel and
|
302 |
shape = query.shape[:-1] + (qk.shape[-2],)
|
303 |
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
304 |
-
|
305 |
query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
|
306 |
qk.transpose(-1, -2).contiguous(),
|
307 |
attn_weights,
|
@@ -353,10 +355,10 @@ class QWenAttention(nn.Module):
|
|
353 |
|
354 |
if self.use_cache_quantization:
|
355 |
qv, qv_scale, qv_zero = value
|
356 |
-
if self.use_cache_kernel and
|
357 |
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
358 |
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
359 |
-
|
360 |
attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
|
361 |
qv.contiguous(), # dtype: int32
|
362 |
attn_output,
|
@@ -1022,15 +1024,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
1022 |
if config.use_flash_attn:
|
1023 |
_import_flash_attn()
|
1024 |
|
1025 |
-
|
1026 |
-
if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization:
|
1027 |
-
config.use_flash_attn = False
|
1028 |
-
if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel:
|
1029 |
-
try:
|
1030 |
-
from kernels.cpp_kernels import cache_autogptq_cuda_256
|
1031 |
-
except ImportError:
|
1032 |
-
cache_autogptq_cuda_256 = None
|
1033 |
-
|
1034 |
self.transformer = QWenModel(config)
|
1035 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1036 |
|
|
|
32 |
rearrange = None
|
33 |
from torch import nn
|
34 |
|
|
|
|
|
|
|
|
|
|
|
35 |
SUPPORT_CUDA = torch.cuda.is_available()
|
36 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
37 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
|
|
289 |
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
290 |
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
291 |
|
292 |
+
if config.use_cache_quantization and config.use_cache_kernel:
|
293 |
+
from .cpp_kernels import cache_autogptq_cuda_256
|
294 |
+
try:
|
295 |
+
self.cache_kernels = cache_autogptq_cuda_256
|
296 |
+
except ImportError:
|
297 |
+
self.cache_kernels = None
|
298 |
+
|
299 |
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
|
300 |
device = query.device
|
301 |
if self.use_cache_quantization:
|
302 |
qk, qk_scale, qk_zero = key
|
303 |
+
if self.use_cache_kernel and self.cache_kernels is not None:
|
304 |
shape = query.shape[:-1] + (qk.shape[-2],)
|
305 |
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
306 |
+
self.cache_kernels.vecquant8matmul_batched_faster_old(
|
307 |
query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
|
308 |
qk.transpose(-1, -2).contiguous(),
|
309 |
attn_weights,
|
|
|
355 |
|
356 |
if self.use_cache_quantization:
|
357 |
qv, qv_scale, qv_zero = value
|
358 |
+
if self.use_cache_kernel and self.cache_kernels is not None:
|
359 |
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
360 |
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
361 |
+
self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
|
362 |
attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
|
363 |
qv.contiguous(), # dtype: int32
|
364 |
attn_output,
|
|
|
1024 |
if config.use_flash_attn:
|
1025 |
_import_flash_attn()
|
1026 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1027 |
self.transformer = QWenModel(config)
|
1028 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1029 |
|