HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

AI 工作流引擎概述

概述

AI 工作流引擎是管理机器学习全生命周期的核心组件,它将数据处理、模型训练、评估、部署等环节编排成自动化流水线。本文介绍工作流引擎的核心概念、主流方案对比及架构设计。

为什么需要工作流引擎

ML 开发的挑战

┌─────────────────────────────────────────────────────────────────┐
│                    ML 开发生命周期挑战                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  传统开发方式的问题:                                               │
│                                                                 │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │                                                         │   │
│  │  数据准备 ──► 特征工程 ──► 训练 ──► 评估 ──► 部署       │   │
│  │     │           │          │        │        │         │   │
│  │     ▼           ▼          ▼        ▼        ▼         │   │
│  │  ┌─────┐    ┌─────┐    ┌─────┐  ┌─────┐  ┌─────┐      │   │
│  │  │手动  │    │手动  │    │手动  │  │手动  │  │手动  │      │   │
│  │  │脚本  │    │脚本  │    │脚本  │  │脚本  │  │脚本  │      │   │
│  │  └─────┘    └─────┘    └─────┘  └─────┘  └─────┘      │   │
│  │                                                         │   │
│  │  问题:                                                   │   │
│  │  • 步骤之间缺乏关联                                       │   │
│  │  • 难以复现实验                                          │   │
│  │  • 手动触发易出错                                        │   │
│  │  • 资源管理困难                                          │   │
│  │  • 缺乏版本控制                                          │   │
│  │                                                         │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  工作流引擎解决方案:                                              │
│                                                                 │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │                                                         │   │
│  │            ┌────────────────────────────┐               │   │
│  │            │     Workflow Engine        │               │   │
│  │            │  ┌──────────────────────┐  │               │   │
│  │            │  │   DAG 调度器          │  │               │   │
│  │            │  └──────────────────────┘  │               │   │
│  │            │  ┌──────────────────────┐  │               │   │
│  │            │  │   资源管理器          │  │               │   │
│  │            │  └──────────────────────┘  │               │   │
│  │            │  ┌──────────────────────┐  │               │   │
│  │            │  │   元数据存储          │  │               │   │
│  │            │  └──────────────────────┘  │               │   │
│  │            └────────────┬───────────────┘               │   │
│  │                         │                               │   │
│  │    ┌────────────────────┼────────────────────┐         │   │
│  │    │                    │                    │         │   │
│  │    ▼                    ▼                    ▼         │   │
│  │  ┌─────┐            ┌─────┐            ┌─────┐        │   │
│  │  │Step1│───────────►│Step2│───────────►│Step3│        │   │
│  │  │数据 │            │训练 │            │部署 │        │   │
│  │  └─────┘            └─────┘            └─────┘        │   │
│  │                                                         │   │
│  │  优势:                                                   │   │
│  │  • 自动化编排                                            │   │
│  │  • 完整可复现                                            │   │
│  │  • 资源自动分配                                          │   │
│  │  • 全程可追溯                                            │   │
│  │                                                         │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

核心价值

# 工作流引擎核心价值
workflow_engine_value:
  # 1. 自动化
  automation:
    - 端到端流水线自动执行
    - 触发器驱动 (定时/事件/手动)
    - 失败自动重试
    - 条件分支执行

  # 2. 可复现性
  reproducibility:
    - 代码版本化
    - 数据版本化
    - 环境容器化
    - 参数记录

  # 3. 可扩展性
  scalability:
    - 分布式执行
    - 弹性资源
    - 并行处理
    - 跨集群调度

  # 4. 可观测性
  observability:
    - 执行状态跟踪
    - 日志聚合
    - 指标监控
    - 血缘追溯

  # 5. 协作效率
  collaboration:
    - 组件复用
    - 模板共享
    - 权限管理
    - 审计日志

主流工作流引擎对比

方案概览

┌─────────────────────────────────────────────────────────────────┐
│                    主流 ML 工作流引擎对比                         │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │                    Kubeflow Pipelines                    │   │
│  │  ┌─────────────────────────────────────────────────┐    │   │
│  │  │  • Kubernetes 原生                               │    │   │
│  │  │  • Argo Workflow 后端                            │    │   │
│  │  │  • Python SDK (KFP)                             │    │   │
│  │  │  • 组件市场                                      │    │   │
│  │  │  适用: 大规模 ML 平台                             │    │   │
│  │  └─────────────────────────────────────────────────┘    │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │                       Argo Workflows                     │   │
│  │  ┌─────────────────────────────────────────────────┐    │   │
│  │  │  • 通用容器工作流                                │    │   │
│  │  │  • YAML 定义                                    │    │   │
│  │  │  • DAG + Steps 模式                             │    │   │
│  │  │  • 丰富的模板功能                                │    │   │
│  │  │  适用: 通用 CI/CD + ML                           │    │   │
│  │  └─────────────────────────────────────────────────┘    │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │                        Apache Airflow                    │   │
│  │  ┌─────────────────────────────────────────────────┐    │   │
│  │  │  • Python 原生定义                               │    │   │
│  │  │  • 丰富的 Operator                              │    │   │
│  │  │  • 成熟的调度系统                                │    │   │
│  │  │  • 强大的 UI                                    │    │   │
│  │  │  适用: 数据工程 + ML Pipeline                    │    │   │
│  │  └─────────────────────────────────────────────────┘    │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │                         Prefect                          │   │
│  │  ┌─────────────────────────────────────────────────┐    │   │
│  │  │  • 现代 Python API                              │    │   │
│  │  │  • 动态工作流                                    │    │   │
│  │  │  • 本地/云端混合                                 │    │   │
│  │  │  • 简单易用                                      │    │   │
│  │  │  适用: 中小规模数据/ML 流水线                     │    │   │
│  │  └─────────────────────────────────────────────────┘    │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │                        MLflow                            │   │
│  │  ┌─────────────────────────────────────────────────┐    │   │
│  │  │  • 实验跟踪为核心                                │    │   │
│  │  │  • 模型注册表                                    │    │   │
│  │  │  • Projects 定义                                │    │   │
│  │  │  • 多框架支持                                    │    │   │
│  │  │  适用: 实验管理 + 简单流水线                      │    │   │
│  │  └─────────────────────────────────────────────────┘    │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

详细对比

# 工作流引擎详细对比
comparison:
  kubeflow_pipelines:
    pros:
      - Kubernetes 原生集成
      - 强大的 ML 组件生态
      - 支持分布式训练
      - 与 Kubeflow 生态无缝集成
    cons:
      - 学习曲线陡峭
      - 部署复杂
      - 资源消耗大
    best_for:
      - 大规模 ML 平台
      - Kubernetes 环境
      - 需要完整 MLOps 能力

  argo_workflows:
    pros:
      - 灵活的 YAML 定义
      - 丰富的工作流模式
      - 轻量级部署
      - 活跃的社区
    cons:
      - ML 特定功能较少
      - 需要 Kubernetes
      - UI 功能有限
    best_for:
      - 通用容器工作流
      - CI/CD 集成
      - 定制化需求高

  airflow:
    pros:
      - 成熟稳定
      - 丰富的 Operator
      - 强大的调度能力
      - 完善的监控
    cons:
      - 不是 Kubernetes 原生
      - 扩展性有限
      - 实时性较差
    best_for:
      - 数据工程
      - 批处理任务
      - 传统环境

  prefect:
    pros:
      - 现代 Python API
      - 动态工作流支持
      - 本地开发友好
      - 错误处理优秀
    cons:
      - 相对较新
      - 企业功能需付费
      - 生态相对小
    best_for:
      - 中小规模
      - Python 为主
      - 快速迭代

工作流引擎架构

通用架构设计

┌─────────────────────────────────────────────────────────────────┐
│                    工作流引擎通用架构                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  用户接口层                                                       │
│  ┌──────────────────────────────────────────────────────────┐  │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐ │  │
│  │  │  Web UI  │  │   CLI    │  │   SDK    │  │   API    │ │  │
│  │  └────┬─────┘  └────┬─────┘  └────┬─────┘  └────┬─────┘ │  │
│  └───────┼─────────────┼─────────────┼─────────────┼────────┘  │
│          └─────────────┴──────┬──────┴─────────────┘           │
│                               │                                 │
│  ════════════════════════════╪════════════════════════════════ │
│                               │                                 │
│  控制平面                      │                                 │
│  ┌────────────────────────────┴─────────────────────────────┐  │
│  │                                                          │  │
│  │  ┌─────────────────────────────────────────────────────┐ │  │
│  │  │                   API Server                        │ │  │
│  │  │  • 工作流 CRUD                                       │ │  │
│  │  │  • 运行管理                                          │ │  │
│  │  │  • 认证授权                                          │ │  │
│  │  └─────────────────────────────────────────────────────┘ │  │
│  │                                                          │  │
│  │  ┌───────────────┐  ┌───────────────┐  ┌──────────────┐ │  │
│  │  │   Scheduler   │  │   Executor    │  │   Monitor    │ │  │
│  │  │               │  │               │  │              │ │  │
│  │  │ • DAG 解析    │  │ • 任务分发    │  │ • 状态监控   │ │  │
│  │  │ • 依赖计算    │  │ • 重试管理    │  │ • 日志收集   │ │  │
│  │  │ • 触发管理    │  │ • 并发控制    │  │ • 指标采集   │ │  │
│  │  └───────┬───────┘  └───────┬───────┘  └──────┬───────┘ │  │
│  │          │                  │                 │          │  │
│  └──────────┼──────────────────┼─────────────────┼──────────┘  │
│             │                  │                 │              │
│  ═══════════╪══════════════════╪═════════════════╪════════════ │
│             │                  │                 │              │
│  数据平面   │                  │                 │              │
│  ┌──────────┼──────────────────┼─────────────────┼──────────┐  │
│  │          │                  │                 │          │  │
│  │  ┌───────┴───────┐  ┌───────┴───────┐  ┌─────┴────────┐ │  │
│  │  │  Metadata DB  │  │  Artifact     │  │    Queue     │ │  │
│  │  │               │  │  Storage      │  │              │ │  │
│  │  │ • 工作流定义  │  │ • 模型文件    │  │ • 任务队列   │ │  │
│  │  │ • 运行记录    │  │ • 数据集      │  │ • 事件队列   │ │  │
│  │  │ • 血缘关系    │  │ • 日志文件    │  │              │ │  │
│  │  └───────────────┘  └───────────────┘  └──────────────┘ │  │
│  │                                                          │  │
│  └──────────────────────────────────────────────────────────┘  │
│                               │                                 │
│  ═════════════════════════════╪═══════════════════════════════ │
│                               │                                 │
│  执行层                        │                                 │
│  ┌────────────────────────────┴─────────────────────────────┐  │
│  │                                                          │  │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐      │  │
│  │  │ Kubernetes  │  │   Docker    │  │   Local     │      │  │
│  │  │  Executor   │  │  Executor   │  │  Executor   │      │  │
│  │  └──────┬──────┘  └──────┬──────┘  └──────┬──────┘      │  │
│  │         │                │                │              │  │
│  │    ┌────┴────┐      ┌────┴────┐      ┌────┴────┐        │  │
│  │    │  Pods   │      │Containers│      │ Process │        │  │
│  │    └─────────┘      └─────────┘      └─────────┘        │  │
│  │                                                          │  │
│  └──────────────────────────────────────────────────────────┘  │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

核心组件设计

"""
工作流引擎核心组件抽象
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass, field
from enum import Enum
from datetime import datetime
import uuid


class TaskState(Enum):
    """任务状态"""
    PENDING = "pending"
    QUEUED = "queued"
    RUNNING = "running"
    SUCCESS = "success"
    FAILED = "failed"
    SKIPPED = "skipped"
    CANCELLED = "cancelled"


class WorkflowState(Enum):
    """工作流状态"""
    PENDING = "pending"
    RUNNING = "running"
    SUCCESS = "success"
    FAILED = "failed"
    CANCELLED = "cancelled"


@dataclass
class TaskSpec:
    """任务规格"""
    name: str
    image: str
    command: List[str]
    args: List[str] = field(default_factory=list)
    env: Dict[str, str] = field(default_factory=dict)
    resources: Dict[str, Any] = field(default_factory=dict)
    inputs: Dict[str, Any] = field(default_factory=dict)
    outputs: Dict[str, Any] = field(default_factory=dict)
    retry_policy: Dict[str, Any] = field(default_factory=dict)
    timeout: int = 3600  # 秒


@dataclass
class Task:
    """任务定义"""
    id: str
    name: str
    spec: TaskSpec
    dependencies: List[str] = field(default_factory=list)
    state: TaskState = TaskState.PENDING
    start_time: Optional[datetime] = None
    end_time: Optional[datetime] = None
    retry_count: int = 0
    error_message: Optional[str] = None


@dataclass
class WorkflowSpec:
    """工作流规格"""
    name: str
    description: str = ""
    tasks: List[Task] = field(default_factory=list)
    parameters: Dict[str, Any] = field(default_factory=dict)
    triggers: List[Dict] = field(default_factory=list)
    labels: Dict[str, str] = field(default_factory=dict)


@dataclass
class WorkflowRun:
    """工作流运行实例"""
    id: str = field(default_factory=lambda: str(uuid.uuid4()))
    workflow_id: str = ""
    state: WorkflowState = WorkflowState.PENDING
    parameters: Dict[str, Any] = field(default_factory=dict)
    task_runs: Dict[str, Task] = field(default_factory=dict)
    start_time: Optional[datetime] = None
    end_time: Optional[datetime] = None
    triggered_by: str = "manual"


class DAGParser:
    """DAG 解析器"""

    def __init__(self):
        self.tasks: Dict[str, Task] = {}
        self.adjacency: Dict[str, List[str]] = {}  # 邻接表
        self.in_degree: Dict[str, int] = {}  # 入度

    def parse(self, workflow_spec: WorkflowSpec) -> 'DAGParser':
        """解析工作流为 DAG"""
        for task in workflow_spec.tasks:
            self.tasks[task.id] = task
            self.adjacency[task.id] = []
            self.in_degree[task.id] = len(task.dependencies)

            for dep in task.dependencies:
                if dep not in self.adjacency:
                    self.adjacency[dep] = []
                self.adjacency[dep].append(task.id)

        return self

    def get_ready_tasks(self, completed: set) -> List[str]:
        """获取可执行的任务"""
        ready = []
        for task_id, degree in self.in_degree.items():
            if task_id in completed:
                continue
            # 检查所有依赖是否完成
            task = self.tasks[task_id]
            if all(dep in completed for dep in task.dependencies):
                ready.append(task_id)
        return ready

    def topological_sort(self) -> List[str]:
        """拓扑排序"""
        in_degree = self.in_degree.copy()
        queue = [task_id for task_id, degree in in_degree.items() if degree == 0]
        result = []

        while queue:
            task_id = queue.pop(0)
            result.append(task_id)

            for next_task in self.adjacency.get(task_id, []):
                in_degree[next_task] -= 1
                if in_degree[next_task] == 0:
                    queue.append(next_task)

        if len(result) != len(self.tasks):
            raise ValueError("Workflow contains cycles")

        return result

    def validate(self) -> List[str]:
        """验证 DAG"""
        errors = []

        # 检查循环依赖
        try:
            self.topological_sort()
        except ValueError as e:
            errors.append(str(e))

        # 检查依赖是否存在
        for task in self.tasks.values():
            for dep in task.dependencies:
                if dep not in self.tasks:
                    errors.append(f"Task {task.id} depends on unknown task {dep}")

        return errors


class Executor(ABC):
    """执行器抽象基类"""

    @abstractmethod
    async def execute(self, task: Task, context: Dict) -> TaskState:
        """执行任务"""
        pass

    @abstractmethod
    async def cancel(self, task_id: str):
        """取消任务"""
        pass

    @abstractmethod
    async def get_logs(self, task_id: str) -> str:
        """获取日志"""
        pass


class Scheduler:
    """调度器"""

    def __init__(
        self,
        executor: Executor,
        max_concurrent: int = 10
    ):
        self.executor = executor
        self.max_concurrent = max_concurrent
        self.running_tasks: Dict[str, Task] = {}
        self.completed_tasks: set = set()

    async def run_workflow(self, run: WorkflowRun, dag: DAGParser) -> WorkflowState:
        """运行工作流"""
        run.state = WorkflowState.RUNNING
        run.start_time = datetime.now()

        try:
            while True:
                # 获取可执行任务
                ready_tasks = dag.get_ready_tasks(self.completed_tasks)

                if not ready_tasks and not self.running_tasks:
                    # 所有任务完成
                    break

                # 启动新任务
                for task_id in ready_tasks:
                    if len(self.running_tasks) >= self.max_concurrent:
                        break
                    if task_id not in self.running_tasks:
                        task = dag.tasks[task_id]
                        await self._start_task(task, run)

                # 等待任务完成
                await self._wait_for_completion()

            # 检查最终状态
            all_success = all(
                dag.tasks[tid].state == TaskState.SUCCESS
                for tid in dag.tasks
            )
            run.state = WorkflowState.SUCCESS if all_success else WorkflowState.FAILED

        except Exception as e:
            run.state = WorkflowState.FAILED
            raise

        finally:
            run.end_time = datetime.now()

        return run.state

    async def _start_task(self, task: Task, run: WorkflowRun):
        """启动任务"""
        task.state = TaskState.RUNNING
        task.start_time = datetime.now()
        self.running_tasks[task.id] = task
        run.task_runs[task.id] = task

        # 异步执行
        context = {"workflow_run_id": run.id, "parameters": run.parameters}
        state = await self.executor.execute(task, context)

        task.state = state
        task.end_time = datetime.now()

        del self.running_tasks[task.id]

        if state == TaskState.SUCCESS:
            self.completed_tasks.add(task.id)
        elif state == TaskState.FAILED and task.retry_count < task.spec.retry_policy.get("max_retries", 0):
            # 重试
            task.retry_count += 1
            task.state = TaskState.PENDING

    async def _wait_for_completion(self):
        """等待任务完成"""
        import asyncio
        await asyncio.sleep(1)


class WorkflowEngine:
    """工作流引擎"""

    def __init__(
        self,
        executor: Executor,
        metadata_store: 'MetadataStore' = None
    ):
        self.executor = executor
        self.metadata_store = metadata_store
        self.scheduler = Scheduler(executor)

        self.workflows: Dict[str, WorkflowSpec] = {}
        self.runs: Dict[str, WorkflowRun] = {}

    def register_workflow(self, spec: WorkflowSpec) -> str:
        """注册工作流"""
        workflow_id = str(uuid.uuid4())
        self.workflows[workflow_id] = spec

        if self.metadata_store:
            self.metadata_store.save_workflow(workflow_id, spec)

        return workflow_id

    async def run_workflow(
        self,
        workflow_id: str,
        parameters: Dict[str, Any] = None
    ) -> WorkflowRun:
        """运行工作流"""
        spec = self.workflows.get(workflow_id)
        if not spec:
            raise ValueError(f"Workflow {workflow_id} not found")

        # 创建运行实例
        run = WorkflowRun(
            workflow_id=workflow_id,
            parameters=parameters or {}
        )
        self.runs[run.id] = run

        # 解析 DAG
        dag = DAGParser().parse(spec)
        errors = dag.validate()
        if errors:
            raise ValueError(f"Invalid workflow: {errors}")

        # 执行
        await self.scheduler.run_workflow(run, dag)

        if self.metadata_store:
            self.metadata_store.save_run(run)

        return run

    def get_run(self, run_id: str) -> Optional[WorkflowRun]:
        """获取运行实例"""
        return self.runs.get(run_id)

    def list_runs(
        self,
        workflow_id: str = None,
        state: WorkflowState = None
    ) -> List[WorkflowRun]:
        """列出运行实例"""
        runs = list(self.runs.values())

        if workflow_id:
            runs = [r for r in runs if r.workflow_id == workflow_id]
        if state:
            runs = [r for r in runs if r.state == state]

        return runs


class MetadataStore(ABC):
    """元数据存储抽象"""

    @abstractmethod
    def save_workflow(self, workflow_id: str, spec: WorkflowSpec):
        pass

    @abstractmethod
    def get_workflow(self, workflow_id: str) -> Optional[WorkflowSpec]:
        pass

    @abstractmethod
    def save_run(self, run: WorkflowRun):
        pass

    @abstractmethod
    def get_run(self, run_id: str) -> Optional[WorkflowRun]:
        pass


# Python DSL for defining workflows
class WorkflowBuilder:
    """工作流构建器 - Python DSL"""

    def __init__(self, name: str, description: str = ""):
        self.name = name
        self.description = description
        self.tasks: List[Task] = []
        self.parameters: Dict[str, Any] = {}

    def add_parameter(self, name: str, default: Any = None, type_hint: str = "string"):
        """添加参数"""
        self.parameters[name] = {
            "default": default,
            "type": type_hint
        }
        return self

    def add_task(
        self,
        name: str,
        image: str,
        command: List[str],
        dependencies: List[str] = None,
        **kwargs
    ) -> 'WorkflowBuilder':
        """添加任务"""
        task = Task(
            id=name,
            name=name,
            spec=TaskSpec(
                name=name,
                image=image,
                command=command,
                **kwargs
            ),
            dependencies=dependencies or []
        )
        self.tasks.append(task)
        return self

    def build(self) -> WorkflowSpec:
        """构建工作流"""
        return WorkflowSpec(
            name=self.name,
            description=self.description,
            tasks=self.tasks,
            parameters=self.parameters
        )


# 装饰器风格定义
def task(
    name: str = None,
    image: str = None,
    dependencies: List[str] = None,
    **kwargs
):
    """任务装饰器"""
    def decorator(func: Callable):
        task_name = name or func.__name__
        # 从函数签名推断参数
        import inspect
        sig = inspect.signature(func)

        task_spec = TaskSpec(
            name=task_name,
            image=image or "python:3.9",
            command=["python", "-c", inspect.getsource(func)],
            **kwargs
        )

        func._task_spec = task_spec
        func._dependencies = dependencies or []
        return func

    return decorator


# 使用示例
if __name__ == "__main__":
    # 使用 Builder 模式
    workflow = (
        WorkflowBuilder("ml-pipeline", "ML Training Pipeline")
        .add_parameter("learning_rate", 0.001)
        .add_parameter("epochs", 10)
        .add_task(
            name="preprocess",
            image="python:3.9",
            command=["python", "preprocess.py"],
            resources={"cpu": "2", "memory": "4Gi"}
        )
        .add_task(
            name="train",
            image="pytorch/pytorch:2.0.0",
            command=["python", "train.py"],
            dependencies=["preprocess"],
            resources={"nvidia.com/gpu": "1", "memory": "16Gi"}
        )
        .add_task(
            name="evaluate",
            image="python:3.9",
            command=["python", "evaluate.py"],
            dependencies=["train"]
        )
        .build()
    )

    print(f"Workflow: {workflow.name}")
    print(f"Tasks: {[t.name for t in workflow.tasks]}")

工作流模式

常见工作流模式

┌─────────────────────────────────────────────────────────────────┐
│                      常见工作流模式                               │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. 顺序模式 (Sequential)                                        │
│  ┌─────┐    ┌─────┐    ┌─────┐    ┌─────┐                      │
│  │  A  │───►│  B  │───►│  C  │───►│  D  │                      │
│  └─────┘    └─────┘    └─────┘    └─────┘                      │
│                                                                 │
│  2. 并行模式 (Parallel)                                          │
│             ┌─────┐                                             │
│          ┌─►│  B  │─┐                                          │
│  ┌─────┐ │  └─────┘ │  ┌─────┐                                 │
│  │  A  │─┤          ├─►│  D  │                                 │
│  └─────┘ │  ┌─────┐ │  └─────┘                                 │
│          └─►│  C  │─┘                                          │
│             └─────┘                                             │
│                                                                 │
│  3. 条件分支 (Conditional)                                       │
│             ┌─────┐                                             │
│          Y  │  B  │                                             │
│         ┌──►└─────┘                                             │
│  ┌─────┐│                                                       │
│  │ A ? ├┤                                                       │
│  └─────┘│                                                       │
│         └──►┌─────┐                                             │
│          N  │  C  │                                             │
│             └─────┘                                             │
│                                                                 │
│  4. 循环模式 (Loop)                                              │
│  ┌─────┐    ┌─────┐    ┌─────┐                                 │
│  │  A  │───►│  B  │───►│ C ? │                                 │
│  └─────┘    └──▲──┘    └──┬──┘                                 │
│               │      N    │ Y                                   │
│               └───────────┘                                     │
│                                                                 │
│  5. 扇出-扇入 (Fan-out/Fan-in)                                   │
│             ┌─────┐                                             │
│          ┌─►│ B1  │─┐                                          │
│          │  └─────┘ │                                          │
│  ┌─────┐ │  ┌─────┐ │  ┌─────┐                                 │
│  │  A  │─┼─►│ B2  │─┼─►│  C  │                                 │
│  └─────┘ │  └─────┘ │  └─────┘                                 │
│          │  ┌─────┐ │                                          │
│          └─►│ B3  │─┘                                          │
│             └─────┘                                             │
│                                                                 │
│  6. 子工作流 (Sub-workflow)                                      │
│  ┌─────┐    ┌───────────────┐    ┌─────┐                       │
│  │  A  │───►│  Sub-Workflow │───►│  C  │                       │
│  └─────┘    │  ┌───┐ ┌───┐ │    └─────┘                       │
│             │  │ X │►│ Y │ │                                   │
│             │  └───┘ └───┘ │                                   │
│             └───────────────┘                                   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

模式实现

"""
工作流模式实现
"""
from typing import List, Dict, Any, Callable, Optional
from dataclasses import dataclass
import asyncio


@dataclass
class TaskResult:
    """任务结果"""
    task_id: str
    success: bool
    output: Any
    error: Optional[str] = None


class WorkflowPatterns:
    """工作流模式实现"""

    @staticmethod
    async def sequential(
        tasks: List[Callable],
        initial_input: Any = None
    ) -> List[TaskResult]:
        """
        顺序执行模式
        每个任务的输出作为下一个任务的输入
        """
        results = []
        current_input = initial_input

        for i, task in enumerate(tasks):
            try:
                output = await task(current_input)
                result = TaskResult(
                    task_id=f"task_{i}",
                    success=True,
                    output=output
                )
                current_input = output
            except Exception as e:
                result = TaskResult(
                    task_id=f"task_{i}",
                    success=False,
                    output=None,
                    error=str(e)
                )
                results.append(result)
                break  # 失败时停止

            results.append(result)

        return results

    @staticmethod
    async def parallel(
        tasks: List[Callable],
        inputs: List[Any] = None
    ) -> List[TaskResult]:
        """
        并行执行模式
        所有任务同时执行
        """
        if inputs is None:
            inputs = [None] * len(tasks)

        async def run_task(idx: int, task: Callable, input_data: Any) -> TaskResult:
            try:
                output = await task(input_data)
                return TaskResult(
                    task_id=f"task_{idx}",
                    success=True,
                    output=output
                )
            except Exception as e:
                return TaskResult(
                    task_id=f"task_{idx}",
                    success=False,
                    output=None,
                    error=str(e)
                )

        results = await asyncio.gather(*[
            run_task(i, task, inp)
            for i, (task, inp) in enumerate(zip(tasks, inputs))
        ])

        return list(results)

    @staticmethod
    async def conditional(
        condition: Callable[[], bool],
        if_true: Callable,
        if_false: Callable,
        input_data: Any = None
    ) -> TaskResult:
        """
        条件分支模式
        根据条件执行不同分支
        """
        try:
            branch = if_true if await condition() else if_false
            output = await branch(input_data)
            return TaskResult(
                task_id="conditional",
                success=True,
                output=output
            )
        except Exception as e:
            return TaskResult(
                task_id="conditional",
                success=False,
                output=None,
                error=str(e)
            )

    @staticmethod
    async def loop(
        task: Callable,
        condition: Callable[[Any], bool],
        initial_input: Any,
        max_iterations: int = 100
    ) -> List[TaskResult]:
        """
        循环模式
        重复执行直到条件不满足
        """
        results = []
        current_input = initial_input
        iteration = 0

        while iteration < max_iterations:
            try:
                output = await task(current_input)
                result = TaskResult(
                    task_id=f"iteration_{iteration}",
                    success=True,
                    output=output
                )
                results.append(result)

                # 检查继续条件
                if not await condition(output):
                    break

                current_input = output
                iteration += 1

            except Exception as e:
                result = TaskResult(
                    task_id=f"iteration_{iteration}",
                    success=False,
                    output=None,
                    error=str(e)
                )
                results.append(result)
                break

        return results

    @staticmethod
    async def fan_out_fan_in(
        scatter_task: Callable[[Any], List[Any]],
        process_task: Callable[[Any], Any],
        gather_task: Callable[[List[Any]], Any],
        input_data: Any
    ) -> TaskResult:
        """
        扇出-扇入模式
        1. 拆分输入为多个部分
        2. 并行处理每个部分
        3. 聚合结果
        """
        try:
            # 扇出
            scattered = await scatter_task(input_data)

            # 并行处理
            processed = await asyncio.gather(*[
                process_task(item) for item in scattered
            ])

            # 扇入
            result = await gather_task(list(processed))

            return TaskResult(
                task_id="fan_out_fan_in",
                success=True,
                output=result
            )
        except Exception as e:
            return TaskResult(
                task_id="fan_out_fan_in",
                success=False,
                output=None,
                error=str(e)
            )

    @staticmethod
    async def retry_with_backoff(
        task: Callable,
        input_data: Any,
        max_retries: int = 3,
        base_delay: float = 1.0,
        max_delay: float = 60.0
    ) -> TaskResult:
        """
        带退避的重试模式
        """
        last_error = None

        for attempt in range(max_retries + 1):
            try:
                output = await task(input_data)
                return TaskResult(
                    task_id=f"attempt_{attempt}",
                    success=True,
                    output=output
                )
            except Exception as e:
                last_error = str(e)

                if attempt < max_retries:
                    # 指数退避
                    delay = min(base_delay * (2 ** attempt), max_delay)
                    await asyncio.sleep(delay)

        return TaskResult(
            task_id=f"attempt_{max_retries}",
            success=False,
            output=None,
            error=f"Failed after {max_retries + 1} attempts: {last_error}"
        )


# ML 特定模式
class MLWorkflowPatterns:
    """ML 工作流模式"""

    @staticmethod
    async def hyperparameter_search(
        train_func: Callable,
        param_grid: List[Dict],
        evaluate_func: Callable,
        select_best: Callable[[List[Dict]], Dict]
    ) -> Dict:
        """
        超参数搜索模式
        """
        # 并行训练多个模型
        results = []

        async def train_and_evaluate(params: Dict) -> Dict:
            model = await train_func(params)
            score = await evaluate_func(model)
            return {"params": params, "model": model, "score": score}

        results = await asyncio.gather(*[
            train_and_evaluate(params) for params in param_grid
        ])

        # 选择最佳
        best = select_best(list(results))
        return best

    @staticmethod
    async def cross_validation(
        data_splitter: Callable[[Any, int], List[tuple]],
        train_func: Callable,
        evaluate_func: Callable,
        data: Any,
        n_folds: int = 5
    ) -> Dict:
        """
        交叉验证模式
        """
        folds = await data_splitter(data, n_folds)
        scores = []

        for i, (train_data, val_data) in enumerate(folds):
            model = await train_func(train_data)
            score = await evaluate_func(model, val_data)
            scores.append(score)

        return {
            "scores": scores,
            "mean_score": sum(scores) / len(scores),
            "std_score": (sum((s - sum(scores)/len(scores))**2 for s in scores) / len(scores)) ** 0.5
        }

    @staticmethod
    async def incremental_training(
        model_loader: Callable,
        data_fetcher: Callable,
        train_step: Callable,
        checkpoint_saver: Callable,
        should_continue: Callable[[Dict], bool],
        initial_checkpoint: str = None
    ) -> Dict:
        """
        增量训练模式
        """
        # 加载模型
        model = await model_loader(initial_checkpoint)
        metrics_history = []

        while True:
            # 获取新数据
            data = await data_fetcher()
            if data is None:
                break

            # 训练一步
            metrics = await train_step(model, data)
            metrics_history.append(metrics)

            # 保存检查点
            checkpoint = await checkpoint_saver(model, metrics)

            # 检查是否继续
            if not await should_continue(metrics):
                break

        return {
            "final_checkpoint": checkpoint,
            "metrics_history": metrics_history
        }


# 使用示例
async def example():
    """工作流模式使用示例"""

    # 定义简单任务
    async def preprocess(data):
        print(f"Preprocessing: {data}")
        return f"processed_{data}"

    async def train(data):
        print(f"Training with: {data}")
        return f"model_{data}"

    async def evaluate(model):
        print(f"Evaluating: {model}")
        return 0.95

    # 顺序执行
    results = await WorkflowPatterns.sequential(
        [preprocess, train, evaluate],
        initial_input="raw_data"
    )
    print(f"Sequential results: {results}")

    # 并行执行
    async def process_shard(shard):
        return f"processed_{shard}"

    parallel_results = await WorkflowPatterns.parallel(
        [process_shard, process_shard, process_shard],
        ["shard_1", "shard_2", "shard_3"]
    )
    print(f"Parallel results: {parallel_results}")


if __name__ == "__main__":
    asyncio.run(example())

Kubernetes 集成

Argo Workflow 示例

# Argo Workflow ML Pipeline 示例
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
  generateName: ml-pipeline-
spec:
  entrypoint: ml-pipeline

  # 参数定义
  arguments:
    parameters:
    - name: learning-rate
      value: "0.001"
    - name: epochs
      value: "10"
    - name: model-name
      value: "resnet50"

  # 工作流模板
  templates:
  - name: ml-pipeline
    dag:
      tasks:
      # 数据预处理
      - name: preprocess
        template: preprocess-template
        arguments:
          parameters:
          - name: input-path
            value: "s3://data/raw"

      # 并行训练多个模型
      - name: train-model-1
        template: train-template
        dependencies: [preprocess]
        arguments:
          parameters:
          - name: learning-rate
            value: "{{workflow.parameters.learning-rate}}"
          - name: model-variant
            value: "variant-1"

      - name: train-model-2
        template: train-template
        dependencies: [preprocess]
        arguments:
          parameters:
          - name: learning-rate
            value: "0.0001"
          - name: model-variant
            value: "variant-2"

      # 评估
      - name: evaluate
        template: evaluate-template
        dependencies: [train-model-1, train-model-2]

      # 选择最佳模型
      - name: select-best
        template: select-best-template
        dependencies: [evaluate]

      # 部署
      - name: deploy
        template: deploy-template
        dependencies: [select-best]
        when: "{{tasks.select-best.outputs.parameters.should-deploy}} == true"

  # 预处理模板
  - name: preprocess-template
    inputs:
      parameters:
      - name: input-path
    container:
      image: python:3.9
      command: [python, preprocess.py]
      args:
        - --input={{inputs.parameters.input-path}}
        - --output=/data/processed
      resources:
        requests:
          memory: "4Gi"
          cpu: "2"
      volumeMounts:
      - name: data-volume
        mountPath: /data
    outputs:
      artifacts:
      - name: processed-data
        path: /data/processed

  # 训练模板
  - name: train-template
    inputs:
      parameters:
      - name: learning-rate
      - name: model-variant
    container:
      image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
      command: [python, train.py]
      args:
        - --lr={{inputs.parameters.learning-rate}}
        - --variant={{inputs.parameters.model-variant}}
        - --epochs={{workflow.parameters.epochs}}
      resources:
        limits:
          nvidia.com/gpu: 1
          memory: "32Gi"
        requests:
          nvidia.com/gpu: 1
          memory: "32Gi"
    outputs:
      artifacts:
      - name: model
        path: /output/model
      parameters:
      - name: train-loss
        valueFrom:
          path: /output/metrics.json

  # 评估模板
  - name: evaluate-template
    container:
      image: python:3.9
      command: [python, evaluate.py]
    outputs:
      parameters:
      - name: best-model
        valueFrom:
          path: /output/best_model.txt

  # 部署模板
  - name: deploy-template
    resource:
      action: apply
      manifest: |
        apiVersion: serving.kubeflow.org/v1beta1
        kind: InferenceService
        metadata:
          name: {{workflow.parameters.model-name}}
        spec:
          predictor:
            pytorch:
              storageUri: "s3://models/{{workflow.name}}"

  # 持久卷
  volumeClaimTemplates:
  - metadata:
      name: data-volume
    spec:
      accessModes: ["ReadWriteOnce"]
      resources:
        requests:
          storage: 100Gi

---
# 定时触发的工作流
apiVersion: argoproj.io/v1alpha1
kind: CronWorkflow
metadata:
  name: ml-pipeline-daily
spec:
  schedule: "0 2 * * *"  # 每天凌晨2点
  timezone: "Asia/Shanghai"
  workflowSpec:
    entrypoint: ml-pipeline
    # ... 同上

最佳实践

工作流设计原则

# 工作流设计最佳实践
best_practices:
  # 1. 任务设计
  task_design:
    - name: "单一职责"
      description: "每个任务只做一件事"
      example: "分离数据下载和数据处理"

    - name: "幂等性"
      description: "重复执行产生相同结果"
      implementation:
        - 使用确定性随机种子
        - 避免依赖外部可变状态

    - name: "容器化"
      description: "每个任务使用独立容器"
      benefits:
        - 环境隔离
        - 可移植性
        - 版本控制

  # 2. 数据传递
  data_passing:
    - name: "小数据用参数"
      description: "简单值通过参数传递"
      limit: "< 1MB"

    - name: "大数据用 Artifact"
      description: "文件和数据集通过存储传递"
      storage:
        - S3/GCS
        - PVC
        - MinIO

    - name: "避免重复传输"
      description: "使用共享存储或缓存"

  # 3. 错误处理
  error_handling:
    retry:
      max_retries: 3
      backoff:
        type: exponential
        initial: 10s
        max: 5m

    failure_strategy:
      - fail_fast: "关键路径失败立即停止"
      - continue: "非关键任务失败继续"
      - retry_from_checkpoint: "从检查点恢复"

  # 4. 资源管理
  resource_management:
    - name: "明确资源需求"
      fields: [cpu, memory, gpu]

    - name: "使用资源配额"
      purpose: "防止资源滥用"

    - name: "清理临时资源"
      method: "设置 TTL 或手动清理"

  # 5. 可观测性
  observability:
    logging:
      - 结构化日志
      - 统一日志收集
      - 日志级别分离

    metrics:
      - 任务执行时间
      - 成功/失败率
      - 资源使用率

    tracing:
      - 分布式追踪
      - 血缘关系记录

总结

AI 工作流引擎是 ML 工程化的核心组件:

  1. 核心价值:自动化、可复现、可扩展、可观测
  2. 主流方案:Kubeflow Pipelines、Argo Workflows、Airflow、Prefect
  3. 架构组成:DAG 调度器、执行器、元数据存储
  4. 工作流模式:顺序、并行、条件、循环、扇出扇入

选择工作流引擎时需要考虑:

  • 团队技术栈(Python/YAML)
  • 基础设施(Kubernetes/传统环境)
  • 规模需求(单机/分布式)
  • 生态集成(MLOps 工具链)
Prev
06-AI工作流引擎
Next
Kubeflow Pipelines 深度实践