CUDA GEMM: Intro & Naive Implementation
中文版本请访问知乎 Chinese Version Here
1. 引言 (Introduction)
什么是 GEMM
- 通用矩阵乘法 (General Matrix Multiply, GEMM) 运算定义为 $C = A \times B$。给定矩阵 $A \in \mathbb{R}^{M \times K}$ 与 $B \in \mathbb{R}^{K \times N}$,计算目标矩阵 $C \in \mathbb{R}^{M \times N}$。
为何优化 GEMM
- 手写并逐级优化 GEMM 核函数,来深入理解底层硬件架构(内存层级、计算单元、指令调度)。
目标
- 实现并逐步优化单精度浮点 (FP 32) GEMM 核函数。目标将算法从内存受限 (Memory-Bound) 的基础实现,优化至接近 NVIDIA RTX 5080 (Blackwell) 硬件极限的高利用率架构。
性能度量体系 (Measurement)
总浮点运算量 (Total FLOPs)
针对输出矩阵 $C$ 中的每一个元素,需执行 $K$ 次乘法与 $K$ 次加法(在 GPU 中通常融合成 FMA 指令)。全局计算负载计算如下:
\[Total\ FLOPs = 2MNK\]实际吞吐量 (Achieved TFLOPS)
核函数执行期间的实际计算性能:
\[Achieved\ TFLOPS = \frac{2MNK}{Execution\ Time\ (s) \times 10^{12}}\]算术强度与 Roofline 模型 (Arithmetic Intensity)
算术强度定义为计算量与内存数据传输量的比值,用于判断算子是计算受限还是内存受限:
\[Arithmetic\ Intensity = \frac{Total\ FLOPs}{Total\ Memory\ Traffic\ (bytes)}\]硬件参数与机器平衡点 (RTX 5080)
- 理论峰值算力 (FP 32):约 56.28 TFLOPS
- 峰值显存带宽 (GDDR 7):约 960 GB/s
计算该硬件的机器平衡点 (Machine Balance):
\[Machine\ Balance = \frac{Peak\ Compute}{Peak\ Bandwidth} = \frac{56280\ GFLOPs/s}{960\ GB/s} \approx 58.625\ FLOPs/Byte\]- 算子的算术强度低于 58.625 FLOPs/Byte 时,可以认为 kernel 是 Memory Bound。
2 基础实现 (Naive GEMM)
2.1 核心思想
采用最直接的映射策略。每个 CUDA Thread 负责计算输出矩阵 $C$ 中的一个元素。最小的可操控单元对应最小的矩阵元素。
__global__ void gemm_naive_kernel(const float* A, const float* B, float* C, int M, int N, int K) {
// 映射 2D 线程网格/块索引至矩阵的行与列
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
// 边界检查,防止非对齐维度的越界访问
if (row < M && col < N) {
float sum = 0.0f;
// 计算 A 的第 row 行与 B 的第 col 列的内积
for (int i = 0; i < K; ++i) {
sum += A[row * K + i] * B[i * N + col];
}
// 将结果写回全局内存
C[row * N + col] = sum;
}
}
void launch_gemm_naive(const float* A, const float* B, float* C, int M, int N, int K) {
// 设定标准的 16x16 线程块
dim3 block(16, 16);
// 向上取整计算 Grid 维度
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
gemm_naive_kernel<<<grid, block>>>(A, B, C, M, N, K);
}
2.2 代码级内存映射解析
- 多维索引与物理内存的映射 在 C/C++ 中,二维数组的物理内存排布为行优先 (Row-Major)。即列下标
col是内存地址中最内层(连续)的变化维度。相邻两列 ([row][col]与[row][col+1]) 在物理内存中相邻。 - 线程维度绑定策略 GPU 执行指令的最小调度单位是 Warp(包含 32 个线程)。在同一个 Warp 内,
threadIdx.x的编号是连续的(是最内层(连续)的变化维度)。为促成合并内存访问 (Coalesced Access),必须将矩阵的连续内存维度映射到线程的连续编号上。因此代码逻辑设定为:矩阵的列索引col绑定至threadIdx.x,行索引row绑定至threadIdx.y。 - Warp 内的访问行为理论判断
- 读取
B[i * N + col]:索引依赖col(threadIdx.x)。Warp 内相邻线程读取连续的地址,形成连续访存。 - 读取
A[row * K + i]:索引依赖row(threadIdx.y)。若 Warp 内所有线程的threadIdx.y相同,则形成同址广播访存。
- 读取
2.3 Profiling 数据量化验证
算术强度评估
针对输出矩阵 $C$ 中的单一元素:
- 计算量:执行 $K$ 次乘法与 $K$ 次加法,总计 $2K$ FLOPs。
- 访存量:全局内存读取矩阵 A 的一行($K$ 个 float)与矩阵 B 的一列($K$ 个 float),总计读取 $2K$ 个浮点数。数据传输量为 $2K \times 4$ bytes = $8K$ bytes。
- 算术强度:$2K / 8K = \mathbf{0.25\ FLOPs/Byte}$。 与 RTX 5080 的机器平衡点 ($58.625$) 相比,该算法严重受限于内存带宽。
底层物理内存搬运分析
传输原则基准
- Sector(扇区):GPU 显存(DRAM/L 2 Cache)与计算核心(SM/L 1 Cache)之间数据传输的最小不可分割物理单位,固定为 32 字节。
- 在理想(合并访存)状态下,一个 warp 的 32 个线程访问连续的物理内存地址地址,且地址上的信息都被使用。假设数据类型为
float(4 字节)。一次单次全局读取命令需要 32 线程 * 4 字节 = 128 字节,即 4 个 连续 sectors.
访存行为微观拆解 (基于 dim3 block(16, 16)) 由于 Block 宽度为 16,无法达到合并访存状态。
- 读取 B 矩阵 (
B[i * N + col]):- 行为:前 16 线程与后 16 线程请求完全重复的
col范围 (0~15)。 - 有效数据:16 个不重复 float,计 64 字节。
- 物理执行:触发 2 个连续扇区。
- 行为:前 16 线程与后 16 线程请求完全重复的
- 读取 A 矩阵 (
A[row * K + i]):- 行为:前 16 线程与后 16 线程分别请求相隔一整行(K 个元素)的物理地址。
- 有效数据:2 个不重复 float,计 8 字节。
- 物理执行:因地址跨度过大,强制触发 2 个独立扇区。
NCU 硬件计数器印证
gemm_naive_kernel(const float *, const float *, float *, int, int, int) (128, 128, 1)x(16, 16, 1), Context 1, Stream 7, Device 0, CC 12.0
Section: Command line profiler metrics
----------------------------------------------- ----------- -------------
Metric Name Metric Unit Metric Value
----------------------------------------------- ----------- -------------
l1tex__t_requests_pipe_lsu_mem_global_op_ld.sum 536,870,912
l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum sector 1,073,741,824
----------------------------------------------- ----------- -------------
指标一:扇区请求比 (Sectors per Request)
- 理论计算:循环内发出的 2 次请求(读 A + 读 B),共计触发 4 个扇区。比值为 $4/2 = \mathbf{2.0\ 扇区/请求}$。
- 实测数据:
1,073,741,824 / 536,870,912 = 2.0。数据吻合。
Section: Memory Workload Analysis Tables
OPT Est. Speedup: 42.18%
The memory access pattern for global loads from L1TEX might not be optimal. On average, only 18.0 of the 32
bytes transmitted per sector are utilized by each thread. This could possibly be caused by a stride between
threads. Check the Source Counters section for uncoalesced global loads.
指标二:平均扇区有效利用率 (bytes per Sector)
- 单次循环有效利用数据:读 B (64 B) + 读 A (8 B) = 72 bytes。
- 单次循环实际搬运数据:4 个扇区 $\times$ 32 B = 128 bytes。
- 总共利用了 4 个扇区,平均每个扇区的 32 个 bytes 中,有效数据量为 18 bytes。
- 实测数据:
only 18.0 of the 32 bytes。数据吻合。
2.4 优化点:物理对齐
修改 blockDim.x 为 32 的倍数
内存访问分析
Section: Memory Workload Analysis Tables
OPT Est. Speedup: 17.21%
The memory access pattern for global loads from L1TEX might not be optimal. On average, only 26.4 of the 32
bytes transmitted per sector are utilized by each thread. This could possibly be caused by a stride between
threads. Check the Source Counters section for uncoalesced global loads.
- 取矩阵 B (
B[i * N + col]):100% 满载合并访存- 行为:Warp 内 32 个线程请求连续的 32 个列地址(
col连续)。 - 有效数据量:32 个独立
float,共计 128 字节。 - 物理行为:数据量精确覆盖 4 个对齐的 32 字节扇区。LSU 触发单次 128 字节的合并事务,平均每扇区利用率达到 32/32 字节的物理极值。
- 行为:Warp 内 32 个线程请求连续的 32 个列地址(
- 读取矩阵 A (
A[row * K + i]):同址完美广播- 行为:Warp 内 32 个线程计算处于同一逻辑行(
row相同),请求完全相同的单一物理地址。 - 有效数据量:1 个
float,共计 4 字节。 - 物理行为:LSU 触发 1 个扇区 (32 字节) 请求。数据抵达 L 1 Cache 后,触发 L 1 内部同周期广播(Broadcast)机制,将单一数据同时分发至 32 个 SIMT 通道,消除跨行离散惩罚。
- 行为:Warp 内 32 个线程计算处于同一逻辑行(
- 平均扇区有效利用率 (bytes per Sector)
- 单次循环有效利用数据:读 B (128 B) + 读 A (4B) = 132 bytes。
- 单次循环实际搬运数据:5个扇区 $\times$ 32 B = 160 bytes。
- 总共利用了 4 个扇区,平均每个扇区的 32 个 bytes 中,有效数据量为 26.5 bytes。
- 实测数据:
only 26.4 of the 32 bytes。数据吻合。
算术强度分析
- 仅解决了“单次请求内部的物理利用率”问题,未解决“跨请求的数据复用”问题。每次执行乘加运算(FMA),线程依然需要去全局内存 (Global Memory) 发起真实的物理读取。同一块数据矩阵依然被不同的 Thread Block 重复拉取无数次。无法提高算术强度
2.5 Performance
| M | N | K | Kernel | Time (ms) | TFLOPS |
|---|---|---|---|---|---|
| 2048 | 2048 | 2048 | NAIVE_16 X 16 | 5.1235 | 3.35 |
| 2048 | 2048 | 2048 | NAIVE_32 X 16 | 4.9678 | 3.46 |
| 2048 | 2048 | 2048 | CUBLAS | 0.5127 | 33.51 |
- $Total FLOPS =2 * 2048^{3} \approx 17.18\;GFLOPS$
2.6 总结
[ 计算核心 (CUDA Cores) ]
| ^
| | 4 Bytes (单线程单次请求 1 个 float)
v |
+-----------------------+
| 寄存器 (Registers) | <-- 累加操作 (sum += ...) 发生在此处
+-----------------------+
^
| L1 命中则直接返回 / 同周期广播 (Broadcast)
|
+-----------------------+
| L1 Cache / TEX | <-- 位于 SM 内部,管理单位为 Cache Line (128 Bytes)
+-----------------------+
^
| 32 Bytes (1 个 Sector,物理传输的最小不可分割单位)
| LSU (Load/Store Unit) 在此层级发起物理请求
|
+-----------------------+
| L2 Cache | <-- 所有 SM 共享
+-----------------------+
^
| 内存事务 (Memory Transaction:32 / 64 / 128 Bytes)
|
+-----------------------+
| Global Memory (DRAM) | <-- 矩阵 A、矩阵 B、矩阵 C 所在位置
+-----------------------+
Naive GEMM 架构被死死钉在性能底谷的核心原因,可以归结为一个致命的物理缺陷:极低的算术强度引发了全局内存的灾难性冗余读取。
1. 缺陷剖析:O(N) 级别的冗余拉取
在 Naive 实现中,计算逻辑是“以线程为孤岛”的。计算 Ci,j 的线程需要读取矩阵 A 的第 i 行;而计算其右侧相邻元素 Ci,j+1 的线程,再次独立发起了对矩阵 A 第 i 行的全局内存读取请求。
在物理层面上,矩阵 A 的同一行、矩阵 B 的同一列,被数以万计的不同线程反复从高延迟的 Global Memory 中重复搬运。对于 N×N 的方阵,Global Memory 中的每一个浮点数,都被物理拉取了整整 N 次。在 N=2048 的测试中,硬件显存总线承担了理论值 2048 倍的数据搬运重压,彻底击穿了 GDDR 7 的带宽上限。
Enjoy Reading This Article?
Here are some more articles you might like to read next: