AI 工作流引擎概述
概述
AI 工作流引擎是管理机器学习全生命周期的核心组件,它将数据处理、模型训练、评估、部署等环节编排成自动化流水线。本文介绍工作流引擎的核心概念、主流方案对比及架构设计。
为什么需要工作流引擎
ML 开发的挑战
┌─────────────────────────────────────────────────────────────────┐
│ ML 开发生命周期挑战 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 传统开发方式的问题: │
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ 数据准备 ──► 特征工程 ──► 训练 ──► 评估 ──► 部署 │ │
│ │ │ │ │ │ │ │ │
│ │ ▼ ▼ ▼ ▼ ▼ │ │
│ │ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │ │
│ │ │手动 │ │手动 │ │手动 │ │手动 │ │手动 │ │ │
│ │ │脚本 │ │脚本 │ │脚本 │ │脚本 │ │脚本 │ │ │
│ │ └─────┘ └─────┘ └─────┘ └─────┘ └─────┘ │ │
│ │ │ │
│ │ 问题: │ │
│ │ • 步骤之间缺乏关联 │ │
│ │ • 难以复现实验 │ │
│ │ • 手动触发易出错 │ │
│ │ • 资源管理困难 │ │
│ │ • 缺乏版本控制 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ 工作流引擎解决方案: │
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ ┌────────────────────────────┐ │ │
│ │ │ Workflow Engine │ │ │
│ │ │ ┌──────────────────────┐ │ │ │
│ │ │ │ DAG 调度器 │ │ │ │
│ │ │ └──────────────────────┘ │ │ │
│ │ │ ┌──────────────────────┐ │ │ │
│ │ │ │ 资源管理器 │ │ │ │
│ │ │ └──────────────────────┘ │ │ │
│ │ │ ┌──────────────────────┐ │ │ │
│ │ │ │ 元数据存储 │ │ │ │
│ │ │ └──────────────────────┘ │ │ │
│ │ └────────────┬───────────────┘ │ │
│ │ │ │ │
│ │ ┌────────────────────┼────────────────────┐ │ │
│ │ │ │ │ │ │
│ │ ▼ ▼ ▼ │ │
│ │ ┌─────┐ ┌─────┐ ┌─────┐ │ │
│ │ │Step1│───────────►│Step2│───────────►│Step3│ │ │
│ │ │数据 │ │训练 │ │部署 │ │ │
│ │ └─────┘ └─────┘ └─────┘ │ │
│ │ │ │
│ │ 优势: │ │
│ │ • 自动化编排 │ │
│ │ • 完整可复现 │ │
│ │ • 资源自动分配 │ │
│ │ • 全程可追溯 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
核心价值
# 工作流引擎核心价值
workflow_engine_value:
# 1. 自动化
automation:
- 端到端流水线自动执行
- 触发器驱动 (定时/事件/手动)
- 失败自动重试
- 条件分支执行
# 2. 可复现性
reproducibility:
- 代码版本化
- 数据版本化
- 环境容器化
- 参数记录
# 3. 可扩展性
scalability:
- 分布式执行
- 弹性资源
- 并行处理
- 跨集群调度
# 4. 可观测性
observability:
- 执行状态跟踪
- 日志聚合
- 指标监控
- 血缘追溯
# 5. 协作效率
collaboration:
- 组件复用
- 模板共享
- 权限管理
- 审计日志
主流工作流引擎对比
方案概览
┌─────────────────────────────────────────────────────────────────┐
│ 主流 ML 工作流引擎对比 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Kubeflow Pipelines │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ • Kubernetes 原生 │ │ │
│ │ │ • Argo Workflow 后端 │ │ │
│ │ │ • Python SDK (KFP) │ │ │
│ │ │ • 组件市场 │ │ │
│ │ │ 适用: 大规模 ML 平台 │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Argo Workflows │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ • 通用容器工作流 │ │ │
│ │ │ • YAML 定义 │ │ │
│ │ │ • DAG + Steps 模式 │ │ │
│ │ │ • 丰富的模板功能 │ │ │
│ │ │ 适用: 通用 CI/CD + ML │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Apache Airflow │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ • Python 原生定义 │ │ │
│ │ │ • 丰富的 Operator │ │ │
│ │ │ • 成熟的调度系统 │ │ │
│ │ │ • 强大的 UI │ │ │
│ │ │ 适用: 数据工程 + ML Pipeline │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ Prefect │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ • 现代 Python API │ │ │
│ │ │ • 动态工作流 │ │ │
│ │ │ • 本地/云端混合 │ │ │
│ │ │ • 简单易用 │ │ │
│ │ │ 适用: 中小规模数据/ML 流水线 │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ MLflow │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ • 实验跟踪为核心 │ │ │
│ │ │ • 模型注册表 │ │ │
│ │ │ • Projects 定义 │ │ │
│ │ │ • 多框架支持 │ │ │
│ │ │ 适用: 实验管理 + 简单流水线 │ │ │
│ │ └─────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
详细对比
# 工作流引擎详细对比
comparison:
kubeflow_pipelines:
pros:
- Kubernetes 原生集成
- 强大的 ML 组件生态
- 支持分布式训练
- 与 Kubeflow 生态无缝集成
cons:
- 学习曲线陡峭
- 部署复杂
- 资源消耗大
best_for:
- 大规模 ML 平台
- Kubernetes 环境
- 需要完整 MLOps 能力
argo_workflows:
pros:
- 灵活的 YAML 定义
- 丰富的工作流模式
- 轻量级部署
- 活跃的社区
cons:
- ML 特定功能较少
- 需要 Kubernetes
- UI 功能有限
best_for:
- 通用容器工作流
- CI/CD 集成
- 定制化需求高
airflow:
pros:
- 成熟稳定
- 丰富的 Operator
- 强大的调度能力
- 完善的监控
cons:
- 不是 Kubernetes 原生
- 扩展性有限
- 实时性较差
best_for:
- 数据工程
- 批处理任务
- 传统环境
prefect:
pros:
- 现代 Python API
- 动态工作流支持
- 本地开发友好
- 错误处理优秀
cons:
- 相对较新
- 企业功能需付费
- 生态相对小
best_for:
- 中小规模
- Python 为主
- 快速迭代
工作流引擎架构
通用架构设计
┌─────────────────────────────────────────────────────────────────┐
│ 工作流引擎通用架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 用户接口层 │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ Web UI │ │ CLI │ │ SDK │ │ API │ │ │
│ │ └────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘ │ │
│ └───────┼─────────────┼─────────────┼─────────────┼────────┘ │
│ └─────────────┴──────┬──────┴─────────────┘ │
│ │ │
│ ════════════════════════════╪════════════════════════════════ │
│ │ │
│ 控制平面 │ │
│ ┌────────────────────────────┴─────────────────────────────┐ │
│ │ │ │
│ │ ┌─────────────────────────────────────────────────────┐ │ │
│ │ │ API Server │ │ │
│ │ │ • 工作流 CRUD │ │ │
│ │ │ • 运行管理 │ │ │
│ │ │ • 认证授权 │ │ │
│ │ └─────────────────────────────────────────────────────┘ │ │
│ │ │ │
│ │ ┌───────────────┐ ┌───────────────┐ ┌──────────────┐ │ │
│ │ │ Scheduler │ │ Executor │ │ Monitor │ │ │
│ │ │ │ │ │ │ │ │ │
│ │ │ • DAG 解析 │ │ • 任务分发 │ │ • 状态监控 │ │ │
│ │ │ • 依赖计算 │ │ • 重试管理 │ │ • 日志收集 │ │ │
│ │ │ • 触发管理 │ │ • 并发控制 │ │ • 指标采集 │ │ │
│ │ └───────┬───────┘ └───────┬───────┘ └──────┬───────┘ │ │
│ │ │ │ │ │ │
│ └──────────┼──────────────────┼─────────────────┼──────────┘ │
│ │ │ │ │
│ ═══════════╪══════════════════╪═════════════════╪════════════ │
│ │ │ │ │
│ 数据平面 │ │ │ │
│ ┌──────────┼──────────────────┼─────────────────┼──────────┐ │
│ │ │ │ │ │ │
│ │ ┌───────┴───────┐ ┌───────┴───────┐ ┌─────┴────────┐ │ │
│ │ │ Metadata DB │ │ Artifact │ │ Queue │ │ │
│ │ │ │ │ Storage │ │ │ │ │
│ │ │ • 工作流定义 │ │ • 模型文件 │ │ • 任务队列 │ │ │
│ │ │ • 运行记录 │ │ • 数据集 │ │ • 事件队列 │ │ │
│ │ │ • 血缘关系 │ │ • 日志文件 │ │ │ │ │
│ │ └───────────────┘ └───────────────┘ └──────────────┘ │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │ │
│ ═════════════════════════════╪═══════════════════════════════ │
│ │ │
│ 执行层 │ │
│ ┌────────────────────────────┴─────────────────────────────┐ │
│ │ │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │
│ │ │ Kubernetes │ │ Docker │ │ Local │ │ │
│ │ │ Executor │ │ Executor │ │ Executor │ │ │
│ │ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │ │
│ │ │ │ │ │ │
│ │ ┌────┴────┐ ┌────┴────┐ ┌────┴────┐ │ │
│ │ │ Pods │ │Containers│ │ Process │ │ │
│ │ └─────────┘ └─────────┘ └─────────┘ │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
核心组件设计
"""
工作流引擎核心组件抽象
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass, field
from enum import Enum
from datetime import datetime
import uuid
class TaskState(Enum):
"""任务状态"""
PENDING = "pending"
QUEUED = "queued"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped"
CANCELLED = "cancelled"
class WorkflowState(Enum):
"""工作流状态"""
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class TaskSpec:
"""任务规格"""
name: str
image: str
command: List[str]
args: List[str] = field(default_factory=list)
env: Dict[str, str] = field(default_factory=dict)
resources: Dict[str, Any] = field(default_factory=dict)
inputs: Dict[str, Any] = field(default_factory=dict)
outputs: Dict[str, Any] = field(default_factory=dict)
retry_policy: Dict[str, Any] = field(default_factory=dict)
timeout: int = 3600 # 秒
@dataclass
class Task:
"""任务定义"""
id: str
name: str
spec: TaskSpec
dependencies: List[str] = field(default_factory=list)
state: TaskState = TaskState.PENDING
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
retry_count: int = 0
error_message: Optional[str] = None
@dataclass
class WorkflowSpec:
"""工作流规格"""
name: str
description: str = ""
tasks: List[Task] = field(default_factory=list)
parameters: Dict[str, Any] = field(default_factory=dict)
triggers: List[Dict] = field(default_factory=list)
labels: Dict[str, str] = field(default_factory=dict)
@dataclass
class WorkflowRun:
"""工作流运行实例"""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
workflow_id: str = ""
state: WorkflowState = WorkflowState.PENDING
parameters: Dict[str, Any] = field(default_factory=dict)
task_runs: Dict[str, Task] = field(default_factory=dict)
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
triggered_by: str = "manual"
class DAGParser:
"""DAG 解析器"""
def __init__(self):
self.tasks: Dict[str, Task] = {}
self.adjacency: Dict[str, List[str]] = {} # 邻接表
self.in_degree: Dict[str, int] = {} # 入度
def parse(self, workflow_spec: WorkflowSpec) -> 'DAGParser':
"""解析工作流为 DAG"""
for task in workflow_spec.tasks:
self.tasks[task.id] = task
self.adjacency[task.id] = []
self.in_degree[task.id] = len(task.dependencies)
for dep in task.dependencies:
if dep not in self.adjacency:
self.adjacency[dep] = []
self.adjacency[dep].append(task.id)
return self
def get_ready_tasks(self, completed: set) -> List[str]:
"""获取可执行的任务"""
ready = []
for task_id, degree in self.in_degree.items():
if task_id in completed:
continue
# 检查所有依赖是否完成
task = self.tasks[task_id]
if all(dep in completed for dep in task.dependencies):
ready.append(task_id)
return ready
def topological_sort(self) -> List[str]:
"""拓扑排序"""
in_degree = self.in_degree.copy()
queue = [task_id for task_id, degree in in_degree.items() if degree == 0]
result = []
while queue:
task_id = queue.pop(0)
result.append(task_id)
for next_task in self.adjacency.get(task_id, []):
in_degree[next_task] -= 1
if in_degree[next_task] == 0:
queue.append(next_task)
if len(result) != len(self.tasks):
raise ValueError("Workflow contains cycles")
return result
def validate(self) -> List[str]:
"""验证 DAG"""
errors = []
# 检查循环依赖
try:
self.topological_sort()
except ValueError as e:
errors.append(str(e))
# 检查依赖是否存在
for task in self.tasks.values():
for dep in task.dependencies:
if dep not in self.tasks:
errors.append(f"Task {task.id} depends on unknown task {dep}")
return errors
class Executor(ABC):
"""执行器抽象基类"""
@abstractmethod
async def execute(self, task: Task, context: Dict) -> TaskState:
"""执行任务"""
pass
@abstractmethod
async def cancel(self, task_id: str):
"""取消任务"""
pass
@abstractmethod
async def get_logs(self, task_id: str) -> str:
"""获取日志"""
pass
class Scheduler:
"""调度器"""
def __init__(
self,
executor: Executor,
max_concurrent: int = 10
):
self.executor = executor
self.max_concurrent = max_concurrent
self.running_tasks: Dict[str, Task] = {}
self.completed_tasks: set = set()
async def run_workflow(self, run: WorkflowRun, dag: DAGParser) -> WorkflowState:
"""运行工作流"""
run.state = WorkflowState.RUNNING
run.start_time = datetime.now()
try:
while True:
# 获取可执行任务
ready_tasks = dag.get_ready_tasks(self.completed_tasks)
if not ready_tasks and not self.running_tasks:
# 所有任务完成
break
# 启动新任务
for task_id in ready_tasks:
if len(self.running_tasks) >= self.max_concurrent:
break
if task_id not in self.running_tasks:
task = dag.tasks[task_id]
await self._start_task(task, run)
# 等待任务完成
await self._wait_for_completion()
# 检查最终状态
all_success = all(
dag.tasks[tid].state == TaskState.SUCCESS
for tid in dag.tasks
)
run.state = WorkflowState.SUCCESS if all_success else WorkflowState.FAILED
except Exception as e:
run.state = WorkflowState.FAILED
raise
finally:
run.end_time = datetime.now()
return run.state
async def _start_task(self, task: Task, run: WorkflowRun):
"""启动任务"""
task.state = TaskState.RUNNING
task.start_time = datetime.now()
self.running_tasks[task.id] = task
run.task_runs[task.id] = task
# 异步执行
context = {"workflow_run_id": run.id, "parameters": run.parameters}
state = await self.executor.execute(task, context)
task.state = state
task.end_time = datetime.now()
del self.running_tasks[task.id]
if state == TaskState.SUCCESS:
self.completed_tasks.add(task.id)
elif state == TaskState.FAILED and task.retry_count < task.spec.retry_policy.get("max_retries", 0):
# 重试
task.retry_count += 1
task.state = TaskState.PENDING
async def _wait_for_completion(self):
"""等待任务完成"""
import asyncio
await asyncio.sleep(1)
class WorkflowEngine:
"""工作流引擎"""
def __init__(
self,
executor: Executor,
metadata_store: 'MetadataStore' = None
):
self.executor = executor
self.metadata_store = metadata_store
self.scheduler = Scheduler(executor)
self.workflows: Dict[str, WorkflowSpec] = {}
self.runs: Dict[str, WorkflowRun] = {}
def register_workflow(self, spec: WorkflowSpec) -> str:
"""注册工作流"""
workflow_id = str(uuid.uuid4())
self.workflows[workflow_id] = spec
if self.metadata_store:
self.metadata_store.save_workflow(workflow_id, spec)
return workflow_id
async def run_workflow(
self,
workflow_id: str,
parameters: Dict[str, Any] = None
) -> WorkflowRun:
"""运行工作流"""
spec = self.workflows.get(workflow_id)
if not spec:
raise ValueError(f"Workflow {workflow_id} not found")
# 创建运行实例
run = WorkflowRun(
workflow_id=workflow_id,
parameters=parameters or {}
)
self.runs[run.id] = run
# 解析 DAG
dag = DAGParser().parse(spec)
errors = dag.validate()
if errors:
raise ValueError(f"Invalid workflow: {errors}")
# 执行
await self.scheduler.run_workflow(run, dag)
if self.metadata_store:
self.metadata_store.save_run(run)
return run
def get_run(self, run_id: str) -> Optional[WorkflowRun]:
"""获取运行实例"""
return self.runs.get(run_id)
def list_runs(
self,
workflow_id: str = None,
state: WorkflowState = None
) -> List[WorkflowRun]:
"""列出运行实例"""
runs = list(self.runs.values())
if workflow_id:
runs = [r for r in runs if r.workflow_id == workflow_id]
if state:
runs = [r for r in runs if r.state == state]
return runs
class MetadataStore(ABC):
"""元数据存储抽象"""
@abstractmethod
def save_workflow(self, workflow_id: str, spec: WorkflowSpec):
pass
@abstractmethod
def get_workflow(self, workflow_id: str) -> Optional[WorkflowSpec]:
pass
@abstractmethod
def save_run(self, run: WorkflowRun):
pass
@abstractmethod
def get_run(self, run_id: str) -> Optional[WorkflowRun]:
pass
# Python DSL for defining workflows
class WorkflowBuilder:
"""工作流构建器 - Python DSL"""
def __init__(self, name: str, description: str = ""):
self.name = name
self.description = description
self.tasks: List[Task] = []
self.parameters: Dict[str, Any] = {}
def add_parameter(self, name: str, default: Any = None, type_hint: str = "string"):
"""添加参数"""
self.parameters[name] = {
"default": default,
"type": type_hint
}
return self
def add_task(
self,
name: str,
image: str,
command: List[str],
dependencies: List[str] = None,
**kwargs
) -> 'WorkflowBuilder':
"""添加任务"""
task = Task(
id=name,
name=name,
spec=TaskSpec(
name=name,
image=image,
command=command,
**kwargs
),
dependencies=dependencies or []
)
self.tasks.append(task)
return self
def build(self) -> WorkflowSpec:
"""构建工作流"""
return WorkflowSpec(
name=self.name,
description=self.description,
tasks=self.tasks,
parameters=self.parameters
)
# 装饰器风格定义
def task(
name: str = None,
image: str = None,
dependencies: List[str] = None,
**kwargs
):
"""任务装饰器"""
def decorator(func: Callable):
task_name = name or func.__name__
# 从函数签名推断参数
import inspect
sig = inspect.signature(func)
task_spec = TaskSpec(
name=task_name,
image=image or "python:3.9",
command=["python", "-c", inspect.getsource(func)],
**kwargs
)
func._task_spec = task_spec
func._dependencies = dependencies or []
return func
return decorator
# 使用示例
if __name__ == "__main__":
# 使用 Builder 模式
workflow = (
WorkflowBuilder("ml-pipeline", "ML Training Pipeline")
.add_parameter("learning_rate", 0.001)
.add_parameter("epochs", 10)
.add_task(
name="preprocess",
image="python:3.9",
command=["python", "preprocess.py"],
resources={"cpu": "2", "memory": "4Gi"}
)
.add_task(
name="train",
image="pytorch/pytorch:2.0.0",
command=["python", "train.py"],
dependencies=["preprocess"],
resources={"nvidia.com/gpu": "1", "memory": "16Gi"}
)
.add_task(
name="evaluate",
image="python:3.9",
command=["python", "evaluate.py"],
dependencies=["train"]
)
.build()
)
print(f"Workflow: {workflow.name}")
print(f"Tasks: {[t.name for t in workflow.tasks]}")
工作流模式
常见工作流模式
┌─────────────────────────────────────────────────────────────────┐
│ 常见工作流模式 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. 顺序模式 (Sequential) │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │ A │───►│ B │───►│ C │───►│ D │ │
│ └─────┘ └─────┘ └─────┘ └─────┘ │
│ │
│ 2. 并行模式 (Parallel) │
│ ┌─────┐ │
│ ┌─►│ B │─┐ │
│ ┌─────┐ │ └─────┘ │ ┌─────┐ │
│ │ A │─┤ ├─►│ D │ │
│ └─────┘ │ ┌─────┐ │ └─────┘ │
│ └─►│ C │─┘ │
│ └─────┘ │
│ │
│ 3. 条件分支 (Conditional) │
│ ┌─────┐ │
│ Y │ B │ │
│ ┌──►└─────┘ │
│ ┌─────┐│ │
│ │ A ? ├┤ │
│ └─────┘│ │
│ └──►┌─────┐ │
│ N │ C │ │
│ └─────┘ │
│ │
│ 4. 循环模式 (Loop) │
│ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │ A │───►│ B │───►│ C ? │ │
│ └─────┘ └──▲──┘ └──┬──┘ │
│ │ N │ Y │
│ └───────────┘ │
│ │
│ 5. 扇出-扇入 (Fan-out/Fan-in) │
│ ┌─────┐ │
│ ┌─►│ B1 │─┐ │
│ │ └─────┘ │ │
│ ┌─────┐ │ ┌─────┐ │ ┌─────┐ │
│ │ A │─┼─►│ B2 │─┼─►│ C │ │
│ └─────┘ │ └─────┘ │ └─────┘ │
│ │ ┌─────┐ │ │
│ └─►│ B3 │─┘ │
│ └─────┘ │
│ │
│ 6. 子工作流 (Sub-workflow) │
│ ┌─────┐ ┌───────────────┐ ┌─────┐ │
│ │ A │───►│ Sub-Workflow │───►│ C │ │
│ └─────┘ │ ┌───┐ ┌───┐ │ └─────┘ │
│ │ │ X │►│ Y │ │ │
│ │ └───┘ └───┘ │ │
│ └───────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
模式实现
"""
工作流模式实现
"""
from typing import List, Dict, Any, Callable, Optional
from dataclasses import dataclass
import asyncio
@dataclass
class TaskResult:
"""任务结果"""
task_id: str
success: bool
output: Any
error: Optional[str] = None
class WorkflowPatterns:
"""工作流模式实现"""
@staticmethod
async def sequential(
tasks: List[Callable],
initial_input: Any = None
) -> List[TaskResult]:
"""
顺序执行模式
每个任务的输出作为下一个任务的输入
"""
results = []
current_input = initial_input
for i, task in enumerate(tasks):
try:
output = await task(current_input)
result = TaskResult(
task_id=f"task_{i}",
success=True,
output=output
)
current_input = output
except Exception as e:
result = TaskResult(
task_id=f"task_{i}",
success=False,
output=None,
error=str(e)
)
results.append(result)
break # 失败时停止
results.append(result)
return results
@staticmethod
async def parallel(
tasks: List[Callable],
inputs: List[Any] = None
) -> List[TaskResult]:
"""
并行执行模式
所有任务同时执行
"""
if inputs is None:
inputs = [None] * len(tasks)
async def run_task(idx: int, task: Callable, input_data: Any) -> TaskResult:
try:
output = await task(input_data)
return TaskResult(
task_id=f"task_{idx}",
success=True,
output=output
)
except Exception as e:
return TaskResult(
task_id=f"task_{idx}",
success=False,
output=None,
error=str(e)
)
results = await asyncio.gather(*[
run_task(i, task, inp)
for i, (task, inp) in enumerate(zip(tasks, inputs))
])
return list(results)
@staticmethod
async def conditional(
condition: Callable[[], bool],
if_true: Callable,
if_false: Callable,
input_data: Any = None
) -> TaskResult:
"""
条件分支模式
根据条件执行不同分支
"""
try:
branch = if_true if await condition() else if_false
output = await branch(input_data)
return TaskResult(
task_id="conditional",
success=True,
output=output
)
except Exception as e:
return TaskResult(
task_id="conditional",
success=False,
output=None,
error=str(e)
)
@staticmethod
async def loop(
task: Callable,
condition: Callable[[Any], bool],
initial_input: Any,
max_iterations: int = 100
) -> List[TaskResult]:
"""
循环模式
重复执行直到条件不满足
"""
results = []
current_input = initial_input
iteration = 0
while iteration < max_iterations:
try:
output = await task(current_input)
result = TaskResult(
task_id=f"iteration_{iteration}",
success=True,
output=output
)
results.append(result)
# 检查继续条件
if not await condition(output):
break
current_input = output
iteration += 1
except Exception as e:
result = TaskResult(
task_id=f"iteration_{iteration}",
success=False,
output=None,
error=str(e)
)
results.append(result)
break
return results
@staticmethod
async def fan_out_fan_in(
scatter_task: Callable[[Any], List[Any]],
process_task: Callable[[Any], Any],
gather_task: Callable[[List[Any]], Any],
input_data: Any
) -> TaskResult:
"""
扇出-扇入模式
1. 拆分输入为多个部分
2. 并行处理每个部分
3. 聚合结果
"""
try:
# 扇出
scattered = await scatter_task(input_data)
# 并行处理
processed = await asyncio.gather(*[
process_task(item) for item in scattered
])
# 扇入
result = await gather_task(list(processed))
return TaskResult(
task_id="fan_out_fan_in",
success=True,
output=result
)
except Exception as e:
return TaskResult(
task_id="fan_out_fan_in",
success=False,
output=None,
error=str(e)
)
@staticmethod
async def retry_with_backoff(
task: Callable,
input_data: Any,
max_retries: int = 3,
base_delay: float = 1.0,
max_delay: float = 60.0
) -> TaskResult:
"""
带退避的重试模式
"""
last_error = None
for attempt in range(max_retries + 1):
try:
output = await task(input_data)
return TaskResult(
task_id=f"attempt_{attempt}",
success=True,
output=output
)
except Exception as e:
last_error = str(e)
if attempt < max_retries:
# 指数退避
delay = min(base_delay * (2 ** attempt), max_delay)
await asyncio.sleep(delay)
return TaskResult(
task_id=f"attempt_{max_retries}",
success=False,
output=None,
error=f"Failed after {max_retries + 1} attempts: {last_error}"
)
# ML 特定模式
class MLWorkflowPatterns:
"""ML 工作流模式"""
@staticmethod
async def hyperparameter_search(
train_func: Callable,
param_grid: List[Dict],
evaluate_func: Callable,
select_best: Callable[[List[Dict]], Dict]
) -> Dict:
"""
超参数搜索模式
"""
# 并行训练多个模型
results = []
async def train_and_evaluate(params: Dict) -> Dict:
model = await train_func(params)
score = await evaluate_func(model)
return {"params": params, "model": model, "score": score}
results = await asyncio.gather(*[
train_and_evaluate(params) for params in param_grid
])
# 选择最佳
best = select_best(list(results))
return best
@staticmethod
async def cross_validation(
data_splitter: Callable[[Any, int], List[tuple]],
train_func: Callable,
evaluate_func: Callable,
data: Any,
n_folds: int = 5
) -> Dict:
"""
交叉验证模式
"""
folds = await data_splitter(data, n_folds)
scores = []
for i, (train_data, val_data) in enumerate(folds):
model = await train_func(train_data)
score = await evaluate_func(model, val_data)
scores.append(score)
return {
"scores": scores,
"mean_score": sum(scores) / len(scores),
"std_score": (sum((s - sum(scores)/len(scores))**2 for s in scores) / len(scores)) ** 0.5
}
@staticmethod
async def incremental_training(
model_loader: Callable,
data_fetcher: Callable,
train_step: Callable,
checkpoint_saver: Callable,
should_continue: Callable[[Dict], bool],
initial_checkpoint: str = None
) -> Dict:
"""
增量训练模式
"""
# 加载模型
model = await model_loader(initial_checkpoint)
metrics_history = []
while True:
# 获取新数据
data = await data_fetcher()
if data is None:
break
# 训练一步
metrics = await train_step(model, data)
metrics_history.append(metrics)
# 保存检查点
checkpoint = await checkpoint_saver(model, metrics)
# 检查是否继续
if not await should_continue(metrics):
break
return {
"final_checkpoint": checkpoint,
"metrics_history": metrics_history
}
# 使用示例
async def example():
"""工作流模式使用示例"""
# 定义简单任务
async def preprocess(data):
print(f"Preprocessing: {data}")
return f"processed_{data}"
async def train(data):
print(f"Training with: {data}")
return f"model_{data}"
async def evaluate(model):
print(f"Evaluating: {model}")
return 0.95
# 顺序执行
results = await WorkflowPatterns.sequential(
[preprocess, train, evaluate],
initial_input="raw_data"
)
print(f"Sequential results: {results}")
# 并行执行
async def process_shard(shard):
return f"processed_{shard}"
parallel_results = await WorkflowPatterns.parallel(
[process_shard, process_shard, process_shard],
["shard_1", "shard_2", "shard_3"]
)
print(f"Parallel results: {parallel_results}")
if __name__ == "__main__":
asyncio.run(example())
Kubernetes 集成
Argo Workflow 示例
# Argo Workflow ML Pipeline 示例
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: ml-pipeline-
spec:
entrypoint: ml-pipeline
# 参数定义
arguments:
parameters:
- name: learning-rate
value: "0.001"
- name: epochs
value: "10"
- name: model-name
value: "resnet50"
# 工作流模板
templates:
- name: ml-pipeline
dag:
tasks:
# 数据预处理
- name: preprocess
template: preprocess-template
arguments:
parameters:
- name: input-path
value: "s3://data/raw"
# 并行训练多个模型
- name: train-model-1
template: train-template
dependencies: [preprocess]
arguments:
parameters:
- name: learning-rate
value: "{{workflow.parameters.learning-rate}}"
- name: model-variant
value: "variant-1"
- name: train-model-2
template: train-template
dependencies: [preprocess]
arguments:
parameters:
- name: learning-rate
value: "0.0001"
- name: model-variant
value: "variant-2"
# 评估
- name: evaluate
template: evaluate-template
dependencies: [train-model-1, train-model-2]
# 选择最佳模型
- name: select-best
template: select-best-template
dependencies: [evaluate]
# 部署
- name: deploy
template: deploy-template
dependencies: [select-best]
when: "{{tasks.select-best.outputs.parameters.should-deploy}} == true"
# 预处理模板
- name: preprocess-template
inputs:
parameters:
- name: input-path
container:
image: python:3.9
command: [python, preprocess.py]
args:
- --input={{inputs.parameters.input-path}}
- --output=/data/processed
resources:
requests:
memory: "4Gi"
cpu: "2"
volumeMounts:
- name: data-volume
mountPath: /data
outputs:
artifacts:
- name: processed-data
path: /data/processed
# 训练模板
- name: train-template
inputs:
parameters:
- name: learning-rate
- name: model-variant
container:
image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
command: [python, train.py]
args:
- --lr={{inputs.parameters.learning-rate}}
- --variant={{inputs.parameters.model-variant}}
- --epochs={{workflow.parameters.epochs}}
resources:
limits:
nvidia.com/gpu: 1
memory: "32Gi"
requests:
nvidia.com/gpu: 1
memory: "32Gi"
outputs:
artifacts:
- name: model
path: /output/model
parameters:
- name: train-loss
valueFrom:
path: /output/metrics.json
# 评估模板
- name: evaluate-template
container:
image: python:3.9
command: [python, evaluate.py]
outputs:
parameters:
- name: best-model
valueFrom:
path: /output/best_model.txt
# 部署模板
- name: deploy-template
resource:
action: apply
manifest: |
apiVersion: serving.kubeflow.org/v1beta1
kind: InferenceService
metadata:
name: {{workflow.parameters.model-name}}
spec:
predictor:
pytorch:
storageUri: "s3://models/{{workflow.name}}"
# 持久卷
volumeClaimTemplates:
- metadata:
name: data-volume
spec:
accessModes: ["ReadWriteOnce"]
resources:
requests:
storage: 100Gi
---
# 定时触发的工作流
apiVersion: argoproj.io/v1alpha1
kind: CronWorkflow
metadata:
name: ml-pipeline-daily
spec:
schedule: "0 2 * * *" # 每天凌晨2点
timezone: "Asia/Shanghai"
workflowSpec:
entrypoint: ml-pipeline
# ... 同上
最佳实践
工作流设计原则
# 工作流设计最佳实践
best_practices:
# 1. 任务设计
task_design:
- name: "单一职责"
description: "每个任务只做一件事"
example: "分离数据下载和数据处理"
- name: "幂等性"
description: "重复执行产生相同结果"
implementation:
- 使用确定性随机种子
- 避免依赖外部可变状态
- name: "容器化"
description: "每个任务使用独立容器"
benefits:
- 环境隔离
- 可移植性
- 版本控制
# 2. 数据传递
data_passing:
- name: "小数据用参数"
description: "简单值通过参数传递"
limit: "< 1MB"
- name: "大数据用 Artifact"
description: "文件和数据集通过存储传递"
storage:
- S3/GCS
- PVC
- MinIO
- name: "避免重复传输"
description: "使用共享存储或缓存"
# 3. 错误处理
error_handling:
retry:
max_retries: 3
backoff:
type: exponential
initial: 10s
max: 5m
failure_strategy:
- fail_fast: "关键路径失败立即停止"
- continue: "非关键任务失败继续"
- retry_from_checkpoint: "从检查点恢复"
# 4. 资源管理
resource_management:
- name: "明确资源需求"
fields: [cpu, memory, gpu]
- name: "使用资源配额"
purpose: "防止资源滥用"
- name: "清理临时资源"
method: "设置 TTL 或手动清理"
# 5. 可观测性
observability:
logging:
- 结构化日志
- 统一日志收集
- 日志级别分离
metrics:
- 任务执行时间
- 成功/失败率
- 资源使用率
tracing:
- 分布式追踪
- 血缘关系记录
总结
AI 工作流引擎是 ML 工程化的核心组件:
- 核心价值:自动化、可复现、可扩展、可观测
- 主流方案:Kubeflow Pipelines、Argo Workflows、Airflow、Prefect
- 架构组成:DAG 调度器、执行器、元数据存储
- 工作流模式:顺序、并行、条件、循环、扇出扇入
选择工作流引擎时需要考虑:
- 团队技术栈(Python/YAML)
- 基础设施(Kubernetes/传统环境)
- 规模需求(单机/分布式)
- 生态集成(MLOps 工具链)