03-Argo Workflows 深度实践
概述
Argo Workflows 是 Kubernetes 原生的工作流引擎,广泛用于 CI/CD、数据处理和机器学习流水线。本文深入探讨 Argo Workflows 的架构、高级特性和 ML 场景下的最佳实践。
Argo Workflows 架构
核心组件
┌─────────────────────────────────────────────────────────────────────┐
│ Argo Workflows 架构 │
├─────────────────────────────────────────────────────────────────────┤
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Argo CLI │ │ Argo UI │ │ Argo Events│ │
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Argo Server │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ API Svc │ │ Auth Svc │ │ Artifact │ │ Archive │ │ │
│ │ │ │ │ │ │ Driver │ │ Service │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Workflow Controller │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ Informer │ │ Executor │ │ Garbage │ │ Metrics │ │ │
│ │ │ │ │ │ │ Collector│ │ Exporter │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Kubernetes │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ Workflow │ │ Pod │ │ PVC │ │ ConfigMap│ │ │
│ │ │ CRD │ │ │ │ │ │ │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
安装部署
# argo-install.yaml - 生产级部署配置
apiVersion: v1
kind: Namespace
metadata:
name: argo
---
# Argo Workflows Controller
apiVersion: apps/v1
kind: Deployment
metadata:
name: workflow-controller
namespace: argo
spec:
replicas: 2
selector:
matchLabels:
app: workflow-controller
template:
metadata:
labels:
app: workflow-controller
spec:
serviceAccountName: argo-workflow-controller
containers:
- name: workflow-controller
image: quay.io/argoproj/workflow-controller:v3.5.0
args:
- --configmap
- workflow-controller-configmap
- --executor-image
- quay.io/argoproj/argoexec:v3.5.0
- --namespaced
- --managed-namespace
- argo-workflows
env:
- name: LEADER_ELECTION_IDENTITY
valueFrom:
fieldRef:
fieldPath: metadata.name
resources:
requests:
cpu: 100m
memory: 256Mi
limits:
cpu: 500m
memory: 512Mi
ports:
- containerPort: 9090
name: metrics
livenessProbe:
httpGet:
path: /healthz
port: 6060
initialDelaySeconds: 30
periodSeconds: 30
---
# Argo Server
apiVersion: apps/v1
kind: Deployment
metadata:
name: argo-server
namespace: argo
spec:
replicas: 2
selector:
matchLabels:
app: argo-server
template:
metadata:
labels:
app: argo-server
spec:
serviceAccountName: argo-server
containers:
- name: argo-server
image: quay.io/argoproj/argocli:v3.5.0
args:
- server
- --auth-mode=sso
- --secure
ports:
- containerPort: 2746
name: web
env:
- name: ARGO_SECURE
value: "true"
- name: ARGO_BASE_HREF
value: /argo/
resources:
requests:
cpu: 100m
memory: 256Mi
limits:
cpu: 500m
memory: 512Mi
readinessProbe:
httpGet:
path: /
port: 2746
scheme: HTTPS
initialDelaySeconds: 10
periodSeconds: 10
---
# Controller ConfigMap
apiVersion: v1
kind: ConfigMap
metadata:
name: workflow-controller-configmap
namespace: argo
data:
# 执行器配置
executor: |
resources:
requests:
cpu: 10m
memory: 64Mi
limits:
cpu: 100m
memory: 128Mi
# Artifact 存储配置
artifactRepository: |
archiveLogs: true
s3:
bucket: argo-artifacts
endpoint: minio.argo:9000
insecure: true
accessKeySecret:
name: argo-artifacts
key: accesskey
secretKeySecret:
name: argo-artifacts
key: secretkey
# 持久化配置
persistence: |
connectionPool:
maxIdleConns: 100
maxOpenConns: 0
nodeStatusOffLoad: true
archive: true
archiveTTL: 30d
postgresql:
host: postgres.argo
port: 5432
database: argo
tableName: argo_workflows
userNameSecret:
name: argo-postgres-config
key: username
passwordSecret:
name: argo-postgres-config
key: password
# 资源限制
workflowDefaults: |
spec:
ttlStrategy:
secondsAfterCompletion: 86400
secondsAfterSuccess: 3600
secondsAfterFailure: 172800
podGC:
strategy: OnPodCompletion
activeDeadlineSeconds: 86400
# 并发控制
parallelism: "50"
Workflow 模板详解
基础模板类型
# workflow-templates.yaml
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: ml-workflow-templates
namespace: argo-workflows
spec:
# 1. Container 模板
templates:
- name: container-example
container:
image: python:3.9
command: [python]
args: ["{{inputs.parameters.script}}"]
resources:
requests:
memory: "1Gi"
cpu: "500m"
limits:
memory: "2Gi"
cpu: "1"
volumeMounts:
- name: workdir
mountPath: /workdir
inputs:
parameters:
- name: script
outputs:
artifacts:
- name: output
path: /workdir/output
# 2. Script 模板
- name: script-example
script:
image: python:3.9
command: [python]
source: |
import json
import sys
data = json.loads('{{inputs.parameters.data}}')
result = {"processed": data, "count": len(data)}
with open('/tmp/result.json', 'w') as f:
json.dump(result, f)
print(json.dumps(result))
resources:
requests:
memory: "512Mi"
cpu: "200m"
inputs:
parameters:
- name: data
outputs:
parameters:
- name: result
valueFrom:
path: /tmp/result.json
# 3. Resource 模板 - 创建 K8s 资源
- name: resource-example
resource:
action: create
setOwnerReference: true
manifest: |
apiVersion: batch/v1
kind: Job
metadata:
generateName: pytorch-job-
spec:
template:
spec:
containers:
- name: pytorch
image: pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
command: ["python", "-c", "print('Hello PyTorch')"]
restartPolicy: Never
# 4. Suspend 模板 - 人工审批
- name: approval-gate
suspend: {}
# 5. HTTP 模板 - API 调用
- name: http-example
http:
url: "https://api.example.com/webhook"
method: POST
headers:
- name: Content-Type
value: application/json
body: |
{"workflow": "{{workflow.name}}", "status": "{{inputs.parameters.status}}"}
successCondition: response.statusCode == 200
inputs:
parameters:
- name: status
---
# DAG 工作流模板
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: ml-dag-template
namespace: argo-workflows
spec:
entrypoint: ml-pipeline
arguments:
parameters:
- name: dataset-path
value: "s3://data/training"
- name: model-type
value: "resnet50"
- name: epochs
value: "100"
templates:
- name: ml-pipeline
dag:
tasks:
# 数据准备阶段
- name: validate-data
template: data-validation
arguments:
parameters:
- name: path
value: "{{workflow.parameters.dataset-path}}"
- name: preprocess
template: data-preprocessing
dependencies: [validate-data]
arguments:
artifacts:
- name: raw-data
from: "{{tasks.validate-data.outputs.artifacts.validated-data}}"
# 特征工程(并行)
- name: feature-extraction
template: extract-features
dependencies: [preprocess]
arguments:
artifacts:
- name: processed-data
from: "{{tasks.preprocess.outputs.artifacts.processed-data}}"
- name: feature-selection
template: select-features
dependencies: [feature-extraction]
arguments:
artifacts:
- name: features
from: "{{tasks.feature-extraction.outputs.artifacts.features}}"
# 模型训练
- name: train-model
template: model-training
dependencies: [feature-selection]
arguments:
parameters:
- name: model-type
value: "{{workflow.parameters.model-type}}"
- name: epochs
value: "{{workflow.parameters.epochs}}"
artifacts:
- name: training-data
from: "{{tasks.feature-selection.outputs.artifacts.selected-features}}"
# 模型评估
- name: evaluate-model
template: model-evaluation
dependencies: [train-model]
arguments:
artifacts:
- name: model
from: "{{tasks.train-model.outputs.artifacts.model}}"
- name: test-data
from: "{{tasks.preprocess.outputs.artifacts.test-data}}"
# 条件部署
- name: deploy-model
template: model-deployment
dependencies: [evaluate-model]
when: "{{tasks.evaluate-model.outputs.parameters.accuracy}} > 0.9"
arguments:
artifacts:
- name: model
from: "{{tasks.train-model.outputs.artifacts.model}}"
# 模板定义
- name: data-validation
inputs:
parameters:
- name: path
outputs:
artifacts:
- name: validated-data
path: /data/validated
container:
image: ml-platform/data-validator:v1.0
command: [python, validate.py]
args:
- --input={{inputs.parameters.path}}
- --output=/data/validated
- name: data-preprocessing
inputs:
artifacts:
- name: raw-data
path: /data/raw
outputs:
artifacts:
- name: processed-data
path: /data/processed
- name: test-data
path: /data/test
container:
image: ml-platform/preprocessor:v1.0
command: [python, preprocess.py]
args:
- --input=/data/raw
- --output=/data/processed
- --test-output=/data/test
- name: model-training
inputs:
parameters:
- name: model-type
- name: epochs
artifacts:
- name: training-data
path: /data/train
outputs:
artifacts:
- name: model
path: /models/trained
- name: metrics
path: /metrics
container:
image: ml-platform/trainer:v1.0
command: [python, train.py]
args:
- --model={{inputs.parameters.model-type}}
- --epochs={{inputs.parameters.epochs}}
- --data=/data/train
- --output=/models/trained
resources:
requests:
nvidia.com/gpu: 1
memory: "16Gi"
cpu: "4"
limits:
nvidia.com/gpu: 1
memory: "32Gi"
cpu: "8"
- name: model-evaluation
inputs:
artifacts:
- name: model
path: /models
- name: test-data
path: /data/test
outputs:
parameters:
- name: accuracy
valueFrom:
path: /metrics/accuracy.txt
artifacts:
- name: report
path: /reports
container:
image: ml-platform/evaluator:v1.0
command: [python, evaluate.py]
- name: model-deployment
inputs:
artifacts:
- name: model
path: /models
container:
image: ml-platform/deployer:v1.0
command: [python, deploy.py]
高级 DAG 模式
动态 DAG 生成
# dynamic-dag.yaml
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: dynamic-dag-
spec:
entrypoint: main
templates:
- name: main
steps:
# 第一步:生成任务列表
- - name: generate-tasks
template: generate-task-list
# 第二步:动态扩展执行
- - name: process-tasks
template: process-task
arguments:
parameters:
- name: task-id
value: "{{item}}"
withParam: "{{steps.generate-tasks.outputs.result}}"
# 第三步:聚合结果
- - name: aggregate
template: aggregate-results
arguments:
artifacts:
- name: results
from: "{{steps.process-tasks.outputs.artifacts.result}}"
- name: generate-task-list
script:
image: python:3.9
command: [python]
source: |
import json
# 动态生成任务列表
tasks = [f"task-{i}" for i in range(10)]
print(json.dumps(tasks))
- name: process-task
inputs:
parameters:
- name: task-id
outputs:
artifacts:
- name: result
path: /tmp/result.json
script:
image: python:3.9
command: [python]
source: |
import json
task_id = "{{inputs.parameters.task-id}}"
result = {"task_id": task_id, "status": "completed"}
with open('/tmp/result.json', 'w') as f:
json.dump(result, f)
- name: aggregate-results
inputs:
artifacts:
- name: results
path: /tmp/results
script:
image: python:3.9
command: [python]
source: |
import os
import json
results = []
for f in os.listdir('/tmp/results'):
with open(f'/tmp/results/{f}') as file:
results.append(json.load(file))
print(f"Aggregated {len(results)} results")
print(json.dumps(results, indent=2))
---
# 条件分支与循环
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: conditional-loop-
spec:
entrypoint: main
arguments:
parameters:
- name: iterations
value: "5"
- name: threshold
value: "0.8"
templates:
- name: main
dag:
tasks:
- name: init
template: initialize
- name: training-loop
template: training-iteration
dependencies: [init]
arguments:
parameters:
- name: iteration
value: "{{item}}"
- name: prev-model
value: "{{tasks.init.outputs.parameters.model-path}}"
withSequence:
count: "{{workflow.parameters.iterations}}"
- name: final-eval
template: final-evaluation
dependencies: [training-loop]
- name: initialize
outputs:
parameters:
- name: model-path
valueFrom:
path: /tmp/model-path
script:
image: python:3.9
command: [python]
source: |
model_path = "/models/initial"
with open('/tmp/model-path', 'w') as f:
f.write(model_path)
print(f"Initialized model at {model_path}")
- name: training-iteration
inputs:
parameters:
- name: iteration
- name: prev-model
outputs:
parameters:
- name: accuracy
valueFrom:
path: /tmp/accuracy
script:
image: python:3.9
command: [python]
source: |
import random
iteration = {{inputs.parameters.iteration}}
prev_model = "{{inputs.parameters.prev-model}}"
# 模拟训练
accuracy = 0.7 + (iteration * 0.05) + random.uniform(0, 0.1)
with open('/tmp/accuracy', 'w') as f:
f.write(str(accuracy))
print(f"Iteration {iteration}: accuracy = {accuracy:.4f}")
- name: final-evaluation
script:
image: python:3.9
command: [python]
source: |
print("Final evaluation completed")
分布式训练工作流
# distributed-training-workflow.yaml
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: distributed-training-
spec:
entrypoint: distributed-pipeline
arguments:
parameters:
- name: num-workers
value: "4"
- name: model
value: "llama-7b"
- name: dataset
value: "s3://data/pretrain"
volumes:
- name: shared-data
persistentVolumeClaim:
claimName: training-data-pvc
templates:
- name: distributed-pipeline
dag:
tasks:
- name: prepare-data
template: data-preparation
arguments:
parameters:
- name: dataset
value: "{{workflow.parameters.dataset}}"
- name: num-shards
value: "{{workflow.parameters.num-workers}}"
- name: launch-master
template: training-master
dependencies: [prepare-data]
arguments:
parameters:
- name: num-workers
value: "{{workflow.parameters.num-workers}}"
- name: model
value: "{{workflow.parameters.model}}"
- name: launch-workers
template: training-worker
dependencies: [prepare-data, launch-master]
arguments:
parameters:
- name: worker-id
value: "{{item}}"
- name: master-addr
value: "{{tasks.launch-master.outputs.parameters.master-addr}}"
withSequence:
count: "{{workflow.parameters.num-workers}}"
- name: wait-completion
template: wait-for-training
dependencies: [launch-workers]
arguments:
parameters:
- name: master-addr
value: "{{tasks.launch-master.outputs.parameters.master-addr}}"
- name: collect-checkpoints
template: checkpoint-collection
dependencies: [wait-completion]
- name: merge-model
template: model-merging
dependencies: [collect-checkpoints]
- name: data-preparation
inputs:
parameters:
- name: dataset
- name: num-shards
outputs:
artifacts:
- name: data-manifest
path: /output/manifest.json
container:
image: ml-platform/data-prep:v1.0
command: [python, shard_data.py]
args:
- --input={{inputs.parameters.dataset}}
- --num-shards={{inputs.parameters.num-shards}}
- --output=/output
volumeMounts:
- name: shared-data
mountPath: /output
resources:
requests:
cpu: "4"
memory: "16Gi"
- name: training-master
inputs:
parameters:
- name: num-workers
- name: model
outputs:
parameters:
- name: master-addr
valueFrom:
path: /tmp/master-addr
# 使用 daemon 模式保持 master 运行
daemon: true
container:
image: ml-platform/distributed-trainer:v1.0
command: [python, master.py]
args:
- --model={{inputs.parameters.model}}
- --num-workers={{inputs.parameters.num-workers}}
- --port=29500
env:
- name: MASTER_ADDR
valueFrom:
fieldRef:
fieldPath: status.podIP
- name: MASTER_PORT
value: "29500"
ports:
- containerPort: 29500
resources:
requests:
nvidia.com/gpu: 1
cpu: "8"
memory: "32Gi"
limits:
nvidia.com/gpu: 1
memory: "64Gi"
volumeMounts:
- name: shared-data
mountPath: /data
- name: training-worker
inputs:
parameters:
- name: worker-id
- name: master-addr
container:
image: ml-platform/distributed-trainer:v1.0
command: [python, worker.py]
args:
- --master-addr={{inputs.parameters.master-addr}}
- --master-port=29500
- --worker-id={{inputs.parameters.worker-id}}
- --data-shard=/data/shard-{{inputs.parameters.worker-id}}
env:
- name: WORLD_SIZE
value: "{{workflow.parameters.num-workers}}"
- name: RANK
value: "{{inputs.parameters.worker-id}}"
resources:
requests:
nvidia.com/gpu: 1
cpu: "8"
memory: "32Gi"
limits:
nvidia.com/gpu: 1
memory: "64Gi"
volumeMounts:
- name: shared-data
mountPath: /data
- name: wait-for-training
inputs:
parameters:
- name: master-addr
script:
image: curlimages/curl:latest
command: [sh]
source: |
while true; do
status=$(curl -s http://{{inputs.parameters.master-addr}}:29500/status)
if [ "$status" = "completed" ]; then
echo "Training completed"
break
fi
sleep 60
done
- name: checkpoint-collection
outputs:
artifacts:
- name: checkpoints
path: /checkpoints
container:
image: ml-platform/checkpoint-collector:v1.0
command: [python, collect.py]
volumeMounts:
- name: shared-data
mountPath: /data
- name: model-merging
container:
image: ml-platform/model-merger:v1.0
command: [python, merge.py]
args:
- --checkpoints=/data/checkpoints
- --output=/models/final
resources:
requests:
cpu: "16"
memory: "64Gi"
volumeMounts:
- name: shared-data
mountPath: /data
Artifact 管理
Artifact 配置与使用
# artifact_manager.py
"""
Argo Workflows Artifact 管理器
支持 S3、GCS、MinIO、HDFS 等多种存储后端
"""
import os
import yaml
import boto3
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import hashlib
import json
@dataclass
class ArtifactSpec:
"""Artifact 规格定义"""
name: str
path: str
artifact_type: str = "file" # file, directory, archive
compression: Optional[str] = None # gzip, zip, tar
mode: int = 0o644
optional: bool = False
def to_argo_spec(self) -> Dict[str, Any]:
spec = {
"name": self.name,
"path": self.path
}
if self.compression:
spec["archive"] = {"none": {}} if self.compression == "none" else {self.compression: {}}
if self.optional:
spec["optional"] = True
return spec
class ArtifactRepository(ABC):
"""Artifact 存储仓库抽象基类"""
@abstractmethod
def upload(self, local_path: str, remote_key: str) -> str:
pass
@abstractmethod
def download(self, remote_key: str, local_path: str) -> bool:
pass
@abstractmethod
def exists(self, remote_key: str) -> bool:
pass
@abstractmethod
def delete(self, remote_key: str) -> bool:
pass
@abstractmethod
def list(self, prefix: str) -> List[str]:
pass
@abstractmethod
def get_argo_config(self) -> Dict[str, Any]:
pass
class S3ArtifactRepository(ArtifactRepository):
"""S3 Artifact 存储"""
def __init__(
self,
bucket: str,
endpoint: Optional[str] = None,
region: str = "us-east-1",
access_key: Optional[str] = None,
secret_key: Optional[str] = None,
insecure: bool = False,
key_prefix: str = "artifacts"
):
self.bucket = bucket
self.endpoint = endpoint
self.region = region
self.insecure = insecure
self.key_prefix = key_prefix
# 初始化 S3 客户端
config = {
"region_name": region
}
if endpoint:
config["endpoint_url"] = f"{'http' if insecure else 'https'}://{endpoint}"
if access_key and secret_key:
config["aws_access_key_id"] = access_key
config["aws_secret_access_key"] = secret_key
self.client = boto3.client("s3", **config)
def upload(self, local_path: str, remote_key: str) -> str:
"""上传文件到 S3"""
full_key = f"{self.key_prefix}/{remote_key}"
if os.path.isdir(local_path):
# 上传目录
for root, dirs, files in os.walk(local_path):
for file in files:
file_path = os.path.join(root, file)
rel_path = os.path.relpath(file_path, local_path)
file_key = f"{full_key}/{rel_path}"
self.client.upload_file(file_path, self.bucket, file_key)
else:
# 上传文件
self.client.upload_file(local_path, self.bucket, full_key)
return f"s3://{self.bucket}/{full_key}"
def download(self, remote_key: str, local_path: str) -> bool:
"""从 S3 下载文件"""
full_key = f"{self.key_prefix}/{remote_key}"
try:
# 检查是否是目录(前缀)
response = self.client.list_objects_v2(
Bucket=self.bucket,
Prefix=full_key,
MaxKeys=2
)
objects = response.get("Contents", [])
if len(objects) == 0:
return False
if len(objects) == 1 and objects[0]["Key"] == full_key:
# 单文件
os.makedirs(os.path.dirname(local_path), exist_ok=True)
self.client.download_file(self.bucket, full_key, local_path)
else:
# 目录
os.makedirs(local_path, exist_ok=True)
paginator = self.client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=self.bucket, Prefix=full_key):
for obj in page.get("Contents", []):
key = obj["Key"]
rel_path = key[len(full_key):].lstrip("/")
file_path = os.path.join(local_path, rel_path)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
self.client.download_file(self.bucket, key, file_path)
return True
except Exception as e:
print(f"Download failed: {e}")
return False
def exists(self, remote_key: str) -> bool:
"""检查文件是否存在"""
full_key = f"{self.key_prefix}/{remote_key}"
try:
self.client.head_object(Bucket=self.bucket, Key=full_key)
return True
except:
# 检查是否是前缀
response = self.client.list_objects_v2(
Bucket=self.bucket,
Prefix=full_key,
MaxKeys=1
)
return len(response.get("Contents", [])) > 0
def delete(self, remote_key: str) -> bool:
"""删除文件或目录"""
full_key = f"{self.key_prefix}/{remote_key}"
try:
# 删除所有匹配的对象
paginator = self.client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=self.bucket, Prefix=full_key):
objects = page.get("Contents", [])
if objects:
delete_objects = [{"Key": obj["Key"]} for obj in objects]
self.client.delete_objects(
Bucket=self.bucket,
Delete={"Objects": delete_objects}
)
return True
except Exception as e:
print(f"Delete failed: {e}")
return False
def list(self, prefix: str = "") -> List[str]:
"""列出指定前缀下的所有文件"""
full_prefix = f"{self.key_prefix}/{prefix}" if prefix else self.key_prefix
keys = []
paginator = self.client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=self.bucket, Prefix=full_prefix):
for obj in page.get("Contents", []):
# 去除 key_prefix
key = obj["Key"]
if key.startswith(self.key_prefix):
key = key[len(self.key_prefix):].lstrip("/")
keys.append(key)
return keys
def get_argo_config(self) -> Dict[str, Any]:
"""获取 Argo 配置"""
config = {
"s3": {
"bucket": self.bucket,
"keyPrefix": self.key_prefix,
"region": self.region
}
}
if self.endpoint:
config["s3"]["endpoint"] = self.endpoint
config["s3"]["insecure"] = self.insecure
return config
class ArtifactCache:
"""Artifact 缓存管理"""
def __init__(
self,
repository: ArtifactRepository,
cache_dir: str = "/tmp/artifact-cache"
):
self.repository = repository
self.cache_dir = cache_dir
self.manifest_file = os.path.join(cache_dir, "manifest.json")
os.makedirs(cache_dir, exist_ok=True)
self._load_manifest()
def _load_manifest(self):
"""加载缓存清单"""
if os.path.exists(self.manifest_file):
with open(self.manifest_file) as f:
self.manifest = json.load(f)
else:
self.manifest = {}
def _save_manifest(self):
"""保存缓存清单"""
with open(self.manifest_file, "w") as f:
json.dump(self.manifest, f, indent=2)
def _compute_hash(self, path: str) -> str:
"""计算文件/目录的 hash"""
hasher = hashlib.sha256()
if os.path.isfile(path):
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hasher.update(chunk)
else:
for root, dirs, files in sorted(os.walk(path)):
for file in sorted(files):
file_path = os.path.join(root, file)
rel_path = os.path.relpath(file_path, path)
hasher.update(rel_path.encode())
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hasher.update(chunk)
return hasher.hexdigest()
def get_cached(self, key: str, output_path: str) -> bool:
"""获取缓存的 artifact"""
if key not in self.manifest:
return False
cache_entry = self.manifest[key]
cache_path = os.path.join(self.cache_dir, cache_entry["hash"])
if os.path.exists(cache_path):
# 从本地缓存复制
if os.path.isdir(cache_path):
import shutil
shutil.copytree(cache_path, output_path)
else:
import shutil
shutil.copy2(cache_path, output_path)
return True
# 从远程下载
if self.repository.download(key, output_path):
# 更新本地缓存
if os.path.isdir(output_path):
import shutil
shutil.copytree(output_path, cache_path)
else:
import shutil
shutil.copy2(output_path, cache_path)
return True
return False
def put_cached(self, key: str, local_path: str) -> str:
"""缓存 artifact"""
file_hash = self._compute_hash(local_path)
cache_path = os.path.join(self.cache_dir, file_hash)
# 保存到本地缓存
if not os.path.exists(cache_path):
import shutil
if os.path.isdir(local_path):
shutil.copytree(local_path, cache_path)
else:
shutil.copy2(local_path, cache_path)
# 上传到远程
remote_url = self.repository.upload(local_path, key)
# 更新清单
self.manifest[key] = {
"hash": file_hash,
"remote_url": remote_url,
"timestamp": os.path.getmtime(local_path)
}
self._save_manifest()
return remote_url
def invalidate(self, key: str):
"""使缓存失效"""
if key in self.manifest:
cache_entry = self.manifest[key]
cache_path = os.path.join(self.cache_dir, cache_entry["hash"])
if os.path.exists(cache_path):
import shutil
if os.path.isdir(cache_path):
shutil.rmtree(cache_path)
else:
os.remove(cache_path)
del self.manifest[key]
self._save_manifest()
class WorkflowArtifactManager:
"""工作流 Artifact 管理器"""
def __init__(self, repository: ArtifactRepository):
self.repository = repository
self.cache = ArtifactCache(repository)
def generate_artifact_config(
self,
workflow_name: str,
inputs: List[ArtifactSpec],
outputs: List[ArtifactSpec]
) -> Dict[str, Any]:
"""生成工作流 artifact 配置"""
config = {
"inputs": {
"artifacts": [spec.to_argo_spec() for spec in inputs]
},
"outputs": {
"artifacts": [spec.to_argo_spec() for spec in outputs]
}
}
# 添加 artifact 仓库配置
for artifact in config["outputs"]["artifacts"]:
artifact["s3"] = {
"key": f"{workflow_name}/{{workflow.uid}}/{artifact['name']}"
}
return config
def setup_input_artifacts(
self,
workflow_id: str,
artifacts: List[Dict[str, Any]]
) -> Dict[str, str]:
"""设置输入 artifacts"""
paths = {}
for artifact in artifacts:
name = artifact["name"]
source_key = artifact.get("key", f"inputs/{name}")
local_path = artifact["path"]
if self.cache.get_cached(source_key, local_path):
paths[name] = local_path
else:
raise FileNotFoundError(f"Artifact {name} not found: {source_key}")
return paths
def upload_output_artifacts(
self,
workflow_id: str,
artifacts: List[Dict[str, Any]]
) -> Dict[str, str]:
"""上传输出 artifacts"""
urls = {}
for artifact in artifacts:
name = artifact["name"]
local_path = artifact["path"]
remote_key = f"workflows/{workflow_id}/outputs/{name}"
url = self.cache.put_cached(remote_key, local_path)
urls[name] = url
return urls
# 使用示例
if __name__ == "__main__":
# 创建 S3 仓库
repo = S3ArtifactRepository(
bucket="argo-artifacts",
endpoint="minio.argo:9000",
access_key="admin",
secret_key="password",
insecure=True
)
# 创建管理器
manager = WorkflowArtifactManager(repo)
# 定义 artifacts
inputs = [
ArtifactSpec(name="training-data", path="/data/train"),
ArtifactSpec(name="config", path="/config/params.yaml")
]
outputs = [
ArtifactSpec(name="model", path="/models/trained", compression="gzip"),
ArtifactSpec(name="metrics", path="/metrics/results.json")
]
# 生成配置
config = manager.generate_artifact_config(
workflow_name="ml-training",
inputs=inputs,
outputs=outputs
)
print(yaml.dump(config, default_flow_style=False))
工作流监控与可观测性
Prometheus 指标集成
# argo-monitoring.yaml
apiVersion: v1
kind: ConfigMap
metadata:
name: argo-prometheus-rules
namespace: monitoring
data:
argo-workflows.yaml: |
groups:
- name: argo-workflows
rules:
# 工作流状态指标
- record: argo_workflow_status_phase
expr: |
sum by (namespace, phase) (
argo_workflows_count{phase!=""}
)
# 工作流成功率
- record: argo_workflow_success_rate
expr: |
sum(rate(argo_workflows_count{phase="Succeeded"}[1h])) /
sum(rate(argo_workflows_count{phase=~"Succeeded|Failed|Error"}[1h]))
# 平均工作流执行时间
- record: argo_workflow_duration_seconds_avg
expr: |
avg by (workflow_name) (
argo_workflows_pods_duration_seconds_sum /
argo_workflows_pods_duration_seconds_count
)
# 工作流排队时间
- record: argo_workflow_queue_duration_seconds_avg
expr: |
avg(argo_workflows_queue_duration_seconds_sum /
argo_workflows_queue_duration_seconds_count)
# 告警规则
- alert: ArgoWorkflowHighFailureRate
expr: |
(1 - argo_workflow_success_rate) > 0.1
for: 10m
labels:
severity: warning
annotations:
summary: "High workflow failure rate"
description: "Workflow failure rate is above 10%"
- alert: ArgoWorkflowStuck
expr: |
argo_workflows_count{phase="Running"} > 0
and
changes(argo_workflows_count{phase="Running"}[30m]) == 0
for: 30m
labels:
severity: critical
annotations:
summary: "Workflow appears stuck"
description: "Workflow has been in Running state for 30+ minutes without change"
- alert: ArgoWorkflowQueueBacklog
expr: |
argo_workflows_count{phase="Pending"} > 50
for: 15m
labels:
severity: warning
annotations:
summary: "Large workflow queue backlog"
description: "More than 50 workflows pending"
- alert: ArgoControllerDown
expr: |
up{job="workflow-controller"} == 0
for: 5m
labels:
severity: critical
annotations:
summary: "Argo Workflow Controller is down"
description: "The workflow controller has been down for 5 minutes"
---
# Grafana Dashboard ConfigMap
apiVersion: v1
kind: ConfigMap
metadata:
name: argo-grafana-dashboard
namespace: monitoring
labels:
grafana_dashboard: "1"
data:
argo-workflows.json: |
{
"title": "Argo Workflows Dashboard",
"uid": "argo-workflows",
"panels": [
{
"title": "Workflow Status Distribution",
"type": "piechart",
"targets": [
{
"expr": "sum by (phase) (argo_workflows_count)",
"legendFormat": "{{phase}}"
}
]
},
{
"title": "Workflow Success Rate (1h)",
"type": "gauge",
"targets": [
{
"expr": "argo_workflow_success_rate * 100"
}
],
"options": {
"thresholds": {
"steps": [
{"color": "red", "value": 0},
{"color": "yellow", "value": 80},
{"color": "green", "value": 95}
]
}
}
},
{
"title": "Workflow Execution Time (p95)",
"type": "timeseries",
"targets": [
{
"expr": "histogram_quantile(0.95, rate(argo_workflows_pods_duration_seconds_bucket[5m]))",
"legendFormat": "p95"
},
{
"expr": "histogram_quantile(0.50, rate(argo_workflows_pods_duration_seconds_bucket[5m]))",
"legendFormat": "p50"
}
]
},
{
"title": "Active Workflows",
"type": "timeseries",
"targets": [
{
"expr": "argo_workflows_count{phase='Running'}",
"legendFormat": "Running"
},
{
"expr": "argo_workflows_count{phase='Pending'}",
"legendFormat": "Pending"
}
]
},
{
"title": "Workflow Completion Rate",
"type": "timeseries",
"targets": [
{
"expr": "sum(rate(argo_workflows_count{phase='Succeeded'}[5m]))",
"legendFormat": "Succeeded"
},
{
"expr": "sum(rate(argo_workflows_count{phase='Failed'}[5m]))",
"legendFormat": "Failed"
}
]
},
{
"title": "Resource Usage by Workflow",
"type": "table",
"targets": [
{
"expr": "topk(10, sum by (workflow) (argo_workflows_pods_cpu_usage_seconds_total))",
"legendFormat": "{{workflow}}"
}
]
}
]
}
工作流事件处理
# workflow_event_handler.py
"""
Argo Workflows 事件处理器
处理工作流生命周期事件、告警通知等
"""
import asyncio
import json
from typing import Dict, Any, Callable, List, Optional
from dataclasses import dataclass
from datetime import datetime
import aiohttp
from kubernetes import client, config, watch
from kubernetes.client.rest import ApiException
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class WorkflowEvent:
"""工作流事件"""
event_type: str # ADDED, MODIFIED, DELETED
workflow_name: str
namespace: str
phase: str
message: Optional[str]
start_time: Optional[datetime]
finish_time: Optional[datetime]
resource_duration: Optional[int]
raw_object: Dict[str, Any]
@property
def duration_seconds(self) -> Optional[float]:
if self.start_time and self.finish_time:
return (self.finish_time - self.start_time).total_seconds()
return None
@classmethod
def from_k8s_event(cls, event: Dict[str, Any]) -> "WorkflowEvent":
obj = event["object"]
metadata = obj.get("metadata", {})
status = obj.get("status", {})
start_time = None
finish_time = None
if status.get("startedAt"):
start_time = datetime.fromisoformat(
status["startedAt"].replace("Z", "+00:00")
)
if status.get("finishedAt"):
finish_time = datetime.fromisoformat(
status["finishedAt"].replace("Z", "+00:00")
)
return cls(
event_type=event["type"],
workflow_name=metadata.get("name", ""),
namespace=metadata.get("namespace", ""),
phase=status.get("phase", "Unknown"),
message=status.get("message"),
start_time=start_time,
finish_time=finish_time,
resource_duration=status.get("resourcesDuration", {}).get("cpu"),
raw_object=obj
)
class NotificationChannel:
"""通知渠道基类"""
async def send(self, event: WorkflowEvent, message: str):
raise NotImplementedError
class SlackNotificationChannel(NotificationChannel):
"""Slack 通知渠道"""
def __init__(self, webhook_url: str):
self.webhook_url = webhook_url
async def send(self, event: WorkflowEvent, message: str):
color_map = {
"Succeeded": "good",
"Failed": "danger",
"Error": "danger",
"Running": "#439FE0"
}
payload = {
"attachments": [{
"color": color_map.get(event.phase, "#808080"),
"title": f"Workflow {event.phase}: {event.workflow_name}",
"text": message,
"fields": [
{"title": "Namespace", "value": event.namespace, "short": True},
{"title": "Phase", "value": event.phase, "short": True}
],
"footer": "Argo Workflows",
"ts": int(datetime.now().timestamp())
}]
}
if event.duration_seconds:
payload["attachments"][0]["fields"].append({
"title": "Duration",
"value": f"{event.duration_seconds:.1f}s",
"short": True
})
async with aiohttp.ClientSession() as session:
async with session.post(self.webhook_url, json=payload) as resp:
if resp.status != 200:
logger.error(f"Slack notification failed: {await resp.text()}")
class WebhookNotificationChannel(NotificationChannel):
"""通用 Webhook 通知渠道"""
def __init__(self, url: str, headers: Optional[Dict[str, str]] = None):
self.url = url
self.headers = headers or {}
async def send(self, event: WorkflowEvent, message: str):
payload = {
"workflow_name": event.workflow_name,
"namespace": event.namespace,
"phase": event.phase,
"message": message,
"event_type": event.event_type,
"timestamp": datetime.now().isoformat(),
"duration_seconds": event.duration_seconds
}
async with aiohttp.ClientSession() as session:
async with session.post(
self.url,
json=payload,
headers=self.headers
) as resp:
if resp.status >= 400:
logger.error(f"Webhook notification failed: {await resp.text()}")
class WorkflowEventHandler:
"""工作流事件处理器"""
def __init__(self, namespace: Optional[str] = None):
self.namespace = namespace
self.handlers: Dict[str, List[Callable]] = {
"Running": [],
"Succeeded": [],
"Failed": [],
"Error": [],
"Pending": []
}
self.notification_channels: List[NotificationChannel] = []
self._setup_k8s_client()
def _setup_k8s_client(self):
"""设置 Kubernetes 客户端"""
try:
config.load_incluster_config()
except:
config.load_kube_config()
self.custom_api = client.CustomObjectsApi()
def on_phase(self, phase: str):
"""装饰器:注册事件处理函数"""
def decorator(func: Callable):
if phase in self.handlers:
self.handlers[phase].append(func)
return func
return decorator
def add_notification_channel(self, channel: NotificationChannel):
"""添加通知渠道"""
self.notification_channels.append(channel)
async def _handle_event(self, event: WorkflowEvent):
"""处理单个事件"""
phase = event.phase
# 调用注册的处理函数
for handler in self.handlers.get(phase, []):
try:
if asyncio.iscoroutinefunction(handler):
await handler(event)
else:
handler(event)
except Exception as e:
logger.error(f"Handler error: {e}")
# 发送通知
if phase in ["Succeeded", "Failed", "Error"]:
message = self._format_message(event)
for channel in self.notification_channels:
try:
await channel.send(event, message)
except Exception as e:
logger.error(f"Notification error: {e}")
def _format_message(self, event: WorkflowEvent) -> str:
"""格式化通知消息"""
lines = [
f"Workflow: {event.workflow_name}",
f"Status: {event.phase}"
]
if event.message:
lines.append(f"Message: {event.message}")
if event.duration_seconds:
minutes, seconds = divmod(int(event.duration_seconds), 60)
lines.append(f"Duration: {minutes}m {seconds}s")
return "\n".join(lines)
async def watch_workflows(self):
"""监听工作流事件"""
logger.info(f"Starting workflow watcher for namespace: {self.namespace or 'all'}")
while True:
try:
w = watch.Watch()
kwargs = {
"group": "argoproj.io",
"version": "v1alpha1",
"plural": "workflows"
}
if self.namespace:
kwargs["namespace"] = self.namespace
stream = w.stream(
self.custom_api.list_namespaced_custom_object,
**kwargs
)
else:
stream = w.stream(
self.custom_api.list_cluster_custom_object,
**kwargs
)
for event in stream:
try:
workflow_event = WorkflowEvent.from_k8s_event(event)
await self._handle_event(workflow_event)
except Exception as e:
logger.error(f"Event processing error: {e}")
except ApiException as e:
logger.error(f"K8s API error: {e}")
await asyncio.sleep(5)
except Exception as e:
logger.error(f"Watch error: {e}")
await asyncio.sleep(5)
class WorkflowMetricsCollector:
"""工作流指标收集器"""
def __init__(self, handler: WorkflowEventHandler):
self.handler = handler
self.metrics = {
"total_workflows": 0,
"succeeded": 0,
"failed": 0,
"running": 0,
"total_duration": 0.0
}
self._register_handlers()
def _register_handlers(self):
"""注册指标收集处理函数"""
@self.handler.on_phase("Running")
def on_running(event: WorkflowEvent):
self.metrics["running"] += 1
self.metrics["total_workflows"] += 1
@self.handler.on_phase("Succeeded")
def on_succeeded(event: WorkflowEvent):
self.metrics["running"] = max(0, self.metrics["running"] - 1)
self.metrics["succeeded"] += 1
if event.duration_seconds:
self.metrics["total_duration"] += event.duration_seconds
@self.handler.on_phase("Failed")
def on_failed(event: WorkflowEvent):
self.metrics["running"] = max(0, self.metrics["running"] - 1)
self.metrics["failed"] += 1
if event.duration_seconds:
self.metrics["total_duration"] += event.duration_seconds
def get_metrics(self) -> Dict[str, Any]:
"""获取当前指标"""
completed = self.metrics["succeeded"] + self.metrics["failed"]
return {
**self.metrics,
"success_rate": (
self.metrics["succeeded"] / completed if completed > 0 else 0
),
"avg_duration": (
self.metrics["total_duration"] / completed if completed > 0 else 0
)
}
# 使用示例
async def main():
# 创建事件处理器
handler = WorkflowEventHandler(namespace="argo-workflows")
# 添加 Slack 通知
handler.add_notification_channel(
SlackNotificationChannel(
webhook_url="https://hooks.slack.com/services/xxx"
)
)
# 添加自定义 Webhook
handler.add_notification_channel(
WebhookNotificationChannel(
url="https://api.example.com/workflow-events",
headers={"Authorization": "Bearer token"}
)
)
# 创建指标收集器
metrics_collector = WorkflowMetricsCollector(handler)
# 注册自定义处理函数
@handler.on_phase("Failed")
async def handle_failure(event: WorkflowEvent):
logger.error(f"Workflow failed: {event.workflow_name}")
logger.error(f"Message: {event.message}")
# 可以在这里添加自动重试逻辑
@handler.on_phase("Succeeded")
async def handle_success(event: WorkflowEvent):
logger.info(f"Workflow succeeded: {event.workflow_name}")
metrics = metrics_collector.get_metrics()
logger.info(f"Current success rate: {metrics['success_rate']:.2%}")
# 启动监听
await handler.watch_workflows()
if __name__ == "__main__":
asyncio.run(main())
最佳实践
工作流设计原则
# best-practices-workflow.yaml
apiVersion: argoproj.io/v1alpha1
kind: WorkflowTemplate
metadata:
name: ml-best-practices
namespace: argo-workflows
labels:
app.kubernetes.io/name: ml-pipeline
app.kubernetes.io/component: workflow
spec:
# 1. 设置合理的超时和重试
entrypoint: main
activeDeadlineSeconds: 86400 # 24小时全局超时
# 2. 资源清理策略
ttlStrategy:
secondsAfterCompletion: 3600
secondsAfterSuccess: 3600
secondsAfterFailure: 86400
podGC:
strategy: OnPodCompletion
labelSelector:
matchLabels:
cleanup: "true"
# 3. 服务账号和安全
serviceAccountName: ml-workflow-sa
securityContext:
runAsNonRoot: true
runAsUser: 1000
# 4. 节点选择和容忍
nodeSelector:
node-type: ml-worker
tolerations:
- key: "nvidia.com/gpu"
operator: "Exists"
effect: "NoSchedule"
# 5. 资源配额
podSpecPatch: |
containers:
- name: main
resources:
requests:
cpu: "100m"
memory: "256Mi"
templates:
- name: main
dag:
tasks:
- name: step1
template: reliable-step
arguments:
parameters:
- name: step-name
value: "data-prep"
- name: step2
template: gpu-step
dependencies: [step1]
# 6. 可靠的步骤模板
- name: reliable-step
inputs:
parameters:
- name: step-name
# 重试策略
retryStrategy:
limit: 3
retryPolicy: "Always"
backoff:
duration: "10s"
factor: 2
maxDuration: "5m"
# 超时设置
activeDeadlineSeconds: 3600
container:
image: ml-platform/worker:v1.0
command: [python, process.py]
args:
- --step={{inputs.parameters.step-name}}
# 资源限制
resources:
requests:
cpu: "500m"
memory: "1Gi"
limits:
cpu: "2"
memory: "4Gi"
# 存活检查
livenessProbe:
exec:
command: [cat, /tmp/healthy]
initialDelaySeconds: 30
periodSeconds: 30
# 优雅退出
lifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 5"]
# Memoization 缓存
memoize:
key: "{{inputs.parameters.step-name}}-{{workflow.parameters.data-version}}"
maxAge: "24h"
cache:
configMap:
name: workflow-cache
# 7. GPU 步骤模板
- name: gpu-step
activeDeadlineSeconds: 7200
retryStrategy:
limit: 2
retryPolicy: "OnFailure"
podSpecPatch: |
containers:
- name: main
resources:
limits:
nvidia.com/gpu: 1
container:
image: ml-platform/gpu-worker:v1.0
command: [python, train.py]
env:
- name: CUDA_VISIBLE_DEVICES
value: "0"
- name: NCCL_DEBUG
value: "INFO"
volumeMounts:
- name: shm
mountPath: /dev/shm
volumes:
- name: shm
emptyDir:
medium: Memory
sizeLimit: "16Gi"
---
# 工作流 RBAC
apiVersion: v1
kind: ServiceAccount
metadata:
name: ml-workflow-sa
namespace: argo-workflows
---
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: ml-workflow-role
namespace: argo-workflows
rules:
- apiGroups: [""]
resources: ["pods", "pods/log"]
verbs: ["get", "list", "watch"]
- apiGroups: [""]
resources: ["configmaps"]
verbs: ["get", "list", "watch", "create", "update"]
- apiGroups: [""]
resources: ["persistentvolumeclaims"]
verbs: ["get", "list", "watch", "create", "delete"]
- apiGroups: ["argoproj.io"]
resources: ["workflows", "workflowtemplates"]
verbs: ["get", "list", "watch"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: ml-workflow-rolebinding
namespace: argo-workflows
subjects:
- kind: ServiceAccount
name: ml-workflow-sa
roleRef:
kind: Role
name: ml-workflow-role
apiGroup: rbac.authorization.k8s.io
总结
Argo Workflows 作为 Kubernetes 原生工作流引擎的核心优势:
云原生设计
- 完全基于 Kubernetes CRD
- 利用 K8s 原生资源管理
- 易于与云原生生态集成
强大的 DAG 支持
- 灵活的任务依赖定义
- 动态 DAG 生成
- 条件分支和循环
企业级特性
- 完善的 RBAC 控制
- 多租户支持
- 审计日志
ML/AI 友好
- GPU 资源调度
- 分布式训练支持
- Artifact 管理
可观测性
- Prometheus 指标
- 结构化日志
- 事件驱动通知
在下一章节中,我们将探讨数据版本管理,了解如何使用 DVC、LakeFS 等工具管理 ML 项目中的数据版本。