HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 基础设施深度教程

    • AI Infra 深度教程
    • GPU容器化

      • 01-GPU 架构基础
      • NVIDIA 容器运行时
      • GPU 共享与隔离
      • GPU 监控与调试
    • Kubernetes GPU调度

      • Device Plugin 机制深度解析
      • GPU 调度器实现
      • 拓扑感知调度
      • 弹性 GPU 调度
    • AI训练平台

      • 分布式训练框架
      • 训练任务调度
      • 模型存储与管理
      • 实验管理
      • 超参数优化
    • 推理服务

      • 推理引擎原理
      • 模型服务框架
      • 动态批处理
      • 推理优化技术
      • 多模型服务
    • 异构计算

      • 05-异构计算
      • 异构计算概述
      • GPU 虚拟化技术
      • NPU 与专用 AI 芯片
      • 设备拓扑感知调度
      • 算力池化与弹性调度
    • AI工作流引擎

      • 06-AI工作流引擎
      • AI 工作流引擎概述
      • Kubeflow Pipelines 深度实践
      • 03-Argo Workflows 深度实践
      • 04-数据版本管理
      • 05-实验跟踪与模型注册
    • MLOps实践

      • 07-MLOps实践
      • 01-MLOps 成熟度模型
      • 02-数据集工程
      • 03-Feature Store 特征存储
      • 04-模型评测体系
      • 05-模型安全与治理
    • AIOps实践

      • 08-AIOps实践
      • 01-AIOps概述与架构
      • 02-异常检测算法
      • 03-根因分析与告警聚合
      • 04-智能运维决策
      • 05-AIOps平台实战
    • 面试专题

      • 09-面试专题
      • 01-AI基础设施核心面试题
      • 02-大模型面试题
      • 03-系统设计面试题
    • CUDA编程与算子开发

      • 10-CUDA 编程与算子开发
      • 01-CUDA编程模型与内存层次
      • 02-高性能 Kernel 开发实战
      • 03-Tensor Core 与矩阵运算
      • 04-算子融合与优化技术
      • 05-Triton 编程入门
    • 通信与网络底层

      • 11-通信与网络底层
      • 01-NCCL 源码深度解析
      • 02-AllReduce 算法实现
      • 03-RDMA与InfiniBand原理
      • 04-网络拓扑与通信优化
      • 05-大规模集群网络架构
    • 框架源码解析

      • 12-框架源码解析
      • 01-PyTorch分布式源码解析
      • 02-DeepSpeed源码深度解析
      • 03-Megatron-LM源码解析
      • 04-vLLM推理引擎源码解析
      • 05-HuggingFace Transformers源码解析
    • 编译优化与图优化

      • 13-编译优化与图优化
      • 01-深度学习编译器概述
      • 02-TorchDynamo与torch.compile
      • 03-XLA编译器深度解析
      • 04-算子融合与Kernel优化
      • 05-自动调度与代码生成

03-Argo Workflows 深度实践

概述

Argo Workflows 是 Kubernetes 原生的工作流引擎,广泛用于 CI/CD、数据处理和机器学习流水线。本文深入探讨 Argo Workflows 的架构、高级特性和 ML 场景下的最佳实践。

Argo Workflows 架构

核心组件

┌─────────────────────────────────────────────────────────────────────┐
│                        Argo Workflows 架构                          │
├─────────────────────────────────────────────────────────────────────┤
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐                 │
│  │  Argo CLI   │  │  Argo UI    │  │  Argo Events│                 │
│  └──────┬──────┘  └──────┬──────┘  └──────┬──────┘                 │
│         │                │                │                         │
│         ▼                ▼                ▼                         │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                    Argo Server                               │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │ API Svc  │  │ Auth Svc │  │ Artifact │  │ Archive  │    │   │
│  │  │          │  │          │  │ Driver   │  │ Service  │    │   │
│  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘    │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                 Workflow Controller                          │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │ Informer │  │ Executor │  │ Garbage  │  │ Metrics  │    │   │
│  │  │          │  │          │  │ Collector│  │ Exporter │    │   │
│  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘    │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                    Kubernetes                                │   │
│  │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐    │   │
│  │  │ Workflow │  │ Pod      │  │ PVC      │  │ ConfigMap│    │   │
│  │  │ CRD      │  │          │  │          │  │          │    │   │
│  │  └──────────┘  └──────────┘  └──────────┘  └──────────┘    │   │
│  └─────────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────────┘

安装部署

# argo-install.yaml - 生产级部署配置
apiVersion: v1
kind: Namespace
metadata:
  name: argo
---
# Argo Workflows Controller
apiVersion: apps/v1
kind: Deployment
metadata:
  name: workflow-controller
  namespace: argo
spec:
  replicas: 2
  selector:
    matchLabels:
      app: workflow-controller
  template:
    metadata:
      labels:
        app: workflow-controller
    spec:
      serviceAccountName: argo-workflow-controller
      containers:
      - name: workflow-controller
        image: quay.io/argoproj/workflow-controller:v3.5.0
        args:
        - --configmap
        - workflow-controller-configmap
        - --executor-image
        - quay.io/argoproj/argoexec:v3.5.0
        - --namespaced
        - --managed-namespace
        - argo-workflows
        env:
        - name: LEADER_ELECTION_IDENTITY
          valueFrom:
            fieldRef:
              fieldPath: metadata.name
        resources:
          requests:
            cpu: 100m
            memory: 256Mi
          limits:
            cpu: 500m
            memory: 512Mi
        ports:
        - containerPort: 9090
          name: metrics
        livenessProbe:
          httpGet:
            path: /healthz
            port: 6060
          initialDelaySeconds: 30
          periodSeconds: 30
---
# Argo Server
apiVersion: apps/v1
kind: Deployment
metadata:
  name: argo-server
  namespace: argo
spec:
  replicas: 2
  selector:
    matchLabels:
      app: argo-server
  template:
    metadata:
      labels:
        app: argo-server
    spec:
      serviceAccountName: argo-server
      containers:
      - name: argo-server
        image: quay.io/argoproj/argocli:v3.5.0
        args:
        - server
        - --auth-mode=sso
        - --secure
        ports:
        - containerPort: 2746
          name: web
        env:
        - name: ARGO_SECURE
          value: "true"
        - name: ARGO_BASE_HREF
          value: /argo/
        resources:
          requests:
            cpu: 100m
            memory: 256Mi
          limits:
            cpu: 500m
            memory: 512Mi
        readinessProbe:
          httpGet:
            path: /
            port: 2746
            scheme: HTTPS
          initialDelaySeconds: 10
          periodSeconds: 10
---
# Controller ConfigMap
apiVersion: v1
kind: ConfigMap
metadata:
  name: workflow-controller-configmap
  namespace: argo
data:
  # 执行器配置
  executor: |
    resources:
      requests:
        cpu: 10m
        memory: 64Mi
      limits:
        cpu: 100m
        memory: 128Mi

  # Artifact 存储配置
  artifactRepository: |
    archiveLogs: true
    s3:
      bucket: argo-artifacts
      endpoint: minio.argo:9000
      insecure: true
      accessKeySecret:
        name: argo-artifacts
        key: accesskey
      secretKeySecret:
        name: argo-artifacts
        key: secretkey

  # 持久化配置
  persistence: |
    connectionPool:
      maxIdleConns: 100
      maxOpenConns: 0
    nodeStatusOffLoad: true
    archive: true
    archiveTTL: 30d
    postgresql:
      host: postgres.argo
      port: 5432
      database: argo
      tableName: argo_workflows
      userNameSecret:
        name: argo-postgres-config
        key: username
      passwordSecret:
        name: argo-postgres-config
        key: password

  # 资源限制
  workflowDefaults: |
    spec:
      ttlStrategy:
        secondsAfterCompletion: 86400
        secondsAfterSuccess: 3600
        secondsAfterFailure: 172800
      podGC:
        strategy: OnPodCompletion
      activeDeadlineSeconds: 86400

  # 并发控制
  parallelism: "50"

Workflow 模板详解

基础模板类型

# workflow-templates.yaml
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
  name: ml-workflow-templates
  namespace: argo-workflows
spec:
  # 1. Container 模板
  templates:
  - name: container-example
    container:
      image: python:3.9
      command: [python]
      args: ["{{inputs.parameters.script}}"]
      resources:
        requests:
          memory: "1Gi"
          cpu: "500m"
        limits:
          memory: "2Gi"
          cpu: "1"
      volumeMounts:
      - name: workdir
        mountPath: /workdir
    inputs:
      parameters:
      - name: script
    outputs:
      artifacts:
      - name: output
        path: /workdir/output

  # 2. Script 模板
  - name: script-example
    script:
      image: python:3.9
      command: [python]
      source: |
        import json
        import sys

        data = json.loads('{{inputs.parameters.data}}')
        result = {"processed": data, "count": len(data)}

        with open('/tmp/result.json', 'w') as f:
            json.dump(result, f)

        print(json.dumps(result))
      resources:
        requests:
          memory: "512Mi"
          cpu: "200m"
    inputs:
      parameters:
      - name: data
    outputs:
      parameters:
      - name: result
        valueFrom:
          path: /tmp/result.json

  # 3. Resource 模板 - 创建 K8s 资源
  - name: resource-example
    resource:
      action: create
      setOwnerReference: true
      manifest: |
        apiVersion: batch/v1
        kind: Job
        metadata:
          generateName: pytorch-job-
        spec:
          template:
            spec:
              containers:
              - name: pytorch
                image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
                command: ["python", "-c", "print('Hello PyTorch')"]
              restartPolicy: Never

  # 4. Suspend 模板 - 人工审批
  - name: approval-gate
    suspend: {}

  # 5. HTTP 模板 - API 调用
  - name: http-example
    http:
      url: "https://api.example.com/webhook"
      method: POST
      headers:
      - name: Content-Type
        value: application/json
      body: |
        {"workflow": "{{workflow.name}}", "status": "{{inputs.parameters.status}}"}
      successCondition: response.statusCode == 200
    inputs:
      parameters:
      - name: status

---
# DAG 工作流模板
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
  name: ml-dag-template
  namespace: argo-workflows
spec:
  entrypoint: ml-pipeline

  arguments:
    parameters:
    - name: dataset-path
      value: "s3://data/training"
    - name: model-type
      value: "resnet50"
    - name: epochs
      value: "100"

  templates:
  - name: ml-pipeline
    dag:
      tasks:
      # 数据准备阶段
      - name: validate-data
        template: data-validation
        arguments:
          parameters:
          - name: path
            value: "{{workflow.parameters.dataset-path}}"

      - name: preprocess
        template: data-preprocessing
        dependencies: [validate-data]
        arguments:
          artifacts:
          - name: raw-data
            from: "{{tasks.validate-data.outputs.artifacts.validated-data}}"

      # 特征工程(并行)
      - name: feature-extraction
        template: extract-features
        dependencies: [preprocess]
        arguments:
          artifacts:
          - name: processed-data
            from: "{{tasks.preprocess.outputs.artifacts.processed-data}}"

      - name: feature-selection
        template: select-features
        dependencies: [feature-extraction]
        arguments:
          artifacts:
          - name: features
            from: "{{tasks.feature-extraction.outputs.artifacts.features}}"

      # 模型训练
      - name: train-model
        template: model-training
        dependencies: [feature-selection]
        arguments:
          parameters:
          - name: model-type
            value: "{{workflow.parameters.model-type}}"
          - name: epochs
            value: "{{workflow.parameters.epochs}}"
          artifacts:
          - name: training-data
            from: "{{tasks.feature-selection.outputs.artifacts.selected-features}}"

      # 模型评估
      - name: evaluate-model
        template: model-evaluation
        dependencies: [train-model]
        arguments:
          artifacts:
          - name: model
            from: "{{tasks.train-model.outputs.artifacts.model}}"
          - name: test-data
            from: "{{tasks.preprocess.outputs.artifacts.test-data}}"

      # 条件部署
      - name: deploy-model
        template: model-deployment
        dependencies: [evaluate-model]
        when: "{{tasks.evaluate-model.outputs.parameters.accuracy}} > 0.9"
        arguments:
          artifacts:
          - name: model
            from: "{{tasks.train-model.outputs.artifacts.model}}"

  # 模板定义
  - name: data-validation
    inputs:
      parameters:
      - name: path
    outputs:
      artifacts:
      - name: validated-data
        path: /data/validated
    container:
      image: ml-platform/data-validator:v1.0
      command: [python, validate.py]
      args:
      - --input={{inputs.parameters.path}}
      - --output=/data/validated

  - name: data-preprocessing
    inputs:
      artifacts:
      - name: raw-data
        path: /data/raw
    outputs:
      artifacts:
      - name: processed-data
        path: /data/processed
      - name: test-data
        path: /data/test
    container:
      image: ml-platform/preprocessor:v1.0
      command: [python, preprocess.py]
      args:
      - --input=/data/raw
      - --output=/data/processed
      - --test-output=/data/test

  - name: model-training
    inputs:
      parameters:
      - name: model-type
      - name: epochs
      artifacts:
      - name: training-data
        path: /data/train
    outputs:
      artifacts:
      - name: model
        path: /models/trained
      - name: metrics
        path: /metrics
    container:
      image: ml-platform/trainer:v1.0
      command: [python, train.py]
      args:
      - --model={{inputs.parameters.model-type}}
      - --epochs={{inputs.parameters.epochs}}
      - --data=/data/train
      - --output=/models/trained
      resources:
        requests:
          nvidia.com/gpu: 1
          memory: "16Gi"
          cpu: "4"
        limits:
          nvidia.com/gpu: 1
          memory: "32Gi"
          cpu: "8"

  - name: model-evaluation
    inputs:
      artifacts:
      - name: model
        path: /models
      - name: test-data
        path: /data/test
    outputs:
      parameters:
      - name: accuracy
        valueFrom:
          path: /metrics/accuracy.txt
      artifacts:
      - name: report
        path: /reports
    container:
      image: ml-platform/evaluator:v1.0
      command: [python, evaluate.py]

  - name: model-deployment
    inputs:
      artifacts:
      - name: model
        path: /models
    container:
      image: ml-platform/deployer:v1.0
      command: [python, deploy.py]

高级 DAG 模式

动态 DAG 生成

# dynamic-dag.yaml
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
  generateName: dynamic-dag-
spec:
  entrypoint: main

  templates:
  - name: main
    steps:
    # 第一步:生成任务列表
    - - name: generate-tasks
        template: generate-task-list

    # 第二步:动态扩展执行
    - - name: process-tasks
        template: process-task
        arguments:
          parameters:
          - name: task-id
            value: "{{item}}"
        withParam: "{{steps.generate-tasks.outputs.result}}"

    # 第三步:聚合结果
    - - name: aggregate
        template: aggregate-results
        arguments:
          artifacts:
          - name: results
            from: "{{steps.process-tasks.outputs.artifacts.result}}"

  - name: generate-task-list
    script:
      image: python:3.9
      command: [python]
      source: |
        import json
        # 动态生成任务列表
        tasks = [f"task-{i}" for i in range(10)]
        print(json.dumps(tasks))

  - name: process-task
    inputs:
      parameters:
      - name: task-id
    outputs:
      artifacts:
      - name: result
        path: /tmp/result.json
    script:
      image: python:3.9
      command: [python]
      source: |
        import json
        task_id = "{{inputs.parameters.task-id}}"
        result = {"task_id": task_id, "status": "completed"}
        with open('/tmp/result.json', 'w') as f:
            json.dump(result, f)

  - name: aggregate-results
    inputs:
      artifacts:
      - name: results
        path: /tmp/results
    script:
      image: python:3.9
      command: [python]
      source: |
        import os
        import json

        results = []
        for f in os.listdir('/tmp/results'):
            with open(f'/tmp/results/{f}') as file:
                results.append(json.load(file))

        print(f"Aggregated {len(results)} results")
        print(json.dumps(results, indent=2))

---
# 条件分支与循环
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
  generateName: conditional-loop-
spec:
  entrypoint: main
  arguments:
    parameters:
    - name: iterations
      value: "5"
    - name: threshold
      value: "0.8"

  templates:
  - name: main
    dag:
      tasks:
      - name: init
        template: initialize

      - name: training-loop
        template: training-iteration
        dependencies: [init]
        arguments:
          parameters:
          - name: iteration
            value: "{{item}}"
          - name: prev-model
            value: "{{tasks.init.outputs.parameters.model-path}}"
        withSequence:
          count: "{{workflow.parameters.iterations}}"

      - name: final-eval
        template: final-evaluation
        dependencies: [training-loop]

  - name: initialize
    outputs:
      parameters:
      - name: model-path
        valueFrom:
          path: /tmp/model-path
    script:
      image: python:3.9
      command: [python]
      source: |
        model_path = "/models/initial"
        with open('/tmp/model-path', 'w') as f:
            f.write(model_path)
        print(f"Initialized model at {model_path}")

  - name: training-iteration
    inputs:
      parameters:
      - name: iteration
      - name: prev-model
    outputs:
      parameters:
      - name: accuracy
        valueFrom:
          path: /tmp/accuracy
    script:
      image: python:3.9
      command: [python]
      source: |
        import random

        iteration = {{inputs.parameters.iteration}}
        prev_model = "{{inputs.parameters.prev-model}}"

        # 模拟训练
        accuracy = 0.7 + (iteration * 0.05) + random.uniform(0, 0.1)

        with open('/tmp/accuracy', 'w') as f:
            f.write(str(accuracy))

        print(f"Iteration {iteration}: accuracy = {accuracy:.4f}")

  - name: final-evaluation
    script:
      image: python:3.9
      command: [python]
      source: |
        print("Final evaluation completed")

分布式训练工作流

# distributed-training-workflow.yaml
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
  generateName: distributed-training-
spec:
  entrypoint: distributed-pipeline

  arguments:
    parameters:
    - name: num-workers
      value: "4"
    - name: model
      value: "llama-7b"
    - name: dataset
      value: "s3://data/pretrain"

  volumes:
  - name: shared-data
    persistentVolumeClaim:
      claimName: training-data-pvc

  templates:
  - name: distributed-pipeline
    dag:
      tasks:
      - name: prepare-data
        template: data-preparation
        arguments:
          parameters:
          - name: dataset
            value: "{{workflow.parameters.dataset}}"
          - name: num-shards
            value: "{{workflow.parameters.num-workers}}"

      - name: launch-master
        template: training-master
        dependencies: [prepare-data]
        arguments:
          parameters:
          - name: num-workers
            value: "{{workflow.parameters.num-workers}}"
          - name: model
            value: "{{workflow.parameters.model}}"

      - name: launch-workers
        template: training-worker
        dependencies: [prepare-data, launch-master]
        arguments:
          parameters:
          - name: worker-id
            value: "{{item}}"
          - name: master-addr
            value: "{{tasks.launch-master.outputs.parameters.master-addr}}"
        withSequence:
          count: "{{workflow.parameters.num-workers}}"

      - name: wait-completion
        template: wait-for-training
        dependencies: [launch-workers]
        arguments:
          parameters:
          - name: master-addr
            value: "{{tasks.launch-master.outputs.parameters.master-addr}}"

      - name: collect-checkpoints
        template: checkpoint-collection
        dependencies: [wait-completion]

      - name: merge-model
        template: model-merging
        dependencies: [collect-checkpoints]

  - name: data-preparation
    inputs:
      parameters:
      - name: dataset
      - name: num-shards
    outputs:
      artifacts:
      - name: data-manifest
        path: /output/manifest.json
    container:
      image: ml-platform/data-prep:v1.0
      command: [python, shard_data.py]
      args:
      - --input={{inputs.parameters.dataset}}
      - --num-shards={{inputs.parameters.num-shards}}
      - --output=/output
      volumeMounts:
      - name: shared-data
        mountPath: /output
      resources:
        requests:
          cpu: "4"
          memory: "16Gi"

  - name: training-master
    inputs:
      parameters:
      - name: num-workers
      - name: model
    outputs:
      parameters:
      - name: master-addr
        valueFrom:
          path: /tmp/master-addr
    # 使用 daemon 模式保持 master 运行
    daemon: true
    container:
      image: ml-platform/distributed-trainer:v1.0
      command: [python, master.py]
      args:
      - --model={{inputs.parameters.model}}
      - --num-workers={{inputs.parameters.num-workers}}
      - --port=29500
      env:
      - name: MASTER_ADDR
        valueFrom:
          fieldRef:
            fieldPath: status.podIP
      - name: MASTER_PORT
        value: "29500"
      ports:
      - containerPort: 29500
      resources:
        requests:
          nvidia.com/gpu: 1
          cpu: "8"
          memory: "32Gi"
        limits:
          nvidia.com/gpu: 1
          memory: "64Gi"
      volumeMounts:
      - name: shared-data
        mountPath: /data

  - name: training-worker
    inputs:
      parameters:
      - name: worker-id
      - name: master-addr
    container:
      image: ml-platform/distributed-trainer:v1.0
      command: [python, worker.py]
      args:
      - --master-addr={{inputs.parameters.master-addr}}
      - --master-port=29500
      - --worker-id={{inputs.parameters.worker-id}}
      - --data-shard=/data/shard-{{inputs.parameters.worker-id}}
      env:
      - name: WORLD_SIZE
        value: "{{workflow.parameters.num-workers}}"
      - name: RANK
        value: "{{inputs.parameters.worker-id}}"
      resources:
        requests:
          nvidia.com/gpu: 1
          cpu: "8"
          memory: "32Gi"
        limits:
          nvidia.com/gpu: 1
          memory: "64Gi"
      volumeMounts:
      - name: shared-data
        mountPath: /data

  - name: wait-for-training
    inputs:
      parameters:
      - name: master-addr
    script:
      image: curlimages/curl:latest
      command: [sh]
      source: |
        while true; do
          status=$(curl -s http://{{inputs.parameters.master-addr}}:29500/status)
          if [ "$status" = "completed" ]; then
            echo "Training completed"
            break
          fi
          sleep 60
        done

  - name: checkpoint-collection
    outputs:
      artifacts:
      - name: checkpoints
        path: /checkpoints
    container:
      image: ml-platform/checkpoint-collector:v1.0
      command: [python, collect.py]
      volumeMounts:
      - name: shared-data
        mountPath: /data

  - name: model-merging
    container:
      image: ml-platform/model-merger:v1.0
      command: [python, merge.py]
      args:
      - --checkpoints=/data/checkpoints
      - --output=/models/final
      resources:
        requests:
          cpu: "16"
          memory: "64Gi"
      volumeMounts:
      - name: shared-data
        mountPath: /data

Artifact 管理

Artifact 配置与使用

# artifact_manager.py
"""
Argo Workflows Artifact 管理器
支持 S3、GCS、MinIO、HDFS 等多种存储后端
"""

import os
import yaml
import boto3
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import hashlib
import json


@dataclass
class ArtifactSpec:
    """Artifact 规格定义"""
    name: str
    path: str
    artifact_type: str = "file"  # file, directory, archive
    compression: Optional[str] = None  # gzip, zip, tar
    mode: int = 0o644
    optional: bool = False

    def to_argo_spec(self) -> Dict[str, Any]:
        spec = {
            "name": self.name,
            "path": self.path
        }
        if self.compression:
            spec["archive"] = {"none": {}} if self.compression == "none" else {self.compression: {}}
        if self.optional:
            spec["optional"] = True
        return spec


class ArtifactRepository(ABC):
    """Artifact 存储仓库抽象基类"""

    @abstractmethod
    def upload(self, local_path: str, remote_key: str) -> str:
        pass

    @abstractmethod
    def download(self, remote_key: str, local_path: str) -> bool:
        pass

    @abstractmethod
    def exists(self, remote_key: str) -> bool:
        pass

    @abstractmethod
    def delete(self, remote_key: str) -> bool:
        pass

    @abstractmethod
    def list(self, prefix: str) -> List[str]:
        pass

    @abstractmethod
    def get_argo_config(self) -> Dict[str, Any]:
        pass


class S3ArtifactRepository(ArtifactRepository):
    """S3 Artifact 存储"""

    def __init__(
        self,
        bucket: str,
        endpoint: Optional[str] = None,
        region: str = "us-east-1",
        access_key: Optional[str] = None,
        secret_key: Optional[str] = None,
        insecure: bool = False,
        key_prefix: str = "artifacts"
    ):
        self.bucket = bucket
        self.endpoint = endpoint
        self.region = region
        self.insecure = insecure
        self.key_prefix = key_prefix

        # 初始化 S3 客户端
        config = {
            "region_name": region
        }
        if endpoint:
            config["endpoint_url"] = f"{'http' if insecure else 'https'}://{endpoint}"
        if access_key and secret_key:
            config["aws_access_key_id"] = access_key
            config["aws_secret_access_key"] = secret_key

        self.client = boto3.client("s3", **config)

    def upload(self, local_path: str, remote_key: str) -> str:
        """上传文件到 S3"""
        full_key = f"{self.key_prefix}/{remote_key}"

        if os.path.isdir(local_path):
            # 上传目录
            for root, dirs, files in os.walk(local_path):
                for file in files:
                    file_path = os.path.join(root, file)
                    rel_path = os.path.relpath(file_path, local_path)
                    file_key = f"{full_key}/{rel_path}"
                    self.client.upload_file(file_path, self.bucket, file_key)
        else:
            # 上传文件
            self.client.upload_file(local_path, self.bucket, full_key)

        return f"s3://{self.bucket}/{full_key}"

    def download(self, remote_key: str, local_path: str) -> bool:
        """从 S3 下载文件"""
        full_key = f"{self.key_prefix}/{remote_key}"

        try:
            # 检查是否是目录(前缀)
            response = self.client.list_objects_v2(
                Bucket=self.bucket,
                Prefix=full_key,
                MaxKeys=2
            )

            objects = response.get("Contents", [])
            if len(objects) == 0:
                return False

            if len(objects) == 1 and objects[0]["Key"] == full_key:
                # 单文件
                os.makedirs(os.path.dirname(local_path), exist_ok=True)
                self.client.download_file(self.bucket, full_key, local_path)
            else:
                # 目录
                os.makedirs(local_path, exist_ok=True)
                paginator = self.client.get_paginator("list_objects_v2")

                for page in paginator.paginate(Bucket=self.bucket, Prefix=full_key):
                    for obj in page.get("Contents", []):
                        key = obj["Key"]
                        rel_path = key[len(full_key):].lstrip("/")
                        file_path = os.path.join(local_path, rel_path)
                        os.makedirs(os.path.dirname(file_path), exist_ok=True)
                        self.client.download_file(self.bucket, key, file_path)

            return True
        except Exception as e:
            print(f"Download failed: {e}")
            return False

    def exists(self, remote_key: str) -> bool:
        """检查文件是否存在"""
        full_key = f"{self.key_prefix}/{remote_key}"
        try:
            self.client.head_object(Bucket=self.bucket, Key=full_key)
            return True
        except:
            # 检查是否是前缀
            response = self.client.list_objects_v2(
                Bucket=self.bucket,
                Prefix=full_key,
                MaxKeys=1
            )
            return len(response.get("Contents", [])) > 0

    def delete(self, remote_key: str) -> bool:
        """删除文件或目录"""
        full_key = f"{self.key_prefix}/{remote_key}"
        try:
            # 删除所有匹配的对象
            paginator = self.client.get_paginator("list_objects_v2")

            for page in paginator.paginate(Bucket=self.bucket, Prefix=full_key):
                objects = page.get("Contents", [])
                if objects:
                    delete_objects = [{"Key": obj["Key"]} for obj in objects]
                    self.client.delete_objects(
                        Bucket=self.bucket,
                        Delete={"Objects": delete_objects}
                    )
            return True
        except Exception as e:
            print(f"Delete failed: {e}")
            return False

    def list(self, prefix: str = "") -> List[str]:
        """列出指定前缀下的所有文件"""
        full_prefix = f"{self.key_prefix}/{prefix}" if prefix else self.key_prefix
        keys = []

        paginator = self.client.get_paginator("list_objects_v2")
        for page in paginator.paginate(Bucket=self.bucket, Prefix=full_prefix):
            for obj in page.get("Contents", []):
                # 去除 key_prefix
                key = obj["Key"]
                if key.startswith(self.key_prefix):
                    key = key[len(self.key_prefix):].lstrip("/")
                keys.append(key)

        return keys

    def get_argo_config(self) -> Dict[str, Any]:
        """获取 Argo 配置"""
        config = {
            "s3": {
                "bucket": self.bucket,
                "keyPrefix": self.key_prefix,
                "region": self.region
            }
        }

        if self.endpoint:
            config["s3"]["endpoint"] = self.endpoint
            config["s3"]["insecure"] = self.insecure

        return config


class ArtifactCache:
    """Artifact 缓存管理"""

    def __init__(
        self,
        repository: ArtifactRepository,
        cache_dir: str = "/tmp/artifact-cache"
    ):
        self.repository = repository
        self.cache_dir = cache_dir
        self.manifest_file = os.path.join(cache_dir, "manifest.json")
        os.makedirs(cache_dir, exist_ok=True)
        self._load_manifest()

    def _load_manifest(self):
        """加载缓存清单"""
        if os.path.exists(self.manifest_file):
            with open(self.manifest_file) as f:
                self.manifest = json.load(f)
        else:
            self.manifest = {}

    def _save_manifest(self):
        """保存缓存清单"""
        with open(self.manifest_file, "w") as f:
            json.dump(self.manifest, f, indent=2)

    def _compute_hash(self, path: str) -> str:
        """计算文件/目录的 hash"""
        hasher = hashlib.sha256()

        if os.path.isfile(path):
            with open(path, "rb") as f:
                for chunk in iter(lambda: f.read(8192), b""):
                    hasher.update(chunk)
        else:
            for root, dirs, files in sorted(os.walk(path)):
                for file in sorted(files):
                    file_path = os.path.join(root, file)
                    rel_path = os.path.relpath(file_path, path)
                    hasher.update(rel_path.encode())
                    with open(file_path, "rb") as f:
                        for chunk in iter(lambda: f.read(8192), b""):
                            hasher.update(chunk)

        return hasher.hexdigest()

    def get_cached(self, key: str, output_path: str) -> bool:
        """获取缓存的 artifact"""
        if key not in self.manifest:
            return False

        cache_entry = self.manifest[key]
        cache_path = os.path.join(self.cache_dir, cache_entry["hash"])

        if os.path.exists(cache_path):
            # 从本地缓存复制
            if os.path.isdir(cache_path):
                import shutil
                shutil.copytree(cache_path, output_path)
            else:
                import shutil
                shutil.copy2(cache_path, output_path)
            return True

        # 从远程下载
        if self.repository.download(key, output_path):
            # 更新本地缓存
            if os.path.isdir(output_path):
                import shutil
                shutil.copytree(output_path, cache_path)
            else:
                import shutil
                shutil.copy2(output_path, cache_path)
            return True

        return False

    def put_cached(self, key: str, local_path: str) -> str:
        """缓存 artifact"""
        file_hash = self._compute_hash(local_path)
        cache_path = os.path.join(self.cache_dir, file_hash)

        # 保存到本地缓存
        if not os.path.exists(cache_path):
            import shutil
            if os.path.isdir(local_path):
                shutil.copytree(local_path, cache_path)
            else:
                shutil.copy2(local_path, cache_path)

        # 上传到远程
        remote_url = self.repository.upload(local_path, key)

        # 更新清单
        self.manifest[key] = {
            "hash": file_hash,
            "remote_url": remote_url,
            "timestamp": os.path.getmtime(local_path)
        }
        self._save_manifest()

        return remote_url

    def invalidate(self, key: str):
        """使缓存失效"""
        if key in self.manifest:
            cache_entry = self.manifest[key]
            cache_path = os.path.join(self.cache_dir, cache_entry["hash"])

            if os.path.exists(cache_path):
                import shutil
                if os.path.isdir(cache_path):
                    shutil.rmtree(cache_path)
                else:
                    os.remove(cache_path)

            del self.manifest[key]
            self._save_manifest()


class WorkflowArtifactManager:
    """工作流 Artifact 管理器"""

    def __init__(self, repository: ArtifactRepository):
        self.repository = repository
        self.cache = ArtifactCache(repository)

    def generate_artifact_config(
        self,
        workflow_name: str,
        inputs: List[ArtifactSpec],
        outputs: List[ArtifactSpec]
    ) -> Dict[str, Any]:
        """生成工作流 artifact 配置"""
        config = {
            "inputs": {
                "artifacts": [spec.to_argo_spec() for spec in inputs]
            },
            "outputs": {
                "artifacts": [spec.to_argo_spec() for spec in outputs]
            }
        }

        # 添加 artifact 仓库配置
        for artifact in config["outputs"]["artifacts"]:
            artifact["s3"] = {
                "key": f"{workflow_name}/{{workflow.uid}}/{artifact['name']}"
            }

        return config

    def setup_input_artifacts(
        self,
        workflow_id: str,
        artifacts: List[Dict[str, Any]]
    ) -> Dict[str, str]:
        """设置输入 artifacts"""
        paths = {}

        for artifact in artifacts:
            name = artifact["name"]
            source_key = artifact.get("key", f"inputs/{name}")
            local_path = artifact["path"]

            if self.cache.get_cached(source_key, local_path):
                paths[name] = local_path
            else:
                raise FileNotFoundError(f"Artifact {name} not found: {source_key}")

        return paths

    def upload_output_artifacts(
        self,
        workflow_id: str,
        artifacts: List[Dict[str, Any]]
    ) -> Dict[str, str]:
        """上传输出 artifacts"""
        urls = {}

        for artifact in artifacts:
            name = artifact["name"]
            local_path = artifact["path"]
            remote_key = f"workflows/{workflow_id}/outputs/{name}"

            url = self.cache.put_cached(remote_key, local_path)
            urls[name] = url

        return urls


# 使用示例
if __name__ == "__main__":
    # 创建 S3 仓库
    repo = S3ArtifactRepository(
        bucket="argo-artifacts",
        endpoint="minio.argo:9000",
        access_key="admin",
        secret_key="password",
        insecure=True
    )

    # 创建管理器
    manager = WorkflowArtifactManager(repo)

    # 定义 artifacts
    inputs = [
        ArtifactSpec(name="training-data", path="/data/train"),
        ArtifactSpec(name="config", path="/config/params.yaml")
    ]

    outputs = [
        ArtifactSpec(name="model", path="/models/trained", compression="gzip"),
        ArtifactSpec(name="metrics", path="/metrics/results.json")
    ]

    # 生成配置
    config = manager.generate_artifact_config(
        workflow_name="ml-training",
        inputs=inputs,
        outputs=outputs
    )

    print(yaml.dump(config, default_flow_style=False))

工作流监控与可观测性

Prometheus 指标集成

# argo-monitoring.yaml
apiVersion: v1
kind: ConfigMap
metadata:
  name: argo-prometheus-rules
  namespace: monitoring
data:
  argo-workflows.yaml: |
    groups:
    - name: argo-workflows
      rules:
      # 工作流状态指标
      - record: argo_workflow_status_phase
        expr: |
          sum by (namespace, phase) (
            argo_workflows_count{phase!=""}
          )

      # 工作流成功率
      - record: argo_workflow_success_rate
        expr: |
          sum(rate(argo_workflows_count{phase="Succeeded"}[1h])) /
          sum(rate(argo_workflows_count{phase=~"Succeeded|Failed|Error"}[1h]))

      # 平均工作流执行时间
      - record: argo_workflow_duration_seconds_avg
        expr: |
          avg by (workflow_name) (
            argo_workflows_pods_duration_seconds_sum /
            argo_workflows_pods_duration_seconds_count
          )

      # 工作流排队时间
      - record: argo_workflow_queue_duration_seconds_avg
        expr: |
          avg(argo_workflows_queue_duration_seconds_sum /
              argo_workflows_queue_duration_seconds_count)

      # 告警规则
      - alert: ArgoWorkflowHighFailureRate
        expr: |
          (1 - argo_workflow_success_rate) > 0.1
        for: 10m
        labels:
          severity: warning
        annotations:
          summary: "High workflow failure rate"
          description: "Workflow failure rate is above 10%"

      - alert: ArgoWorkflowStuck
        expr: |
          argo_workflows_count{phase="Running"} > 0
          and
          changes(argo_workflows_count{phase="Running"}[30m]) == 0
        for: 30m
        labels:
          severity: critical
        annotations:
          summary: "Workflow appears stuck"
          description: "Workflow has been in Running state for 30+ minutes without change"

      - alert: ArgoWorkflowQueueBacklog
        expr: |
          argo_workflows_count{phase="Pending"} > 50
        for: 15m
        labels:
          severity: warning
        annotations:
          summary: "Large workflow queue backlog"
          description: "More than 50 workflows pending"

      - alert: ArgoControllerDown
        expr: |
          up{job="workflow-controller"} == 0
        for: 5m
        labels:
          severity: critical
        annotations:
          summary: "Argo Workflow Controller is down"
          description: "The workflow controller has been down for 5 minutes"

---
# Grafana Dashboard ConfigMap
apiVersion: v1
kind: ConfigMap
metadata:
  name: argo-grafana-dashboard
  namespace: monitoring
  labels:
    grafana_dashboard: "1"
data:
  argo-workflows.json: |
    {
      "title": "Argo Workflows Dashboard",
      "uid": "argo-workflows",
      "panels": [
        {
          "title": "Workflow Status Distribution",
          "type": "piechart",
          "targets": [
            {
              "expr": "sum by (phase) (argo_workflows_count)",
              "legendFormat": "{{phase}}"
            }
          ]
        },
        {
          "title": "Workflow Success Rate (1h)",
          "type": "gauge",
          "targets": [
            {
              "expr": "argo_workflow_success_rate * 100"
            }
          ],
          "options": {
            "thresholds": {
              "steps": [
                {"color": "red", "value": 0},
                {"color": "yellow", "value": 80},
                {"color": "green", "value": 95}
              ]
            }
          }
        },
        {
          "title": "Workflow Execution Time (p95)",
          "type": "timeseries",
          "targets": [
            {
              "expr": "histogram_quantile(0.95, rate(argo_workflows_pods_duration_seconds_bucket[5m]))",
              "legendFormat": "p95"
            },
            {
              "expr": "histogram_quantile(0.50, rate(argo_workflows_pods_duration_seconds_bucket[5m]))",
              "legendFormat": "p50"
            }
          ]
        },
        {
          "title": "Active Workflows",
          "type": "timeseries",
          "targets": [
            {
              "expr": "argo_workflows_count{phase='Running'}",
              "legendFormat": "Running"
            },
            {
              "expr": "argo_workflows_count{phase='Pending'}",
              "legendFormat": "Pending"
            }
          ]
        },
        {
          "title": "Workflow Completion Rate",
          "type": "timeseries",
          "targets": [
            {
              "expr": "sum(rate(argo_workflows_count{phase='Succeeded'}[5m]))",
              "legendFormat": "Succeeded"
            },
            {
              "expr": "sum(rate(argo_workflows_count{phase='Failed'}[5m]))",
              "legendFormat": "Failed"
            }
          ]
        },
        {
          "title": "Resource Usage by Workflow",
          "type": "table",
          "targets": [
            {
              "expr": "topk(10, sum by (workflow) (argo_workflows_pods_cpu_usage_seconds_total))",
              "legendFormat": "{{workflow}}"
            }
          ]
        }
      ]
    }

工作流事件处理

# workflow_event_handler.py
"""
Argo Workflows 事件处理器
处理工作流生命周期事件、告警通知等
"""

import asyncio
import json
from typing import Dict, Any, Callable, List, Optional
from dataclasses import dataclass
from datetime import datetime
import aiohttp
from kubernetes import client, config, watch
from kubernetes.client.rest import ApiException
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class WorkflowEvent:
    """工作流事件"""
    event_type: str  # ADDED, MODIFIED, DELETED
    workflow_name: str
    namespace: str
    phase: str
    message: Optional[str]
    start_time: Optional[datetime]
    finish_time: Optional[datetime]
    resource_duration: Optional[int]
    raw_object: Dict[str, Any]

    @property
    def duration_seconds(self) -> Optional[float]:
        if self.start_time and self.finish_time:
            return (self.finish_time - self.start_time).total_seconds()
        return None

    @classmethod
    def from_k8s_event(cls, event: Dict[str, Any]) -> "WorkflowEvent":
        obj = event["object"]
        metadata = obj.get("metadata", {})
        status = obj.get("status", {})

        start_time = None
        finish_time = None

        if status.get("startedAt"):
            start_time = datetime.fromisoformat(
                status["startedAt"].replace("Z", "+00:00")
            )
        if status.get("finishedAt"):
            finish_time = datetime.fromisoformat(
                status["finishedAt"].replace("Z", "+00:00")
            )

        return cls(
            event_type=event["type"],
            workflow_name=metadata.get("name", ""),
            namespace=metadata.get("namespace", ""),
            phase=status.get("phase", "Unknown"),
            message=status.get("message"),
            start_time=start_time,
            finish_time=finish_time,
            resource_duration=status.get("resourcesDuration", {}).get("cpu"),
            raw_object=obj
        )


class NotificationChannel:
    """通知渠道基类"""

    async def send(self, event: WorkflowEvent, message: str):
        raise NotImplementedError


class SlackNotificationChannel(NotificationChannel):
    """Slack 通知渠道"""

    def __init__(self, webhook_url: str):
        self.webhook_url = webhook_url

    async def send(self, event: WorkflowEvent, message: str):
        color_map = {
            "Succeeded": "good",
            "Failed": "danger",
            "Error": "danger",
            "Running": "#439FE0"
        }

        payload = {
            "attachments": [{
                "color": color_map.get(event.phase, "#808080"),
                "title": f"Workflow {event.phase}: {event.workflow_name}",
                "text": message,
                "fields": [
                    {"title": "Namespace", "value": event.namespace, "short": True},
                    {"title": "Phase", "value": event.phase, "short": True}
                ],
                "footer": "Argo Workflows",
                "ts": int(datetime.now().timestamp())
            }]
        }

        if event.duration_seconds:
            payload["attachments"][0]["fields"].append({
                "title": "Duration",
                "value": f"{event.duration_seconds:.1f}s",
                "short": True
            })

        async with aiohttp.ClientSession() as session:
            async with session.post(self.webhook_url, json=payload) as resp:
                if resp.status != 200:
                    logger.error(f"Slack notification failed: {await resp.text()}")


class WebhookNotificationChannel(NotificationChannel):
    """通用 Webhook 通知渠道"""

    def __init__(self, url: str, headers: Optional[Dict[str, str]] = None):
        self.url = url
        self.headers = headers or {}

    async def send(self, event: WorkflowEvent, message: str):
        payload = {
            "workflow_name": event.workflow_name,
            "namespace": event.namespace,
            "phase": event.phase,
            "message": message,
            "event_type": event.event_type,
            "timestamp": datetime.now().isoformat(),
            "duration_seconds": event.duration_seconds
        }

        async with aiohttp.ClientSession() as session:
            async with session.post(
                self.url,
                json=payload,
                headers=self.headers
            ) as resp:
                if resp.status >= 400:
                    logger.error(f"Webhook notification failed: {await resp.text()}")


class WorkflowEventHandler:
    """工作流事件处理器"""

    def __init__(self, namespace: Optional[str] = None):
        self.namespace = namespace
        self.handlers: Dict[str, List[Callable]] = {
            "Running": [],
            "Succeeded": [],
            "Failed": [],
            "Error": [],
            "Pending": []
        }
        self.notification_channels: List[NotificationChannel] = []
        self._setup_k8s_client()

    def _setup_k8s_client(self):
        """设置 Kubernetes 客户端"""
        try:
            config.load_incluster_config()
        except:
            config.load_kube_config()

        self.custom_api = client.CustomObjectsApi()

    def on_phase(self, phase: str):
        """装饰器:注册事件处理函数"""
        def decorator(func: Callable):
            if phase in self.handlers:
                self.handlers[phase].append(func)
            return func
        return decorator

    def add_notification_channel(self, channel: NotificationChannel):
        """添加通知渠道"""
        self.notification_channels.append(channel)

    async def _handle_event(self, event: WorkflowEvent):
        """处理单个事件"""
        phase = event.phase

        # 调用注册的处理函数
        for handler in self.handlers.get(phase, []):
            try:
                if asyncio.iscoroutinefunction(handler):
                    await handler(event)
                else:
                    handler(event)
            except Exception as e:
                logger.error(f"Handler error: {e}")

        # 发送通知
        if phase in ["Succeeded", "Failed", "Error"]:
            message = self._format_message(event)
            for channel in self.notification_channels:
                try:
                    await channel.send(event, message)
                except Exception as e:
                    logger.error(f"Notification error: {e}")

    def _format_message(self, event: WorkflowEvent) -> str:
        """格式化通知消息"""
        lines = [
            f"Workflow: {event.workflow_name}",
            f"Status: {event.phase}"
        ]

        if event.message:
            lines.append(f"Message: {event.message}")

        if event.duration_seconds:
            minutes, seconds = divmod(int(event.duration_seconds), 60)
            lines.append(f"Duration: {minutes}m {seconds}s")

        return "\n".join(lines)

    async def watch_workflows(self):
        """监听工作流事件"""
        logger.info(f"Starting workflow watcher for namespace: {self.namespace or 'all'}")

        while True:
            try:
                w = watch.Watch()

                kwargs = {
                    "group": "argoproj.io",
                    "version": "v1alpha1",
                    "plural": "workflows"
                }

                if self.namespace:
                    kwargs["namespace"] = self.namespace
                    stream = w.stream(
                        self.custom_api.list_namespaced_custom_object,
                        **kwargs
                    )
                else:
                    stream = w.stream(
                        self.custom_api.list_cluster_custom_object,
                        **kwargs
                    )

                for event in stream:
                    try:
                        workflow_event = WorkflowEvent.from_k8s_event(event)
                        await self._handle_event(workflow_event)
                    except Exception as e:
                        logger.error(f"Event processing error: {e}")

            except ApiException as e:
                logger.error(f"K8s API error: {e}")
                await asyncio.sleep(5)
            except Exception as e:
                logger.error(f"Watch error: {e}")
                await asyncio.sleep(5)


class WorkflowMetricsCollector:
    """工作流指标收集器"""

    def __init__(self, handler: WorkflowEventHandler):
        self.handler = handler
        self.metrics = {
            "total_workflows": 0,
            "succeeded": 0,
            "failed": 0,
            "running": 0,
            "total_duration": 0.0
        }
        self._register_handlers()

    def _register_handlers(self):
        """注册指标收集处理函数"""

        @self.handler.on_phase("Running")
        def on_running(event: WorkflowEvent):
            self.metrics["running"] += 1
            self.metrics["total_workflows"] += 1

        @self.handler.on_phase("Succeeded")
        def on_succeeded(event: WorkflowEvent):
            self.metrics["running"] = max(0, self.metrics["running"] - 1)
            self.metrics["succeeded"] += 1
            if event.duration_seconds:
                self.metrics["total_duration"] += event.duration_seconds

        @self.handler.on_phase("Failed")
        def on_failed(event: WorkflowEvent):
            self.metrics["running"] = max(0, self.metrics["running"] - 1)
            self.metrics["failed"] += 1
            if event.duration_seconds:
                self.metrics["total_duration"] += event.duration_seconds

    def get_metrics(self) -> Dict[str, Any]:
        """获取当前指标"""
        completed = self.metrics["succeeded"] + self.metrics["failed"]
        return {
            **self.metrics,
            "success_rate": (
                self.metrics["succeeded"] / completed if completed > 0 else 0
            ),
            "avg_duration": (
                self.metrics["total_duration"] / completed if completed > 0 else 0
            )
        }


# 使用示例
async def main():
    # 创建事件处理器
    handler = WorkflowEventHandler(namespace="argo-workflows")

    # 添加 Slack 通知
    handler.add_notification_channel(
        SlackNotificationChannel(
            webhook_url="https://hooks.slack.com/services/xxx"
        )
    )

    # 添加自定义 Webhook
    handler.add_notification_channel(
        WebhookNotificationChannel(
            url="https://api.example.com/workflow-events",
            headers={"Authorization": "Bearer token"}
        )
    )

    # 创建指标收集器
    metrics_collector = WorkflowMetricsCollector(handler)

    # 注册自定义处理函数
    @handler.on_phase("Failed")
    async def handle_failure(event: WorkflowEvent):
        logger.error(f"Workflow failed: {event.workflow_name}")
        logger.error(f"Message: {event.message}")
        # 可以在这里添加自动重试逻辑

    @handler.on_phase("Succeeded")
    async def handle_success(event: WorkflowEvent):
        logger.info(f"Workflow succeeded: {event.workflow_name}")
        metrics = metrics_collector.get_metrics()
        logger.info(f"Current success rate: {metrics['success_rate']:.2%}")

    # 启动监听
    await handler.watch_workflows()


if __name__ == "__main__":
    asyncio.run(main())

最佳实践

工作流设计原则

# best-practices-workflow.yaml
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
  name: ml-best-practices
  namespace: argo-workflows
  labels:
    app.kubernetes.io/name: ml-pipeline
    app.kubernetes.io/component: workflow
spec:
  # 1. 设置合理的超时和重试
  entrypoint: main
  activeDeadlineSeconds: 86400  # 24小时全局超时

  # 2. 资源清理策略
  ttlStrategy:
    secondsAfterCompletion: 3600
    secondsAfterSuccess: 3600
    secondsAfterFailure: 86400

  podGC:
    strategy: OnPodCompletion
    labelSelector:
      matchLabels:
        cleanup: "true"

  # 3. 服务账号和安全
  serviceAccountName: ml-workflow-sa
  securityContext:
    runAsNonRoot: true
    runAsUser: 1000

  # 4. 节点选择和容忍
  nodeSelector:
    node-type: ml-worker

  tolerations:
  - key: "nvidia.com/gpu"
    operator: "Exists"
    effect: "NoSchedule"

  # 5. 资源配额
  podSpecPatch: |
    containers:
      - name: main
        resources:
          requests:
            cpu: "100m"
            memory: "256Mi"

  templates:
  - name: main
    dag:
      tasks:
      - name: step1
        template: reliable-step
        arguments:
          parameters:
          - name: step-name
            value: "data-prep"

      - name: step2
        template: gpu-step
        dependencies: [step1]

  # 6. 可靠的步骤模板
  - name: reliable-step
    inputs:
      parameters:
      - name: step-name
    # 重试策略
    retryStrategy:
      limit: 3
      retryPolicy: "Always"
      backoff:
        duration: "10s"
        factor: 2
        maxDuration: "5m"
    # 超时设置
    activeDeadlineSeconds: 3600
    container:
      image: ml-platform/worker:v1.0
      command: [python, process.py]
      args:
      - --step={{inputs.parameters.step-name}}
      # 资源限制
      resources:
        requests:
          cpu: "500m"
          memory: "1Gi"
        limits:
          cpu: "2"
          memory: "4Gi"
      # 存活检查
      livenessProbe:
        exec:
          command: [cat, /tmp/healthy]
        initialDelaySeconds: 30
        periodSeconds: 30
      # 优雅退出
      lifecycle:
        preStop:
          exec:
            command: ["/bin/sh", "-c", "sleep 5"]
    # Memoization 缓存
    memoize:
      key: "{{inputs.parameters.step-name}}-{{workflow.parameters.data-version}}"
      maxAge: "24h"
      cache:
        configMap:
          name: workflow-cache

  # 7. GPU 步骤模板
  - name: gpu-step
    activeDeadlineSeconds: 7200
    retryStrategy:
      limit: 2
      retryPolicy: "OnFailure"
    podSpecPatch: |
      containers:
        - name: main
          resources:
            limits:
              nvidia.com/gpu: 1
    container:
      image: ml-platform/gpu-worker:v1.0
      command: [python, train.py]
      env:
      - name: CUDA_VISIBLE_DEVICES
        value: "0"
      - name: NCCL_DEBUG
        value: "INFO"
      volumeMounts:
      - name: shm
        mountPath: /dev/shm
    volumes:
    - name: shm
      emptyDir:
        medium: Memory
        sizeLimit: "16Gi"

---
# 工作流 RBAC
apiVersion: v1
kind: ServiceAccount
metadata:
  name: ml-workflow-sa
  namespace: argo-workflows
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
  name: ml-workflow-role
  namespace: argo-workflows
rules:
- apiGroups: [""]
  resources: ["pods", "pods/log"]
  verbs: ["get", "list", "watch"]
- apiGroups: [""]
  resources: ["configmaps"]
  verbs: ["get", "list", "watch", "create", "update"]
- apiGroups: [""]
  resources: ["persistentvolumeclaims"]
  verbs: ["get", "list", "watch", "create", "delete"]
- apiGroups: ["argoproj.io"]
  resources: ["workflows", "workflowtemplates"]
  verbs: ["get", "list", "watch"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
  name: ml-workflow-rolebinding
  namespace: argo-workflows
subjects:
- kind: ServiceAccount
  name: ml-workflow-sa
roleRef:
  kind: Role
  name: ml-workflow-role
  apiGroup: rbac.authorization.k8s.io

总结

Argo Workflows 作为 Kubernetes 原生工作流引擎的核心优势:

  1. 云原生设计

    • 完全基于 Kubernetes CRD
    • 利用 K8s 原生资源管理
    • 易于与云原生生态集成
  2. 强大的 DAG 支持

    • 灵活的任务依赖定义
    • 动态 DAG 生成
    • 条件分支和循环
  3. 企业级特性

    • 完善的 RBAC 控制
    • 多租户支持
    • 审计日志
  4. ML/AI 友好

    • GPU 资源调度
    • 分布式训练支持
    • Artifact 管理
  5. 可观测性

    • Prometheus 指标
    • 结构化日志
    • 事件驱动通知

在下一章节中,我们将探讨数据版本管理,了解如何使用 DVC、LakeFS 等工具管理 ML 项目中的数据版本。

Prev
Kubeflow Pipelines 深度实践
Next
04-数据版本管理