|
|
|
|
|
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////// |
|
////////////////////////////////////////////////////////////////////////////////////////////////// |
|
|
|
std::vector<at::Tensor> index_max_kernel( |
|
at::Tensor index_vals, // [batch_size, 32, num_block] |
|
at::Tensor indices, // [batch_size, num_block], |
|
int A_num_block, |
|
int B_num_block |
|
) { |
|
int batch_size = indices.size(0); |
|
int num_block = indices.size(1); |
|
|
|
at::Tensor max_vals = at::zeros({batch_size, A_num_block * 32}, index_vals.options()); |
|
at::Tensor max_vals_scatter = at::zeros({batch_size, 32, num_block}, index_vals.options()); |
|
|
|
dim3 threads(256); |
|
dim3 blocks(batch_size); |
|
int shared_mem = A_num_block * 32 * sizeof(float); |
|
|
|
index_max_cuda_kernel<<<blocks, threads, shared_mem>>>( |
|
index_vals.data_ptr<float>(), |
|
indices.data_ptr<int>(), |
|
max_vals.data_ptr<float>(), |
|
max_vals_scatter.data_ptr<float>(), |
|
batch_size, |
|
A_num_block, |
|
B_num_block, |
|
num_block |
|
); |
|
|
|
return {max_vals, max_vals_scatter}; |
|
} |
|
|
|
at::Tensor mm_to_sparse_kernel( |
|
at::Tensor dense_A, // [batch_size, A_num_block, dim, 32] |
|
at::Tensor dense_B, // [batch_size, B_num_block, dim, 32] |
|
at::Tensor indices // [batch_size, num_block] |
|
) { |
|
int batch_size = dense_A.size(0); |
|
int A_num_block = dense_A.size(1); |
|
int B_num_block = dense_B.size(1); |
|
int dim = dense_A.size(2); |
|
int num_block = indices.size(1); |
|
|
|
at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options()); |
|
|
|
dim3 threads(64, 4); |
|
dim3 blocks(num_block / 4, batch_size); |
|
|
|
mm_to_sparse_cuda_kernel<<<blocks, threads>>>( |
|
dense_A.data_ptr<float>(), |
|
dense_B.data_ptr<float>(), |
|
indices.data_ptr<int>(), |
|
sparse_C.data_ptr<float>(), |
|
batch_size, |
|
A_num_block, |
|
B_num_block, |
|
dim, |
|
num_block |
|
); |
|
|
|
return sparse_C; |
|
} |
|
|
|
at::Tensor sparse_dense_mm_kernel( |
|
at::Tensor sparse_A, // [batch_size, num_block, 32, 32] |
|
at::Tensor indices, // [batch_size, num_block] |
|
at::Tensor dense_B, // [batch_size, B_num_block, dim, 32] |
|
int A_num_block |
|
) { |
|
int batch_size = sparse_A.size(0); |
|
int num_block = sparse_A.size(1); |
|
int B_num_block = dense_B.size(1); |
|
int dim = dense_B.size(2); |
|
|
|
at::Tensor dense_C = at::zeros({batch_size, A_num_block, dim, 32}, dense_B.options()); |
|
|
|
dim3 threads(128, 2); |
|
dim3 blocks(num_block / 2, batch_size); |
|
|
|
sparse_dense_mm_cuda_kernel<<<blocks, threads>>>( |
|
sparse_A.data_ptr<float>(), |
|
indices.data_ptr<int>(), |
|
dense_B.data_ptr<float>(), |
|
dense_C.data_ptr<float>(), |
|
batch_size, |
|
A_num_block, |
|
B_num_block, |
|
dim, |
|
num_block |
|
); |
|
|
|
return dense_C; |
|
} |
|
|
|
at::Tensor reduce_sum_kernel( |
|
at::Tensor sparse_A, // [batch_size, num_block, 32, 32] |
|
at::Tensor indices, // [batch_size, num_block] |
|
int A_num_block, |
|
int B_num_block |
|
) { |
|
int batch_size = sparse_A.size(0); |
|
int num_block = sparse_A.size(1); |
|
|
|
at::Tensor dense_C = at::zeros({batch_size, A_num_block, 32}, sparse_A.options()); |
|
|
|
dim3 threads(32, 4); |
|
dim3 blocks(num_block / 4, batch_size); |
|
|
|
reduce_sum_cuda_kernel<<<blocks, threads>>>( |
|
sparse_A.data_ptr<float>(), |
|
indices.data_ptr<int>(), |
|
dense_C.data_ptr<float>(), |
|
batch_size, |
|
A_num_block, |
|
B_num_block, |
|
num_block |
|
); |
|
|
|
return dense_C; |
|
} |
|
|
|
at::Tensor scatter_kernel( |
|
at::Tensor dense_A, // [batch_size, A_num_block, 32] |
|
at::Tensor indices, // [batch_size, num_block] |
|
int B_num_block |
|
) { |
|
int batch_size = dense_A.size(0); |
|
int A_num_block = dense_A.size(1); |
|
int num_block = indices.size(1); |
|
|
|
at::Tensor sparse_C = at::zeros({batch_size, num_block, 32, 32}, dense_A.options()); |
|
|
|
dim3 threads(32, 4); |
|
dim3 blocks(num_block / 4, batch_size); |
|
|
|
scatter_cuda_kernel<<<blocks, threads>>>( |
|
dense_A.data_ptr<float>(), |
|
indices.data_ptr<int>(), |
|
sparse_C.data_ptr<float>(), |
|
batch_size, |
|
A_num_block, |
|
B_num_block, |
|
num_block |
|
); |
|
|
|
return sparse_C; |
|
} |
|
|