03-Tensor Core 与矩阵运算
概述
Tensor Core 是 NVIDIA GPU 上专门用于加速矩阵乘法的硬件单元,从 Volta 架构开始引入。本文深入讲解 Tensor Core 的原理、编程接口以及如何利用它实现高性能的深度学习算子。
Tensor Core 架构
硬件演进
┌─────────────────────────────────────────────────────────────────────────┐
│ Tensor Core 架构演进 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Volta (V100) - 第一代 Tensor Core │
│ ├── 每个 SM 有 8 个 Tensor Core │
│ ├── 支持 FP16 输入,FP16/FP32 累加 │
│ ├── 单次操作:4x4x4 矩阵乘加 │
│ └── 峰值性能:125 TFLOPS (FP16) │
│ │
│ Turing (T4/RTX 20) - 第二代 Tensor Core │
│ ├── 新增 INT8/INT4 支持 │
│ ├── 单次操作:4x4x4 (FP16) 或 8x8x4 (INT8) │
│ └── 峰值性能:130 TFLOPS (FP16) │
│ │
│ Ampere (A100) - 第三代 Tensor Core │
│ ├── 每个 SM 有 4 个 Tensor Core │
│ ├── 新增 TF32、BF16、FP64 支持 │
│ ├── 稀疏性支持:2:4 结构化稀疏,2x 性能 │
│ ├── 单次操作:8x4x8 (FP16) 或 16x8x8 (INT8) │
│ └── 峰值性能:312 TFLOPS (FP16), 624 TFLOPS (INT8) │
│ │
│ Hopper (H100) - 第四代 Tensor Core │
│ ├── 新增 FP8 (E4M3/E5M2) 支持 │
│ ├── Transformer Engine 集成 │
│ ├── 单次操作:16x8x16 (FP16) │
│ └── 峰值性能:989 TFLOPS (FP16), 1979 TFLOPS (FP8) │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Tensor Core 操作原理
┌─────────────────────────────────────────────────────────────────────────┐
│ Tensor Core 矩阵乘加操作 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ D = A × B + C │
│ │
│ WMMA (Warp Matrix Multiply Accumulate): │
│ │
│ A [M, K] B [K, N] C [M, N] D [M, N] │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ │ │ │ │ │ │ │ │
│ │ 16x16 │ × │ 16x16 │ + │ 16x16 │ = │ 16x16 │ │
│ │ FP16 │ │ FP16 │ │ FP32 │ │ FP32 │ │
│ │ │ │ │ │ │ │ │ │
│ └─────────┘ └─────────┘ └─────────┘ └─────────┘ │
│ │
│ 数据分布: │
│ ├── 整个 Warp (32 线程) 协作完成一次 WMMA │
│ ├── 每个线程持有矩阵的一部分 (fragment) │
│ └── 通过 warp shuffle 和 Tensor Core 完成计算 │
│ │
│ 支持的形状 (M, N, K): │
│ ├── FP16: 16x16x16, 32x8x16, 8x32x16 │
│ ├── TF32: 16x16x8 │
│ ├── BF16: 16x16x16, 32x8x16, 8x32x16 │
│ ├── INT8: 16x16x16, 32x8x16, 8x32x16 │
│ └── FP64: 8x8x4 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
WMMA API 编程
基本 WMMA 使用
#include <mma.h>
using namespace nvcuda;
// WMMA 参数
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
__global__ void wmma_gemm_kernel(
const half* __restrict__ A,
const half* __restrict__ B,
float* __restrict__ C,
int M, int N, int K
) {
// Warp 索引
int warp_m = (blockIdx.y * blockDim.y + threadIdx.y);
int warp_n = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
// 声明 fragments
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
// 初始化累加器为 0
wmma::fill_fragment(c_frag, 0.0f);
// 计算起始位置
int a_row = warp_m * WMMA_M;
int b_col = warp_n * WMMA_N;
// K 维度迭代
for (int k = 0; k < K; k += WMMA_K) {
if (a_row < M && k < K && b_col < N) {
// 加载 A fragment
wmma::load_matrix_sync(a_frag, A + a_row * K + k, K);
// 加载 B fragment
wmma::load_matrix_sync(b_frag, B + k * N + b_col, N);
// 执行矩阵乘加
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}
// 存储结果
if (a_row < M && b_col < N) {
wmma::store_matrix_sync(C + a_row * N + b_col, c_frag, N, wmma::mem_row_major);
}
}
void launch_wmma_gemm(
const half* A, const half* B, float* C,
int M, int N, int K
) {
// 每个 warp 计算 16x16 输出
// Block: 4 warps (128 threads)
dim3 block(128, 1);
// Grid: 覆盖整个输出矩阵
dim3 grid(
(N + WMMA_N * 4 - 1) / (WMMA_N * 4),
(M + WMMA_M - 1) / WMMA_M
);
wmma_gemm_kernel<<<grid, block>>>(A, B, C, M, N, K);
}
共享内存优化的 WMMA
#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
#define WARP_SIZE 32
// Block tile 大小
#define BLOCK_M 128
#define BLOCK_N 128
#define BLOCK_K 32
// Warp tile 大小
#define WARP_M 64
#define WARP_N 64
__global__ void wmma_gemm_shared(
const half* __restrict__ A,
const half* __restrict__ B,
float* __restrict__ C,
int M, int N, int K
) {
// 共享内存
__shared__ half As[BLOCK_K][BLOCK_M]; // 转置存储
__shared__ half Bs[BLOCK_K][BLOCK_N];
// 线程/Warp 索引
int tid = threadIdx.x;
int warp_id = tid / WARP_SIZE;
int lane_id = tid % WARP_SIZE;
// Block 在 grid 中的位置
int block_m = blockIdx.y * BLOCK_M;
int block_n = blockIdx.x * BLOCK_N;
// Warp 在 block 内的位置
int warps_per_row = BLOCK_N / WARP_N; // 2
int warp_row = warp_id / warps_per_row;
int warp_col = warp_id % warps_per_row;
// 每个 warp 计算 WARP_M x WARP_N = 64x64
// 使用 (WARP_M/WMMA_M) x (WARP_N/WMMA_N) = 4x4 = 16 个 WMMA 操作
// 声明 fragments
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag[4];
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> b_frag[4];
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag[4][4];
// 初始化累加器
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::fill_fragment(c_frag[i][j], 0.0f);
}
}
// K 维度分块迭代
for (int bk = 0; bk < K; bk += BLOCK_K) {
// 协作加载 A 到共享内存
// 每个线程加载多个元素
#pragma unroll
for (int i = 0; i < BLOCK_M * BLOCK_K / (blockDim.x * 8); i++) {
int idx = tid + i * blockDim.x;
int row = idx / (BLOCK_K / 8);
int col = (idx % (BLOCK_K / 8)) * 8;
if (block_m + row < M && bk + col < K) {
// 向量化加载
*reinterpret_cast<float4*>(&As[col][row]) =
*reinterpret_cast<const float4*>(&A[(block_m + row) * K + bk + col]);
}
}
// 协作加载 B 到共享内存
#pragma unroll
for (int i = 0; i < BLOCK_K * BLOCK_N / (blockDim.x * 8); i++) {
int idx = tid + i * blockDim.x;
int row = idx / (BLOCK_N / 8);
int col = (idx % (BLOCK_N / 8)) * 8;
if (bk + row < K && block_n + col < N) {
*reinterpret_cast<float4*>(&Bs[row][col]) =
*reinterpret_cast<const float4*>(&B[(bk + row) * N + block_n + col]);
}
}
__syncthreads();
// WMMA 计算
#pragma unroll
for (int k = 0; k < BLOCK_K; k += WMMA_K) {
// 加载 A fragments (4 个 WMMA_M x WMMA_K)
#pragma unroll
for (int m = 0; m < 4; m++) {
int a_row = warp_row * WARP_M + m * WMMA_M;
wmma::load_matrix_sync(
a_frag[m],
&As[k][a_row],
BLOCK_M // leading dimension
);
}
// 加载 B fragments (4 个 WMMA_K x WMMA_N)
#pragma unroll
for (int n = 0; n < 4; n++) {
int b_col = warp_col * WARP_N + n * WMMA_N;
wmma::load_matrix_sync(
b_frag[n],
&Bs[k][b_col],
BLOCK_N
);
}
// 计算 4x4 个 WMMA
#pragma unroll
for (int m = 0; m < 4; m++) {
#pragma unroll
for (int n = 0; n < 4; n++) {
wmma::mma_sync(c_frag[m][n], a_frag[m], b_frag[n], c_frag[m][n]);
}
}
}
__syncthreads();
}
// 存储结果
#pragma unroll
for (int m = 0; m < 4; m++) {
#pragma unroll
for (int n = 0; n < 4; n++) {
int c_row = block_m + warp_row * WARP_M + m * WMMA_M;
int c_col = block_n + warp_col * WARP_N + n * WMMA_N;
if (c_row < M && c_col < N) {
wmma::store_matrix_sync(
&C[c_row * N + c_col],
c_frag[m][n],
N,
wmma::mem_row_major
);
}
}
}
}
MMA PTX 指令
直接使用 PTX
/*
* 对于更细粒度的控制,可以直接使用 PTX 汇编
* mma.sync 指令支持更多配置选项
*/
// FP16 MMA: m16n8k16
__device__ void mma_m16n8k16_fp16(
uint32_t* D, // 4 个 uint32 (8 个 FP16)
uint32_t* A, // 8 个 uint32 (16 个 FP16)
uint32_t* B, // 4 个 uint32 (8 个 FP16)
uint32_t* C // 4 个 uint32 (8 个 FP16)
) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7, %8, %9, %10, %11}, "
"{%12, %13, %14, %15}, "
"{%16, %17, %18, %19};\n"
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(A[4]), "r"(A[5]), "r"(A[6]), "r"(A[7]),
"r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
);
}
// TF32 MMA: m16n8k8
__device__ void mma_m16n8k8_tf32(
float* D, // 4 个 float
uint32_t* A, // 4 个 uint32 (TF32 in uint32)
uint32_t* B, // 2 个 uint32
float* C // 4 个 float
) {
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(B[0]), "r"(B[1]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
);
}
// INT8 MMA: m16n8k32
__device__ void mma_m16n8k32_s8(
int32_t* D, // 4 个 int32
uint32_t* A, // 8 个 uint32 (32 个 INT8)
uint32_t* B, // 4 个 uint32 (16 个 INT8)
int32_t* C // 4 个 int32
) {
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7, %8, %9, %10, %11}, "
"{%12, %13, %14, %15}, "
"{%16, %17, %18, %19};\n"
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(A[4]), "r"(A[5]), "r"(A[6]), "r"(A[7]),
"r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
);
}
MMA 数据布局
/*
* MMA 操作中数据在线程间的分布
* 以 m16n8k16 FP16 为例:
*
* A 矩阵 [16, 16] - row_major
* 每个线程持有 8 个元素
* 线程 T 持有: A[T/4, (T%4)*2], A[T/4, (T%4)*2+1], ...
*
* B 矩阵 [16, 8] - col_major
* 每个线程持有 4 个元素
*
* C/D 矩阵 [16, 8]
* 每个线程持有 4 个元素
*/
// 手动实现 WMMA 加载(理解数据布局)
__device__ void load_matrix_a_m16n8k16(
uint32_t frag[8],
const half* ptr,
int ldm // leading dimension
) {
int lane = threadIdx.x % 32;
// 计算该线程负责的行和列
int row0 = lane / 4;
int row1 = row0 + 8;
int col = (lane % 4) * 2;
// 每个线程加载 8 个 half
// 打包成 4 个 uint32
half2* frag_h2 = reinterpret_cast<half2*>(frag);
frag_h2[0] = *reinterpret_cast<const half2*>(&ptr[row0 * ldm + col]);
frag_h2[1] = *reinterpret_cast<const half2*>(&ptr[row0 * ldm + col + 8]);
frag_h2[2] = *reinterpret_cast<const half2*>(&ptr[row1 * ldm + col]);
frag_h2[3] = *reinterpret_cast<const half2*>(&ptr[row1 * ldm + col + 8]);
}
__device__ void load_matrix_b_m16n8k16(
uint32_t frag[4],
const half* ptr,
int ldm
) {
int lane = threadIdx.x % 32;
int row0 = lane / 4;
int row1 = row0 + 8;
int col = (lane % 4) * 2;
half2* frag_h2 = reinterpret_cast<half2*>(frag);
frag_h2[0] = *reinterpret_cast<const half2*>(&ptr[row0 * ldm + col]);
frag_h2[1] = *reinterpret_cast<const half2*>(&ptr[row1 * ldm + col]);
}
__device__ void store_matrix_c_m16n8k16(
half* ptr,
uint32_t frag[4],
int ldm
) {
int lane = threadIdx.x % 32;
int row0 = lane / 4;
int row1 = row0 + 8;
int col = (lane % 4) * 2;
half2* frag_h2 = reinterpret_cast<half2*>(frag);
*reinterpret_cast<half2*>(&ptr[row0 * ldm + col]) = frag_h2[0];
*reinterpret_cast<half2*>(&ptr[row1 * ldm + col]) = frag_h2[1];
}
CUTLASS 介绍
CUTLASS 架构
┌─────────────────────────────────────────────────────────────────────────┐
│ CUTLASS 层次化架构 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Device-level │ │
│ │ 问题分解为 Thread Block Tiles │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ Thread Block-level │ │
│ │ ├── Thread Block Tile → Warp Tiles │ │
│ │ ├── 管理共享内存 │ │
│ │ └── 协调数据移动 │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ Warp-level │ │
│ │ ├── Warp Tile → MMA 操作 │ │
│ │ ├── 管理寄存器 fragments │ │
│ │ └── 执行 Tensor Core MMA │ │
│ └─────────────────────────────┬───────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────────┐ │
│ │ Thread-level │ │
│ │ ├── 数据加载/存储 │ │
│ │ └── 后处理 (epilogue) │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ 核心组件: │
│ ├── Gemm: 通用矩阵乘法 │
│ ├── Conv: 卷积操作 │
│ ├── Epilogue: 后处理 (bias, activation, quantization) │
│ ├── Layout: 数据布局 (RowMajor, ColumnMajor, etc.) │
│ └── Iterator: 内存访问模式 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
使用 CUTLASS
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/util/host_tensor.h>
// 定义 GEMM 类型
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// 定义 GEMM 操作
using Gemm = cutlass::gemm::device::Gemm<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
cutlass::arch::OpClassTensorOp, // 使用 Tensor Core
cutlass::arch::Sm80, // Ampere 架构
cutlass::gemm::GemmShape<128, 256, 64>, // Thread Block Tile
cutlass::gemm::GemmShape<64, 64, 64>, // Warp Tile
cutlass::gemm::GemmShape<16, 8, 16>, // MMA Shape (m16n8k16)
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3 // Pipeline stages
>;
cudaError_t cutlass_gemm(
cutlass::half_t* A,
cutlass::half_t* B,
cutlass::half_t* C,
cutlass::half_t* D,
int M, int N, int K,
float alpha, float beta
) {
// 定义问题大小
cutlass::gemm::GemmCoord problem_size(M, N, K);
// 创建 GEMM 参数
typename Gemm::Arguments arguments{
problem_size,
{A, K}, // TensorRef for A
{B, N}, // TensorRef for B
{C, N}, // TensorRef for C (source)
{D, N}, // TensorRef for D (destination)
{alpha, beta}
};
// 实例化 GEMM
Gemm gemm_op;
// 检查参数
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return cudaErrorInvalidValue;
}
// 获取工作空间大小
size_t workspace_size = Gemm::get_workspace_size(arguments);
// 分配工作空间
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// 初始化
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
return cudaErrorInvalidValue;
}
// 执行
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
return cudaErrorInvalidValue;
}
return cudaSuccess;
}
CUTLASS 3.x with CuTe
/*
* CUTLASS 3.x 引入 CuTe 库
* CuTe = CUDA Templates for Linear Algebra Primitives
* 提供更高级的抽象来处理张量布局和分区
*/
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/collective/collective_builder.hpp>
using namespace cute;
// 使用 CuTe 定义布局
auto layout_A = make_layout(make_shape(M, K), make_stride(K, 1)); // Row-major
auto layout_B = make_layout(make_shape(K, N), make_stride(N, 1)); // Row-major
// 创建张量
auto tensor_A = make_tensor(make_gmem_ptr(ptr_A), layout_A);
auto tensor_B = make_tensor(make_gmem_ptr(ptr_B), layout_B);
// 分区
// Thread Block Tile
auto tiled_A = local_tile(tensor_A, make_shape(BM, BK), make_coord(bm, bk));
auto tiled_B = local_tile(tensor_B, make_shape(BK, BN), make_coord(bk, bn));
// 定义 MMA 操作
using MMA_Op = SM80_16x8x16_F16F16F16F16_TN;
auto mma = MMA_Op{};
// 获取 MMA 布局
auto mma_A = mma.get_layoutA_MK();
auto mma_B = mma.get_layoutB_NK();
auto mma_C = mma.get_layoutC_MN();
// CUTLASS 3.x Collective GEMM Builder
using CollectiveOp = cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
cute::half_t, cute::LayoutRight, 16,
cute::half_t, cute::LayoutRight, 16,
float,
cute::Shape<cute::_128, cute::_256, cute::_64>,
cute::Shape<cute::_1, cute::_1, cute::_1>,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(float)>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;
FP8 与 Transformer Engine
FP8 数据格式
┌─────────────────────────────────────────────────────────────────────────┐
│ FP8 数据格式 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ E4M3 (更高精度,较小动态范围) │
│ ├── 1 位符号 + 4 位指数 + 3 位尾数 │
│ ├── 动态范围:±448 │
│ └── 适用于:前向传播中的权重和激活 │
│ │
│ E5M2 (较低精度,更大动态范围) │
│ ├── 1 位符号 + 5 位指数 + 2 位尾数 │
│ ├── 动态范围:±57344 │
│ └── 适用于:反向传播中的梯度 │
│ │
│ 对比: │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 格式 │ 符号 │ 指数 │ 尾数 │ 最大值 │ 最小正数 │ │
│ ├─────────────────────────────────────────────────────────────┤ │
│ │ FP32 │ 1 │ 8 │ 23 │ 3.4e38 │ 1.2e-38 │ │
│ │ FP16 │ 1 │ 5 │ 10 │ 65504 │ 6.1e-5 │ │
│ │ BF16 │ 1 │ 8 │ 7 │ 3.4e38 │ 1.2e-38 │ │
│ │ E4M3 │ 1 │ 4 │ 3 │ 448 │ 2^-9 │ │
│ │ E5M2 │ 1 │ 5 │ 2 │ 57344 │ 2^-16 │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
Transformer Engine 使用
# Transformer Engine Python API
import transformer_engine.pytorch as te
import torch
# 创建 FP8 线性层
linear = te.Linear(
in_features=4096,
out_features=4096,
bias=True
)
# 设置 FP8 配置
fp8_recipe = te.recipe.DelayedScaling(
margin=0,
interval=1,
fp8_format=te.recipe.Format.HYBRID, # E4M3 for forward, E5M2 for backward
amax_history_len=1024,
amax_compute_algo="max"
)
# FP8 前向传播
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = linear(input)
# FP8 注意力层
class FP8Attention(te.TransformerLayer):
def __init__(self, hidden_size, num_heads):
super().__init__(
hidden_size=hidden_size,
ffn_hidden_size=4 * hidden_size,
num_attention_heads=num_heads,
attention_dropout=0.0,
hidden_dropout=0.0,
self_attn_mask_type="causal"
)
# 混合精度训练
model = FP8Attention(4096, 32).cuda()
optimizer = torch.optim.AdamW(model.parameters())
for batch in dataloader:
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
loss = model(batch)
loss.backward()
optimizer.step()
FP8 GEMM 实现
// H100 FP8 MMA
#include <cuda_fp8.h>
// FP8 类型定义
using fp8_e4m3 = __nv_fp8_e4m3;
using fp8_e5m2 = __nv_fp8_e5m2;
__device__ void mma_m16n8k32_fp8(
float* D,
uint32_t* A, // 32 个 E4M3
uint32_t* B, // 16 个 E4M3
float* C
) {
#if __CUDA_ARCH__ >= 900 // Hopper
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(B[0]), "r"(B[1]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
);
#endif
}
// FP8 量化
__device__ __forceinline__ fp8_e4m3 quantize_fp8_e4m3(float val, float scale) {
return __nv_cvt_float_to_fp8(val * scale, __NV_SATFINITE, __NV_E4M3);
}
// FP8 反量化
__device__ __forceinline__ float dequantize_fp8_e4m3(fp8_e4m3 val, float scale) {
return __half2float(__nv_cvt_fp8_to_halfraw(val, __NV_E4M3)) / scale;
}
// 动态量化 kernel
__global__ void compute_scale_and_quantize(
const float* input,
fp8_e4m3* output,
float* scale,
int n
) {
__shared__ float s_amax;
// 找最大绝对值
float local_max = 0.0f;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
local_max = fmaxf(local_max, fabsf(input[i]));
}
float amax = block_reduce_max<256>(local_max);
if (threadIdx.x == 0) {
s_amax = amax;
// E4M3 最大值是 448
*scale = 448.0f / (amax + 1e-12f);
}
__syncthreads();
float s = *scale;
// 量化
for (int i = threadIdx.x; i < n; i += blockDim.x) {
output[i] = quantize_fp8_e4m3(input[i], s);
}
}
性能优化策略
Tensor Core 利用率优化
/*
* 最大化 Tensor Core 利用率的关键:
*
* 1. 数据对齐
* - 矩阵维度是 8/16 的倍数
* - 数据地址 16 字节对齐
*
* 2. 足够的并行度
* - 同时有多个 MMA 操作可执行
* - 通过增大 tile 大小或启动更多 warp
*
* 3. 隐藏数据移动延迟
* - 使用软件流水线 (double/triple buffering)
* - 异步拷贝 (cp.async)
*
* 4. 选择合适的 MMA 形状
* - 根据问题大小选择最优 tile 配置
*/
// 软件流水线示例
template<int STAGES>
__global__ void pipelined_gemm(
const half* A, const half* B, float* C,
int M, int N, int K
) {
// 多缓冲共享内存
__shared__ half As[STAGES][BLOCK_K][BLOCK_M];
__shared__ half Bs[STAGES][BLOCK_K][BLOCK_N];
// Fragments
wmma::fragment<...> a_frag, b_frag;
wmma::fragment<wmma::accumulator, ...> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
int stage = 0;
// 填充流水线
for (int s = 0; s < STAGES - 1; s++) {
load_tile_async(A, B, As[s], Bs[s], s * BLOCK_K);
}
cp_async_wait<STAGES - 2>();
__syncthreads();
// 主循环
for (int k = 0; k < K; k += BLOCK_K) {
// 加载下一个 tile
if (k + (STAGES - 1) * BLOCK_K < K) {
load_tile_async(A, B,
As[(stage + STAGES - 1) % STAGES],
Bs[(stage + STAGES - 1) % STAGES],
k + (STAGES - 1) * BLOCK_K);
}
// 等待当前 tile 就绪
cp_async_wait<STAGES - 2>();
__syncthreads();
// 计算当前 tile
wmma::load_matrix_sync(a_frag, As[stage], BLOCK_M);
wmma::load_matrix_sync(b_frag, Bs[stage], BLOCK_N);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
stage = (stage + 1) % STAGES;
}
// 存储结果
wmma::store_matrix_sync(C, c_frag, N, wmma::mem_row_major);
}
性能对比
┌─────────────────────────────────────────────────────────────────────────┐
│ GEMM 性能对比 (A100, M=N=K=4096) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 实现方式 │ 性能 (TFLOPS) │ 带宽利用率 │
│ ─────────────────────────────────┼───────────────┼────────────────── │
│ Naive CUDA (FP32) │ 0.5 │ 5% │
│ Tiled Shared Memory (FP32) │ 5.0 │ 40% │
│ cuBLAS FP32 │ 19.5 │ 95% │
│ ─────────────────────────────────┼───────────────┼────────────────── │
│ Naive WMMA FP16 │ 30.0 │ 20% │
│ Tiled WMMA FP16 │ 180.0 │ 60% │
│ CUTLASS FP16 │ 275.0 │ 88% │
│ cuBLAS FP16 Tensor Core │ 290.0 │ 93% │
│ ─────────────────────────────────┼───────────────┼────────────────── │
│ CUTLASS INT8 │ 550.0 │ 89% │
│ cuBLAS INT8 Tensor Core │ 580.0 │ 93% │
│ │
│ 理论峰值: │
│ ├── FP32 CUDA Core: 19.5 TFLOPS │
│ ├── FP16 Tensor Core: 312 TFLOPS │
│ └── INT8 Tensor Core: 624 TOPS │
│ │
└─────────────────────────────────────────────────────────────────────────┘
总结
Tensor Core 使用要点
| 要点 | 说明 |
|---|---|
| 数据类型 | FP16/BF16/TF32/INT8/FP8 |
| 对齐要求 | 矩阵维度为 8/16 倍数 |
| 编程接口 | WMMA API / PTX / CUTLASS |
| 性能关键 | Tile 大小、流水线深度、数据布局 |
最佳实践
□ 使用 CUTLASS 或 cuBLAS 作为基准
□ 矩阵维度 padding 到合适倍数
□ 使用软件流水线隐藏延迟
□ 合理配置 Block/Warp Tile 大小
□ 考虑数据类型选择 (精度 vs 性能)