HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

03-Tensor Core 与矩阵运算

概述

Tensor Core 是 NVIDIA GPU 上专门用于加速矩阵乘法的硬件单元,从 Volta 架构开始引入。本文深入讲解 Tensor Core 的原理、编程接口以及如何利用它实现高性能的深度学习算子。

Tensor Core 架构

硬件演进

┌─────────────────────────────────────────────────────────────────────────┐
│                      Tensor Core 架构演进                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  Volta (V100) - 第一代 Tensor Core                                      │
│  ├── 每个 SM 有 8 个 Tensor Core                                        │
│  ├── 支持 FP16 输入,FP16/FP32 累加                                     │
│  ├── 单次操作:4x4x4 矩阵乘加                                           │
│  └── 峰值性能:125 TFLOPS (FP16)                                        │
│                                                                         │
│  Turing (T4/RTX 20) - 第二代 Tensor Core                                │
│  ├── 新增 INT8/INT4 支持                                                │
│  ├── 单次操作:4x4x4 (FP16) 或 8x8x4 (INT8)                            │
│  └── 峰值性能:130 TFLOPS (FP16)                                        │
│                                                                         │
│  Ampere (A100) - 第三代 Tensor Core                                     │
│  ├── 每个 SM 有 4 个 Tensor Core                                        │
│  ├── 新增 TF32、BF16、FP64 支持                                         │
│  ├── 稀疏性支持:2:4 结构化稀疏,2x 性能                                 │
│  ├── 单次操作:8x4x8 (FP16) 或 16x8x8 (INT8)                           │
│  └── 峰值性能:312 TFLOPS (FP16), 624 TFLOPS (INT8)                     │
│                                                                         │
│  Hopper (H100) - 第四代 Tensor Core                                     │
│  ├── 新增 FP8 (E4M3/E5M2) 支持                                         │
│  ├── Transformer Engine 集成                                            │
│  ├── 单次操作:16x8x16 (FP16)                                          │
│  └── 峰值性能:989 TFLOPS (FP16), 1979 TFLOPS (FP8)                     │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

Tensor Core 操作原理

┌─────────────────────────────────────────────────────────────────────────┐
│                    Tensor Core 矩阵乘加操作                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  D = A × B + C                                                          │
│                                                                         │
│  WMMA (Warp Matrix Multiply Accumulate):                                │
│                                                                         │
│      A [M, K]        B [K, N]        C [M, N]        D [M, N]           │
│    ┌─────────┐    ┌─────────┐    ┌─────────┐    ┌─────────┐            │
│    │         │    │         │    │         │    │         │            │
│    │  16x16  │ ×  │  16x16  │ +  │  16x16  │ =  │  16x16  │            │
│    │  FP16   │    │  FP16   │    │  FP32   │    │  FP32   │            │
│    │         │    │         │    │         │    │         │            │
│    └─────────┘    └─────────┘    └─────────┘    └─────────┘            │
│                                                                         │
│  数据分布:                                                              │
│  ├── 整个 Warp (32 线程) 协作完成一次 WMMA                              │
│  ├── 每个线程持有矩阵的一部分 (fragment)                                 │
│  └── 通过 warp shuffle 和 Tensor Core 完成计算                          │
│                                                                         │
│  支持的形状 (M, N, K):                                                  │
│  ├── FP16: 16x16x16, 32x8x16, 8x32x16                                  │
│  ├── TF32: 16x16x8                                                     │
│  ├── BF16: 16x16x16, 32x8x16, 8x32x16                                  │
│  ├── INT8: 16x16x16, 32x8x16, 8x32x16                                  │
│  └── FP64: 8x8x4                                                       │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

WMMA API 编程

基本 WMMA 使用

#include <mma.h>
using namespace nvcuda;

// WMMA 参数
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;

__global__ void wmma_gemm_kernel(
    const half* __restrict__ A,
    const half* __restrict__ B,
    float* __restrict__ C,
    int M, int N, int K
) {
    // Warp 索引
    int warp_m = (blockIdx.y * blockDim.y + threadIdx.y);
    int warp_n = (blockIdx.x * blockDim.x + threadIdx.x) / 32;

    // 声明 fragments
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> b_frag;
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

    // 初始化累加器为 0
    wmma::fill_fragment(c_frag, 0.0f);

    // 计算起始位置
    int a_row = warp_m * WMMA_M;
    int b_col = warp_n * WMMA_N;

    // K 维度迭代
    for (int k = 0; k < K; k += WMMA_K) {
        if (a_row < M && k < K && b_col < N) {
            // 加载 A fragment
            wmma::load_matrix_sync(a_frag, A + a_row * K + k, K);

            // 加载 B fragment
            wmma::load_matrix_sync(b_frag, B + k * N + b_col, N);

            // 执行矩阵乘加
            wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
        }
    }

    // 存储结果
    if (a_row < M && b_col < N) {
        wmma::store_matrix_sync(C + a_row * N + b_col, c_frag, N, wmma::mem_row_major);
    }
}

void launch_wmma_gemm(
    const half* A, const half* B, float* C,
    int M, int N, int K
) {
    // 每个 warp 计算 16x16 输出
    // Block: 4 warps (128 threads)
    dim3 block(128, 1);

    // Grid: 覆盖整个输出矩阵
    dim3 grid(
        (N + WMMA_N * 4 - 1) / (WMMA_N * 4),
        (M + WMMA_M - 1) / WMMA_M
    );

    wmma_gemm_kernel<<<grid, block>>>(A, B, C, M, N, K);
}

共享内存优化的 WMMA

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
#define WARP_SIZE 32

// Block tile 大小
#define BLOCK_M 128
#define BLOCK_N 128
#define BLOCK_K 32

// Warp tile 大小
#define WARP_M 64
#define WARP_N 64

__global__ void wmma_gemm_shared(
    const half* __restrict__ A,
    const half* __restrict__ B,
    float* __restrict__ C,
    int M, int N, int K
) {
    // 共享内存
    __shared__ half As[BLOCK_K][BLOCK_M];  // 转置存储
    __shared__ half Bs[BLOCK_K][BLOCK_N];

    // 线程/Warp 索引
    int tid = threadIdx.x;
    int warp_id = tid / WARP_SIZE;
    int lane_id = tid % WARP_SIZE;

    // Block 在 grid 中的位置
    int block_m = blockIdx.y * BLOCK_M;
    int block_n = blockIdx.x * BLOCK_N;

    // Warp 在 block 内的位置
    int warps_per_row = BLOCK_N / WARP_N;  // 2
    int warp_row = warp_id / warps_per_row;
    int warp_col = warp_id % warps_per_row;

    // 每个 warp 计算 WARP_M x WARP_N = 64x64
    // 使用 (WARP_M/WMMA_M) x (WARP_N/WMMA_N) = 4x4 = 16 个 WMMA 操作

    // 声明 fragments
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag[4];
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> b_frag[4];
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag[4][4];

    // 初始化累加器
    #pragma unroll
    for (int i = 0; i < 4; i++) {
        #pragma unroll
        for (int j = 0; j < 4; j++) {
            wmma::fill_fragment(c_frag[i][j], 0.0f);
        }
    }

    // K 维度分块迭代
    for (int bk = 0; bk < K; bk += BLOCK_K) {
        // 协作加载 A 到共享内存
        // 每个线程加载多个元素
        #pragma unroll
        for (int i = 0; i < BLOCK_M * BLOCK_K / (blockDim.x * 8); i++) {
            int idx = tid + i * blockDim.x;
            int row = idx / (BLOCK_K / 8);
            int col = (idx % (BLOCK_K / 8)) * 8;

            if (block_m + row < M && bk + col < K) {
                // 向量化加载
                *reinterpret_cast<float4*>(&As[col][row]) =
                    *reinterpret_cast<const float4*>(&A[(block_m + row) * K + bk + col]);
            }
        }

        // 协作加载 B 到共享内存
        #pragma unroll
        for (int i = 0; i < BLOCK_K * BLOCK_N / (blockDim.x * 8); i++) {
            int idx = tid + i * blockDim.x;
            int row = idx / (BLOCK_N / 8);
            int col = (idx % (BLOCK_N / 8)) * 8;

            if (bk + row < K && block_n + col < N) {
                *reinterpret_cast<float4*>(&Bs[row][col]) =
                    *reinterpret_cast<const float4*>(&B[(bk + row) * N + block_n + col]);
            }
        }

        __syncthreads();

        // WMMA 计算
        #pragma unroll
        for (int k = 0; k < BLOCK_K; k += WMMA_K) {
            // 加载 A fragments (4 个 WMMA_M x WMMA_K)
            #pragma unroll
            for (int m = 0; m < 4; m++) {
                int a_row = warp_row * WARP_M + m * WMMA_M;
                wmma::load_matrix_sync(
                    a_frag[m],
                    &As[k][a_row],
                    BLOCK_M  // leading dimension
                );
            }

            // 加载 B fragments (4 个 WMMA_K x WMMA_N)
            #pragma unroll
            for (int n = 0; n < 4; n++) {
                int b_col = warp_col * WARP_N + n * WMMA_N;
                wmma::load_matrix_sync(
                    b_frag[n],
                    &Bs[k][b_col],
                    BLOCK_N
                );
            }

            // 计算 4x4 个 WMMA
            #pragma unroll
            for (int m = 0; m < 4; m++) {
                #pragma unroll
                for (int n = 0; n < 4; n++) {
                    wmma::mma_sync(c_frag[m][n], a_frag[m], b_frag[n], c_frag[m][n]);
                }
            }
        }

        __syncthreads();
    }

    // 存储结果
    #pragma unroll
    for (int m = 0; m < 4; m++) {
        #pragma unroll
        for (int n = 0; n < 4; n++) {
            int c_row = block_m + warp_row * WARP_M + m * WMMA_M;
            int c_col = block_n + warp_col * WARP_N + n * WMMA_N;

            if (c_row < M && c_col < N) {
                wmma::store_matrix_sync(
                    &C[c_row * N + c_col],
                    c_frag[m][n],
                    N,
                    wmma::mem_row_major
                );
            }
        }
    }
}

MMA PTX 指令

直接使用 PTX

/*
 * 对于更细粒度的控制,可以直接使用 PTX 汇编
 * mma.sync 指令支持更多配置选项
 */

// FP16 MMA: m16n8k16
__device__ void mma_m16n8k16_fp16(
    uint32_t* D,        // 4 个 uint32 (8 个 FP16)
    uint32_t* A,        // 8 个 uint32 (16 个 FP16)
    uint32_t* B,        // 4 个 uint32 (8 个 FP16)
    uint32_t* C         // 4 个 uint32 (8 个 FP16)
) {
    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
        "{%0, %1, %2, %3}, "
        "{%4, %5, %6, %7, %8, %9, %10, %11}, "
        "{%12, %13, %14, %15}, "
        "{%16, %17, %18, %19};\n"
        : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
        : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
          "r"(A[4]), "r"(A[5]), "r"(A[6]), "r"(A[7]),
          "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
          "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
    );
}

// TF32 MMA: m16n8k8
__device__ void mma_m16n8k8_tf32(
    float* D,           // 4 个 float
    uint32_t* A,        // 4 个 uint32 (TF32 in uint32)
    uint32_t* B,        // 2 个 uint32
    float* C            // 4 个 float
) {
    asm volatile(
        "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 "
        "{%0, %1, %2, %3}, "
        "{%4, %5, %6, %7}, "
        "{%8, %9}, "
        "{%10, %11, %12, %13};\n"
        : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
        : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
          "r"(B[0]), "r"(B[1]),
          "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
    );
}

// INT8 MMA: m16n8k32
__device__ void mma_m16n8k32_s8(
    int32_t* D,         // 4 个 int32
    uint32_t* A,        // 8 个 uint32 (32 个 INT8)
    uint32_t* B,        // 4 个 uint32 (16 个 INT8)
    int32_t* C          // 4 个 int32
) {
    asm volatile(
        "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
        "{%0, %1, %2, %3}, "
        "{%4, %5, %6, %7, %8, %9, %10, %11}, "
        "{%12, %13, %14, %15}, "
        "{%16, %17, %18, %19};\n"
        : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
        : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
          "r"(A[4]), "r"(A[5]), "r"(A[6]), "r"(A[7]),
          "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]),
          "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
    );
}

MMA 数据布局

/*
 * MMA 操作中数据在线程间的分布
 * 以 m16n8k16 FP16 为例:
 *
 * A 矩阵 [16, 16] - row_major
 * 每个线程持有 8 个元素
 * 线程 T 持有: A[T/4, (T%4)*2], A[T/4, (T%4)*2+1], ...
 *
 * B 矩阵 [16, 8] - col_major
 * 每个线程持有 4 个元素
 *
 * C/D 矩阵 [16, 8]
 * 每个线程持有 4 个元素
 */

// 手动实现 WMMA 加载(理解数据布局)
__device__ void load_matrix_a_m16n8k16(
    uint32_t frag[8],
    const half* ptr,
    int ldm  // leading dimension
) {
    int lane = threadIdx.x % 32;

    // 计算该线程负责的行和列
    int row0 = lane / 4;
    int row1 = row0 + 8;
    int col = (lane % 4) * 2;

    // 每个线程加载 8 个 half
    // 打包成 4 个 uint32
    half2* frag_h2 = reinterpret_cast<half2*>(frag);

    frag_h2[0] = *reinterpret_cast<const half2*>(&ptr[row0 * ldm + col]);
    frag_h2[1] = *reinterpret_cast<const half2*>(&ptr[row0 * ldm + col + 8]);
    frag_h2[2] = *reinterpret_cast<const half2*>(&ptr[row1 * ldm + col]);
    frag_h2[3] = *reinterpret_cast<const half2*>(&ptr[row1 * ldm + col + 8]);
}

__device__ void load_matrix_b_m16n8k16(
    uint32_t frag[4],
    const half* ptr,
    int ldm
) {
    int lane = threadIdx.x % 32;

    int row0 = lane / 4;
    int row1 = row0 + 8;
    int col = (lane % 4) * 2;

    half2* frag_h2 = reinterpret_cast<half2*>(frag);

    frag_h2[0] = *reinterpret_cast<const half2*>(&ptr[row0 * ldm + col]);
    frag_h2[1] = *reinterpret_cast<const half2*>(&ptr[row1 * ldm + col]);
}

__device__ void store_matrix_c_m16n8k16(
    half* ptr,
    uint32_t frag[4],
    int ldm
) {
    int lane = threadIdx.x % 32;

    int row0 = lane / 4;
    int row1 = row0 + 8;
    int col = (lane % 4) * 2;

    half2* frag_h2 = reinterpret_cast<half2*>(frag);

    *reinterpret_cast<half2*>(&ptr[row0 * ldm + col]) = frag_h2[0];
    *reinterpret_cast<half2*>(&ptr[row1 * ldm + col]) = frag_h2[1];
}

CUTLASS 介绍

CUTLASS 架构

┌─────────────────────────────────────────────────────────────────────────┐
│                      CUTLASS 层次化架构                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐   │
│  │                     Device-level                                 │   │
│  │   问题分解为 Thread Block Tiles                                   │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                     Thread Block-level                           │   │
│  │   ├── Thread Block Tile → Warp Tiles                            │   │
│  │   ├── 管理共享内存                                                │   │
│  │   └── 协调数据移动                                                │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                     Warp-level                                   │   │
│  │   ├── Warp Tile → MMA 操作                                       │   │
│  │   ├── 管理寄存器 fragments                                        │   │
│  │   └── 执行 Tensor Core MMA                                       │   │
│  └─────────────────────────────┬───────────────────────────────────┘   │
│                                │                                        │
│  ┌─────────────────────────────┴───────────────────────────────────┐   │
│  │                     Thread-level                                 │   │
│  │   ├── 数据加载/存储                                               │   │
│  │   └── 后处理 (epilogue)                                          │   │
│  └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
│  核心组件:                                                              │
│  ├── Gemm: 通用矩阵乘法                                                 │
│  ├── Conv: 卷积操作                                                     │
│  ├── Epilogue: 后处理 (bias, activation, quantization)                 │
│  ├── Layout: 数据布局 (RowMajor, ColumnMajor, etc.)                    │
│  └── Iterator: 内存访问模式                                             │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

使用 CUTLASS

#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/util/host_tensor.h>

// 定义 GEMM 类型
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;

using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::RowMajor;

// 定义 GEMM 操作
using Gemm = cutlass::gemm::device::Gemm<
    ElementInputA,
    LayoutInputA,
    ElementInputB,
    LayoutInputB,
    ElementOutput,
    LayoutOutput,
    ElementAccumulator,
    cutlass::arch::OpClassTensorOp,    // 使用 Tensor Core
    cutlass::arch::Sm80,               // Ampere 架构
    cutlass::gemm::GemmShape<128, 256, 64>,   // Thread Block Tile
    cutlass::gemm::GemmShape<64, 64, 64>,     // Warp Tile
    cutlass::gemm::GemmShape<16, 8, 16>,      // MMA Shape (m16n8k16)
    cutlass::epilogue::thread::LinearCombination<
        ElementOutput,
        128 / cutlass::sizeof_bits<ElementOutput>::value,
        ElementAccumulator,
        ElementAccumulator
    >,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
    3  // Pipeline stages
>;

cudaError_t cutlass_gemm(
    cutlass::half_t* A,
    cutlass::half_t* B,
    cutlass::half_t* C,
    cutlass::half_t* D,
    int M, int N, int K,
    float alpha, float beta
) {
    // 定义问题大小
    cutlass::gemm::GemmCoord problem_size(M, N, K);

    // 创建 GEMM 参数
    typename Gemm::Arguments arguments{
        problem_size,
        {A, K},     // TensorRef for A
        {B, N},     // TensorRef for B
        {C, N},     // TensorRef for C (source)
        {D, N},     // TensorRef for D (destination)
        {alpha, beta}
    };

    // 实例化 GEMM
    Gemm gemm_op;

    // 检查参数
    cutlass::Status status = gemm_op.can_implement(arguments);
    if (status != cutlass::Status::kSuccess) {
        return cudaErrorInvalidValue;
    }

    // 获取工作空间大小
    size_t workspace_size = Gemm::get_workspace_size(arguments);

    // 分配工作空间
    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    // 初始化
    status = gemm_op.initialize(arguments, workspace.get());
    if (status != cutlass::Status::kSuccess) {
        return cudaErrorInvalidValue;
    }

    // 执行
    status = gemm_op();
    if (status != cutlass::Status::kSuccess) {
        return cudaErrorInvalidValue;
    }

    return cudaSuccess;
}

CUTLASS 3.x with CuTe

/*
 * CUTLASS 3.x 引入 CuTe 库
 * CuTe = CUDA Templates for Linear Algebra Primitives
 * 提供更高级的抽象来处理张量布局和分区
 */

#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/collective/collective_builder.hpp>

using namespace cute;

// 使用 CuTe 定义布局
auto layout_A = make_layout(make_shape(M, K), make_stride(K, 1));  // Row-major
auto layout_B = make_layout(make_shape(K, N), make_stride(N, 1));  // Row-major

// 创建张量
auto tensor_A = make_tensor(make_gmem_ptr(ptr_A), layout_A);
auto tensor_B = make_tensor(make_gmem_ptr(ptr_B), layout_B);

// 分区
// Thread Block Tile
auto tiled_A = local_tile(tensor_A, make_shape(BM, BK), make_coord(bm, bk));
auto tiled_B = local_tile(tensor_B, make_shape(BK, BN), make_coord(bk, bn));

// 定义 MMA 操作
using MMA_Op = SM80_16x8x16_F16F16F16F16_TN;
auto mma = MMA_Op{};

// 获取 MMA 布局
auto mma_A = mma.get_layoutA_MK();
auto mma_B = mma.get_layoutB_NK();
auto mma_C = mma.get_layoutC_MN();

// CUTLASS 3.x Collective GEMM Builder
using CollectiveOp = cutlass::gemm::collective::CollectiveBuilder<
    cutlass::arch::Sm90,
    cutlass::arch::OpClassTensorOp,
    cute::half_t, cute::LayoutRight, 16,
    cute::half_t, cute::LayoutRight, 16,
    float,
    cute::Shape<cute::_128, cute::_256, cute::_64>,
    cute::Shape<cute::_1, cute::_1, cute::_1>,
    cutlass::gemm::collective::StageCountAutoCarveout<sizeof(float)>,
    cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

FP8 与 Transformer Engine

FP8 数据格式

┌─────────────────────────────────────────────────────────────────────────┐
│                         FP8 数据格式                                     │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  E4M3 (更高精度,较小动态范围)                                           │
│  ├── 1 位符号 + 4 位指数 + 3 位尾数                                     │
│  ├── 动态范围:±448                                                     │
│  └── 适用于:前向传播中的权重和激活                                      │
│                                                                         │
│  E5M2 (较低精度,更大动态范围)                                           │
│  ├── 1 位符号 + 5 位指数 + 2 位尾数                                     │
│  ├── 动态范围:±57344                                                   │
│  └── 适用于:反向传播中的梯度                                            │
│                                                                         │
│  对比:                                                                  │
│  ┌─────────────────────────────────────────────────────────────┐       │
│  │ 格式    │ 符号 │ 指数 │ 尾数 │ 最大值   │ 最小正数          │       │
│  ├─────────────────────────────────────────────────────────────┤       │
│  │ FP32   │ 1    │ 8    │ 23   │ 3.4e38   │ 1.2e-38          │       │
│  │ FP16   │ 1    │ 5    │ 10   │ 65504    │ 6.1e-5           │       │
│  │ BF16   │ 1    │ 8    │ 7    │ 3.4e38   │ 1.2e-38          │       │
│  │ E4M3   │ 1    │ 4    │ 3    │ 448      │ 2^-9             │       │
│  │ E5M2   │ 1    │ 5    │ 2    │ 57344    │ 2^-16            │       │
│  └─────────────────────────────────────────────────────────────┘       │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

Transformer Engine 使用

# Transformer Engine Python API
import transformer_engine.pytorch as te
import torch

# 创建 FP8 线性层
linear = te.Linear(
    in_features=4096,
    out_features=4096,
    bias=True
)

# 设置 FP8 配置
fp8_recipe = te.recipe.DelayedScaling(
    margin=0,
    interval=1,
    fp8_format=te.recipe.Format.HYBRID,  # E4M3 for forward, E5M2 for backward
    amax_history_len=1024,
    amax_compute_algo="max"
)

# FP8 前向传播
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    output = linear(input)

# FP8 注意力层
class FP8Attention(te.TransformerLayer):
    def __init__(self, hidden_size, num_heads):
        super().__init__(
            hidden_size=hidden_size,
            ffn_hidden_size=4 * hidden_size,
            num_attention_heads=num_heads,
            attention_dropout=0.0,
            hidden_dropout=0.0,
            self_attn_mask_type="causal"
        )

# 混合精度训练
model = FP8Attention(4096, 32).cuda()
optimizer = torch.optim.AdamW(model.parameters())

for batch in dataloader:
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        loss = model(batch)

    loss.backward()
    optimizer.step()

FP8 GEMM 实现

// H100 FP8 MMA
#include <cuda_fp8.h>

// FP8 类型定义
using fp8_e4m3 = __nv_fp8_e4m3;
using fp8_e5m2 = __nv_fp8_e5m2;

__device__ void mma_m16n8k32_fp8(
    float* D,
    uint32_t* A,  // 32 个 E4M3
    uint32_t* B,  // 16 个 E4M3
    float* C
) {
    #if __CUDA_ARCH__ >= 900  // Hopper
    asm volatile(
        "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
        "{%0, %1, %2, %3}, "
        "{%4, %5, %6, %7}, "
        "{%8, %9}, "
        "{%10, %11, %12, %13};\n"
        : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
        : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
          "r"(B[0]), "r"(B[1]),
          "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
    );
    #endif
}

// FP8 量化
__device__ __forceinline__ fp8_e4m3 quantize_fp8_e4m3(float val, float scale) {
    return __nv_cvt_float_to_fp8(val * scale, __NV_SATFINITE, __NV_E4M3);
}

// FP8 反量化
__device__ __forceinline__ float dequantize_fp8_e4m3(fp8_e4m3 val, float scale) {
    return __half2float(__nv_cvt_fp8_to_halfraw(val, __NV_E4M3)) / scale;
}

// 动态量化 kernel
__global__ void compute_scale_and_quantize(
    const float* input,
    fp8_e4m3* output,
    float* scale,
    int n
) {
    __shared__ float s_amax;

    // 找最大绝对值
    float local_max = 0.0f;
    for (int i = threadIdx.x; i < n; i += blockDim.x) {
        local_max = fmaxf(local_max, fabsf(input[i]));
    }

    float amax = block_reduce_max<256>(local_max);

    if (threadIdx.x == 0) {
        s_amax = amax;
        // E4M3 最大值是 448
        *scale = 448.0f / (amax + 1e-12f);
    }
    __syncthreads();

    float s = *scale;

    // 量化
    for (int i = threadIdx.x; i < n; i += blockDim.x) {
        output[i] = quantize_fp8_e4m3(input[i], s);
    }
}

性能优化策略

Tensor Core 利用率优化

/*
 * 最大化 Tensor Core 利用率的关键:
 *
 * 1. 数据对齐
 *    - 矩阵维度是 8/16 的倍数
 *    - 数据地址 16 字节对齐
 *
 * 2. 足够的并行度
 *    - 同时有多个 MMA 操作可执行
 *    - 通过增大 tile 大小或启动更多 warp
 *
 * 3. 隐藏数据移动延迟
 *    - 使用软件流水线 (double/triple buffering)
 *    - 异步拷贝 (cp.async)
 *
 * 4. 选择合适的 MMA 形状
 *    - 根据问题大小选择最优 tile 配置
 */

// 软件流水线示例
template<int STAGES>
__global__ void pipelined_gemm(
    const half* A, const half* B, float* C,
    int M, int N, int K
) {
    // 多缓冲共享内存
    __shared__ half As[STAGES][BLOCK_K][BLOCK_M];
    __shared__ half Bs[STAGES][BLOCK_K][BLOCK_N];

    // Fragments
    wmma::fragment<...> a_frag, b_frag;
    wmma::fragment<wmma::accumulator, ...> c_frag;
    wmma::fill_fragment(c_frag, 0.0f);

    int stage = 0;

    // 填充流水线
    for (int s = 0; s < STAGES - 1; s++) {
        load_tile_async(A, B, As[s], Bs[s], s * BLOCK_K);
    }
    cp_async_wait<STAGES - 2>();
    __syncthreads();

    // 主循环
    for (int k = 0; k < K; k += BLOCK_K) {
        // 加载下一个 tile
        if (k + (STAGES - 1) * BLOCK_K < K) {
            load_tile_async(A, B,
                           As[(stage + STAGES - 1) % STAGES],
                           Bs[(stage + STAGES - 1) % STAGES],
                           k + (STAGES - 1) * BLOCK_K);
        }

        // 等待当前 tile 就绪
        cp_async_wait<STAGES - 2>();
        __syncthreads();

        // 计算当前 tile
        wmma::load_matrix_sync(a_frag, As[stage], BLOCK_M);
        wmma::load_matrix_sync(b_frag, Bs[stage], BLOCK_N);
        wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);

        stage = (stage + 1) % STAGES;
    }

    // 存储结果
    wmma::store_matrix_sync(C, c_frag, N, wmma::mem_row_major);
}

性能对比

┌─────────────────────────────────────────────────────────────────────────┐
│                    GEMM 性能对比 (A100, M=N=K=4096)                      │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  实现方式                          │ 性能 (TFLOPS) │ 带宽利用率         │
│  ─────────────────────────────────┼───────────────┼──────────────────  │
│  Naive CUDA (FP32)                │     0.5       │      5%           │
│  Tiled Shared Memory (FP32)       │     5.0       │     40%           │
│  cuBLAS FP32                      │    19.5       │     95%           │
│  ─────────────────────────────────┼───────────────┼──────────────────  │
│  Naive WMMA FP16                  │    30.0       │     20%           │
│  Tiled WMMA FP16                  │   180.0       │     60%           │
│  CUTLASS FP16                     │   275.0       │     88%           │
│  cuBLAS FP16 Tensor Core          │   290.0       │     93%           │
│  ─────────────────────────────────┼───────────────┼──────────────────  │
│  CUTLASS INT8                     │   550.0       │     89%           │
│  cuBLAS INT8 Tensor Core          │   580.0       │     93%           │
│                                                                         │
│  理论峰值:                                                              │
│  ├── FP32 CUDA Core: 19.5 TFLOPS                                       │
│  ├── FP16 Tensor Core: 312 TFLOPS                                      │
│  └── INT8 Tensor Core: 624 TOPS                                        │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

总结

Tensor Core 使用要点

要点说明
数据类型FP16/BF16/TF32/INT8/FP8
对齐要求矩阵维度为 8/16 倍数
编程接口WMMA API / PTX / CUTLASS
性能关键Tile 大小、流水线深度、数据布局

最佳实践

□ 使用 CUTLASS 或 cuBLAS 作为基准
□ 矩阵维度 padding 到合适倍数
□ 使用软件流水线隐藏延迟
□ 合理配置 Block/Warp Tile 大小
□ 考虑数据类型选择 (精度 vs 性能)
Prev
02-高性能 Kernel 开发实战
Next
04-算子融合与优化技术