Just a few suggested CUDA upgrades:
`#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
// Define TILE_SIZE for shared memory tiling
#define TILE_SIZE 16
// CUDA kernel for optimized low-precision GEMM operation using shared memory and tiling
global void gemm_lowbit_kernel(
const half* restrict A,
const half* restrict B,
half* restrict C,
int M, int N, int K)
{
// Shared memory for tiles of A and B
shared half As[TILE_SIZE][TILE_SIZE];
shared half Bs[TILE_SIZE][TILE_SIZE];
// Calculate row and column indices of C element to work on
int row = blockIdx.y * TILE_SIZE + threadIdx.y; // Row index of C to compute
int col = blockIdx.x * TILE_SIZE + threadIdx.x; // Column index of C to compute
// Initialize the accumulator to zero
float sum = 0.0f;
// Loop over tiles
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {
// Load elements of A and B into shared memory if within bounds
if (row < M && (t * TILE_SIZE + threadIdx.x) < K)
As[threadIdx.y][threadIdx.x] = A[row * K + t * TILE_SIZE + threadIdx.x];
else
As[threadIdx.y][threadIdx.x] = __float2half(0.0f);
if (col < N && (t * TILE_SIZE + threadIdx.y) < K)
Bs[threadIdx.y][threadIdx.x] = B[(t * TILE_SIZE + threadIdx.y) * N + col];
else
Bs[threadIdx.y][threadIdx.x] = __float2half(0.0f);
__syncthreads(); // Synchronize to ensure data is loaded
// Compute partial dot product for this tile
#pragma unroll
for (int k = 0; k < TILE_SIZE; ++k) {
half a_element = As[threadIdx.y][k];
half b_element = Bs[k][threadIdx.x];
sum += __half2float(__hmul(a_element, b_element));
}
__syncthreads(); // Synchronize before loading the next tile
}
// Write the result to the output matrix if within bounds
if (row < M && col < N)
C[row * N + col] = __float2half(sum);
}
// Wrapper function to call the CUDA kernel
void gemm_lowbit_cuda(at::Tensor a, at::Tensor b, at::Tensor c, int M, int N, int K) {
// Ensure that input tensors are contiguous and on the correct device
a = a.contiguous();
b = b.contiguous();
c = c.contiguous();
// Define block and grid dimensions
dim3 threads(TILE_SIZE, TILE_SIZE);
dim3 blocks((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
// Get the CUDA stream from PyTorch
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Launch the optimized kernel
gemm_lowbit_kernel<<<blocks, threads, 0, stream>>>(
reinterpret_cast<const half*>(a.data_ptr<at::Half>()),
reinterpret_cast<const half*>(b.data_ptr<at::Half>()),
reinterpret_cast<half*>(c.data_ptr<at::Half>()),
M, N, K);
// Check for kernel launch errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("CUDA kernel launch failed: %s\n", cudaGetErrorString(err));
}
}`
Pay now to fund the work behind this issue.
Get updates on progress being made.
Maintainer is rewarded once the issue is completed.
You're funding impactful open source efforts
You want to contribute to this effort
You want to get funding like this too