#include template __global__ void mat_mul(T *A, T *B, T *C, int N, int M, int K) { __shared__ T sA[TILE_SIZE][TILE_SIZE]; __shared__ T sB[TILE_SIZE][TILE_SIZE]; int bx = blockIdx.x, by = blockIdx.y; int tx = threadIdx.x, ty = threadIdx.y; int row = by * TILE_SIZE + ty; int col = bx * TILE_SIZE + tx; if (col >= K || row >= M) return; T sum = 0; int tiles_len = (M + TILE_SIZE - 1) / TILE_SIZE; for (int tile = 0; tile < tiles_len; tile++) { int aCol = tile * TILE_SIZE + tx; int bRow = tile * TILE_SIZE + ty; if (aCol < M) { sA[ty][tx] = A[row * M + aCol]; } else { sA[ty][tx] = 0; } sB[ty][tx] = (T)((uint64_t)B[bRow * K + col] & ((uint64_t)(bRow >= M) - 1)); __syncthreads(); for (int k = 0; k < TILE_SIZE; k++) { sum += sA[ty][k] * sB[k][tx]; } } C[row * K + col] = sum; } template __global__ void dumb_mat_mul(T *A, T *B, T *C, int N, int M, int K) { int col = blockIdx.x * blockDim.x + threadIdx.x; int row = blockIdx.y * blockDim.y + threadIdx.y; if (col >= K || row >= M) return; T sum = 0; for (int i = 0; i < M; i++) { sum += A[row * M + i] * B[i * K + col]; } C[row * K + col] = sum; } #define N 1024 #define M 1024 #define K 1024 #define NO_PRINT 1 #define GRID_DIM 1 #define BLOCK_DIM 32 #define MAT_TYPE int #define MAT_FMT "%d\t" #define A_LEN (N * M) #define B_LEN (M * K) #define C_LEN (N * K) #define A_SIZE (sizeof(MAT_TYPE) * N * M) #define B_SIZE (sizeof(MAT_TYPE) * M * K) #define C_SIZE (sizeof(MAT_TYPE) * N * K) #include #include #include using namespace std::chrono; template void mat_print(T *a, const char *fmt, int n, int m) { for (auto row = 0; row < n; row++) { for (auto col = 0; col < m; col++) { printf(fmt, a[row * m + col]); } printf("\n"); } } int main() { std::random_device rd; std::mt19937 engine(rd()); std::uniform_int_distribution dist(1, 10); auto buf = (MAT_TYPE *)malloc(A_SIZE + B_SIZE + C_SIZE); for (auto i = 0; i < A_LEN + B_LEN; i++) { buf[i] = dist(engine); } MAT_TYPE *a = buf; MAT_TYPE *b = a + A_LEN; MAT_TYPE *c = b + B_LEN; #if NO_PRINT==0 printf("\na\n"); mat_print(a, MAT_FMT, N, M); printf("\nb\n"); mat_print(b, MAT_FMT, M, K); #endif MAT_TYPE *d_a, *d_b, *d_c; cudaMalloc(&d_a, A_SIZE); cudaMalloc(&d_b, B_SIZE); cudaMalloc(&d_c, C_SIZE); cudaMemcpy(d_a, a, A_SIZE, cudaMemcpyHostToDevice); cudaMemcpy(d_b, b, B_SIZE, cudaMemcpyHostToDevice); dim3 gridDim(GRID_DIM, GRID_DIM); dim3 blockDim(BLOCK_DIM, BLOCK_DIM); int cycles = 0; microseconds duration(0); while (duration.count() < 1e6) { auto start = high_resolution_clock::now(); mat_mul<<>>(d_a, d_b, d_c, N, M, K); cudaDeviceSynchronize(); auto end = high_resolution_clock::now(); cycles++; duration += duration_cast(end - start); } #if NO_PRINT==0 cudaMemcpy(c, d_c, C_SIZE, cudaMemcpyDeviceToHost); printf("\nc\n"); mat_print(c, MAT_FMT, N, K); #endif printf("optimized mul take %f usec avg in %d cycles\n", (float)(duration.count()) / cycles, cycles); cycles = 0; duration = microseconds(0); while (duration.count() < 1e6) { auto start = high_resolution_clock::now(); dumb_mat_mul<<>>(d_a, d_b, d_c, N, M, K); cudaDeviceSynchronize(); auto end = high_resolution_clock::now(); cycles++; duration += duration_cast(end - start); } #if NO_PRINT==0 cudaMemcpy(c, d_c, C_SIZE, cudaMemcpyDeviceToHost); printf("\nc\n"); mat_print(c, MAT_FMT, N, K); #endif printf("dumb mul take %f usec avg in %d cycles\n", (float)(duration.count()) / cycles, cycles); cudaFree(a); cudaFree(b); cudaFree(c); free(buf); }