【高级应用】Day27:复杂任务编排与执行–DAG工作流设计艺术
章节导语
当你需要构建一个复杂的AI系统时,单一的任务流程远远不够。
想象一个AI数据分析平台:它需要先收集数据、清洗数据、分析数据、生成报告、发送邮件……这些任务之间有复杂的依赖关系——有些可以并行,有些必须等待前面的任务完成。而且,某个环节失败了怎么办?重试?跳过?还是回滚?
这就是DAG工作流的价值所在。本文系统讲解复杂任务编排与执行,从DAG设计原理到工作流引擎开发,从代码实现到最佳实践,帮助你掌握构建复杂AI系统的核心能力。
一、为什么需要工作流编排
1.1 简单流程的局限
简单的线性流程(依次执行A→B→C→D)无法应对复杂场景:
任务间依赖复杂:D必须等A和B完成,E必须等C完成……手写代码管理这些依赖会变成灾难。
错误处理困难:某个任务失败了,如何重试?如何回滚?
状态管理混乱:运行到一半被中断,如何恢复?哪些任务已经完成?
监控告警缺失:只知道最终成功或失败,不知道具体卡在哪里。
1.2 DAG的优势
DAG(Directed Acyclic Graph,有向无环图)完美解决这些问题:
清晰的依赖表达:每个任务声明自己的前置依赖,系统自动计算执行顺序。
状态持久化:每个任务完成后状态持久化,重启后可恢复。
自动重试:任务失败可以配置重试策略。
并行优化:没有依赖的任务自动并行执行,最大化效率。
1.3 DAG的数学基础
DAG是一个有向图,满足:
有向:边有方向,表示任务间的依赖关系。
无环:不存在环,避免死循环。
可拓扑排序:可以线性排列,满足所有依赖约束。
import numpy as np
from typing import List, Dict, Set, Optional
from dataclasses import dataclass, field
from collections import deque
import json
@dataclass
class DAGNode:
"""DAG节点"""
id: str
name: str
dependencies: List[str] = field(default_factory=list) # 前置依赖
status: str = "pending" # pending/running/success/failed/skipped
result: any = None
error: Optional[str] = None
retry_count: int = 0
max_retries: int = 3
class DAG:
"""DAG有向无环图"""
def __init__(self):
self.nodes: Dict[str, DAGNode] = {}
def add_node(self, node: DAGNode):
"""添加节点"""
self.nodes[node.id] = node
def validate(self) -> bool:
"""验证DAG是否有效"""
# 检查循环依赖
for node_id in self.nodes:
if self._has_cycle_from(node_id):
return False
return True
def _has_cycle_from(self, start_id: str) -> bool:
"""检查从start_id开始是否存在循环"""
visited = set()
rec_stack = set()
def dfs(node_id):
visited.add(node_id)
rec_stack.add(node_id)
node = self.nodes.get(node_id)
if not node:
return False
for dep_id in node.dependencies:
if dep_id not in self.nodes:
continue
if dep_id in rec_stack:
return True
if dep_id not in visited and dfs(dep_id):
return True
rec_stack.remove(node_id)
return False
return dfs(start_id)
def topological_sort(self) -> List[str]:
"""拓扑排序 - 返回任务的执行顺序"""
in_degree = {node_id: 0 for node_id in self.nodes}
# 计算入度
for node in self.nodes.values():
for dep_id in node.dependencies:
if dep_id in in_degree:
in_degree[dep_id] += 1
# 入度为0的节点先执行
queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0])
result = []
while queue:
node_id = queue.popleft()
result.append(node_id)
# 更新相关节点的入度
for other_id, node in self.nodes.items():
if node_id in node.dependencies:
in_degree[other_id] -= 1
if in_degree[other_id] == 0:
queue.append(other_id)
if len(result) != len(self.nodes):
raise ValueError("存在循环依赖,DAG无效")
return result
def get_ready_tasks(self, completed: Set[str]) -> List[str]:
"""获取当前可执行的任务"""
ready = []
for node_id, node in self.nodes.items():
if node_id in completed:
continue
if node.status in ["running", "success", "failed"]:
continue
# 检查所有依赖是否都已完成
deps_done = all(dep_id in completed for dep_id in node.dependencies)
if deps_done:
ready.append(node_id)
return ready
def to_json(self) -> str:
"""导出为JSON"""
return json.dumps({
node_id: {
"id": node.id,
"name": node.name,
"dependencies": node.dependencies,
"status": node.status
}
for node_id, node in self.nodes.items()
}, indent=2)
# 使用示例
dag = DAG()
# 添加节点并声明依赖
dag.add_node(DAGNode(id="a", name="数据采集", dependencies=[]))
dag.add_node(DAGNode(id="b", name="数据清洗", dependencies=["a"]))
dag.add_node(DAGNode(id="c", name="数据分析", dependencies=["a"]))
dag.add_node(DAGNode(id="d", name="生成报告", dependencies=["b", "c"]))
dag.add_node(DAGNode(id="e", name="发送邮件", dependencies=["d"]))
# 验证
print(f"DAG有效: {dag.validate()}")
# 拓扑排序
execution_order = dag.topological_sort()
print(f"执行顺序: {' -> '.join(execution_order)}")
# 获取可执行任务
print(f"初始可执行: {dag.get_ready_tasks(set())}")
print(f"a完成后可执行: {dag.get_ready_tasks({'a'})}")
print(f"b,c完成后可执行: {dag.get_ready_tasks({'a', 'b', 'c'})}")
二、工作流引擎核心实现
2.1 整体架构
工作流引擎的四大组件:
调度器(Scheduler):负责任务的调度和分发。
执行器(Executor):负责实际执行任务。
状态存储(State Store):存储任务和流程的状态。
事件总线(Event Bus):任务状态变化的发布订阅。
2.2 核心引擎实现
import time
import threading
from typing import Dict, Callable, Any, Optional
from dataclasses import dataclass, field
from enum import Enum
from collections import defaultdict
import logging
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
PENDING = "pending"
QUEUED = "queued"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
RETRY = "retry"
SKIPPED = "skipped"
@dataclass
class TaskResult:
task_id: str
status: TaskStatus
output: Any = None
error: Optional[str] = None
duration: float = 0
class WorkflowEventBus:
"""事件总线"""
def __init__(self):
self.listeners: Dict[str, list] = defaultdict(list)
def subscribe(self, event_type: str, callback: Callable):
"""订阅事件"""
self.listeners[event_type].append(callback)
def publish(self, event_type: str, data: Any):
"""发布事件"""
for callback in self.listeners.get(event_type, []):
try:
callback(data)
except Exception as e:
logger.error(f"事件处理错误: {e}")
class WorkflowEngine:
"""工作流引擎"""
def __init__(self, max_workers: int = 4):
self.dag: Optional[DAG] = None
self.executor_map: Dict[str, Callable] = {}
self.state_store: Dict[str, TaskStatus] = {}
self.max_workers = max_workers
self.event_bus = WorkflowEventBus()
self.lock = threading.Lock()
self.running_tasks: Dict[str, threading.Thread] = {}
def register_task(self, task_id: str, executor: Callable):
"""注册任务执行器"""
self.executor_map[task_id] = executor
def load_dag(self, dag: DAG):
"""加载DAG"""
self.dag = dag
self.state_store.clear()
for node_id in dag.nodes:
self.state_store[node_id] = TaskStatus.PENDING
def run(self, context: Dict = None) -> Dict[str, TaskResult]:
"""运行工作流"""
if not self.dag:
raise ValueError("DAG未加载")
if not self.dag.validate():
raise ValueError("DAG无效")
context = context or {}
results = {}
# 订阅任务完成事件
self.event_bus.subscribe("task_completed", self._on_task_completed)
# 主循环
while True:
completed = {tid for tid, status in self.state_store.items()
if status in [TaskStatus.SUCCESS, TaskStatus.FAILED, TaskStatus.SKIPPED]}
# 检查是否全部完成
if len(completed) == len(self.dag.nodes):
break
# 获取可执行任务
ready_tasks = self.dag.get_ready_tasks(completed)
if not ready_tasks and len(completed) < len(self.dag.nodes):
# 没有可执行任务但未全部完成,可能是死锁
logger.error("工作流死锁")
break
# 启动就绪的任务
for task_id in ready_tasks:
if self.state_store[task_id] == TaskStatus.PENDING:
self._execute_task(task_id, context, results)
time.sleep(0.1) # 避免CPU空转
return results
def _execute_task(self, task_id: str, context: Dict, results: Dict):
"""执行任务"""
with self.lock:
if self.state_store[task_id] != TaskStatus.PENDING:
return
self.state_store[task_id] = TaskStatus.QUEUED
node = self.dag.nodes[task_id]
def run_in_thread():
try:
self.state_store[task_id] = TaskStatus.RUNNING
logger.info(f"开始执行任务: {task_id}")
start_time = time.time()
if task_id in self.executor_map:
output = self.executor_map[task_id](context)
else:
output = None
duration = time.time() - start_time
result = TaskResult(
task_id=task_id,
status=TaskStatus.SUCCESS,
output=output,
duration=duration
)
with self.lock:
self.state_store[task_id] = TaskStatus.SUCCESS
results[task_id] = result
self.event_bus.publish("task_completed", result)
logger.info(f"任务完成: {task_id}, 耗时: {duration:.2f}s")
except Exception as e:
logger.error(f"任务失败: {task_id}, 错误: {e}")
result = TaskResult(
task_id=task_id,
status=TaskStatus.FAILED,
error=str(e)
)
with self.lock:
self.state_store[task_id] = TaskStatus.FAILED
results[task_id] = result
self.event_bus.publish("task_completed", result)
thread = threading.Thread(target=run_in_thread)
thread.start()
with self.lock:
self.running_tasks[task_id] = thread
def _on_task_completed(self, result: TaskResult):
"""任务完成回调"""
logger.info(f"收到任务完成事件: {result.task_id}")
def get_status(self) -> Dict[str, TaskStatus]:
"""获取状态"""
return dict(self.state_store)
# 使用示例
def data_collection(context):
"""数据采集任务"""
time.sleep(1)
return {"data": [1, 2, 3, 4, 5]}
def data_cleaning(context):
"""数据清洗任务"""
time.sleep(1)
data = context.get("collection", {}).get("data", [])
return {"cleaned": [x for x in data if x > 2]}
def data_analysis(context):
"""数据分析任务"""
time.sleep(1)
cleaned = context.get("cleaning", {}).get("cleaned", [])
return {"mean": sum(cleaned) / len(cleaned), "count": len(cleaned)}
def generate_report(context):
"""生成报告"""
time.sleep(1)
analysis = context.get("analysis", {})
return f"分析完成:均值={analysis.get('mean')}, 数量={analysis.get('count')}"
def send_email(context):
"""发送邮件"""
time.sleep(0.5)
report = context.get("report", "")
return f"邮件已发送: {report[:50]}..."
# 创建引擎
engine = WorkflowEngine()
# 注册任务
engine.register_task("collection", data_collection)
engine.register_task("cleaning", data_cleaning)
engine.register_task("analysis", data_analysis)
engine.register_task("report", generate_report)
engine.register_task("email", send_email)
# 创建DAG
dag = DAG()
dag.add_node(DAGNode(id="collection", name="数据采集", dependencies=[]))
dag.add_node(DAGNode(id="cleaning", name="数据清洗", dependencies=["collection"]))
dag.add_node(DAGNode(id="analysis", name="数据分析", dependencies=["collection"]))
dag.add_node(DAGNode(id="report", name="生成报告", dependencies=["cleaning", "analysis"]))
dag.add_node(DAGNode(id="email", name="发送邮件", dependencies=["report"]))
# 加载并运行
engine.load_dag(dag)
results = engine.run()
# 输出结果
print("\n执行结果:")
for task_id, result in results.items():
status = result.status.value
duration = f"{result.duration:.2f}s" if result.duration else "-"
print(f" {task_id}: {status} ({duration})")

三、任务依赖与并行优化
2.1 动态任务生成
有些场景下,任务数量在运行前是未知的,需要动态生成:
from typing import List, Dict
from dataclasses import dataclass, field
import time
@dataclass
class DynamicTask:
"""动态任务定义"""
task_id: str
params: Dict
dependencies: List[str] = field(default_factory=list)
class DynamicDAGBuilder:
"""动态DAG构建器"""
def __init__(self):
self.tasks: Dict[str, DynamicTask] = {}
self.context: Dict = {}
def add_task(self, task: DynamicTask):
"""添加任务"""
self.tasks[task.task_id] = task
def add_dynamic_tasks(self, template: str, items: List[Dict],
depends_on: str = None):
"""根据数据动态生成任务"""
for i, item in enumerate(items):
task_id = f"{template}_{i}"
# 依赖处理
dependencies = []
if depends_on:
# 依赖所有前置任务
dependencies.append(depends_on)
task = DynamicTask(
task_id=task_id,
params=item,
dependencies=dependencies
)
self.add_task(task)
def build_dag(self) -> DAG:
"""构建DAG"""
dag = DAG()
for task_id, task in self.tasks.items():
node = DAGNode(
id=task_id,
name=task.task_id,
dependencies=task.dependencies
)
dag.add_node(node)
return dag
def collect_results(self, task_results: Dict[str, Any],
template: str) -> List[Dict]:
"""收集动态任务的结果"""
results = []
for task_id, result in task_results.items():
if task_id.startswith(template):
results.append(result)
return results
# 使用示例
builder = DynamicDAGBuilder()
# 添加静态任务
builder.add_task(DynamicTask(task_id="fetch_urls", params={}, dependencies=[]))
# 动态添加抓取任务(从URL列表)
urls = [
{"url": "https://example.com/1", "name": "页面1"},
{"url": "https://example.com/2", "name": "页面2"},
{"url": "https://example.com/3", "name": "页面3"},
]
builder.add_dynamic_tasks("scrape", urls, depends_on="fetch_urls")
# 动态添加处理任务
builder.add_task(DynamicTask(
task_id="aggregate",
params={},
dependencies=[f"scrape_{i}" for i in range(len(urls))]
))
# 构建DAG
dag = builder.build_dag()
print(f"DAG节点数: {len(dag.nodes)}")
print(f"执行顺序: {' -> '.join(dag.topological_sort())}")
2.2 并行度控制
任务并行执行需要控制并发度,避免资源耗尽:
import threading
from concurrent.futures import ThreadPoolExecutor, Future
from typing import Dict, Callable, Any, List
import time
class Semaphore:
"""信号量"""
def __init__(self, value: int):
self.value = value
self.lock = threading.Lock()
self.condition = threading.Condition()
self.waiting = 0
def acquire(self):
with self.condition:
while self.value <= 0:
self.waiting += 1
self.condition.wait()
self.waiting -= 1
self.value -= 1
def release(self):
with self.condition:
self.value += 1
class ParallelExecutor:
"""并行执行器"""
def __init__(self, max_parallel: int = 4):
self.max_parallel = max_parallel
self.semaphore = Semaphore(max_parallel)
self.executor = ThreadPoolExecutor(max_workers=max_parallel)
self.futures: Dict[str, Future] = {}
def submit(self, task_id: str, fn: Callable, *args, **kwargs) -> Future:
"""提交任务"""
def wrapped():
self.semaphore.acquire()
try:
return fn(*args, **kwargs)
finally:
self.semaphore.release()
future = self.executor.submit(wrapped)
self.futures[task_id] = future
return future
def wait_all(self) -> Dict[str, Any]:
"""等待所有任务完成"""
results = {}
for task_id, future in self.futures.items():
try:
results[task_id] = future.result()
except Exception as e:
results[task_id] = {"error": str(e)}
return results
def shutdown(self):
"""关闭执行器"""
self.executor.shutdown(wait=True)
# 使用示例
def process_item(item):
"""处理单个任务"""
time.sleep(1)
return f"processed_{item}"
executor = ParallelExecutor(max_parallel=3)
# 提交10个任务,但最多并行3个
for i in range(10):
executor.submit(f"task_{i}", process_item, i)
print("等待任务完成...")
results = executor.wait_all()
print(f"完成 {len(results)} 个任务")
四、错误处理与重试策略
4.1 重试策略
任务失败时的重试策略:
固定间隔:每次失败后等待固定时间重试。
指数退避:每次失败后等待时间翻倍。
熔断器:失败次数过多时暂时跳过任务。
4.2 重试机制实现
import time
from typing import Callable, Any, Optional
from dataclasses import dataclass
from enum import Enum
import logging
logger = logging.getLogger(__name__)
class BackoffStrategy(Enum):
FIXED = "fixed"
LINEAR = "linear"
EXPONENTIAL = "exponential"
@dataclass
class RetryConfig:
"""重试配置"""
max_retries: int = 3
initial_delay: float = 1.0
max_delay: float = 60.0
backoff: BackoffStrategy = BackoffStrategy.EXPONENTIAL
multiplier: float = 2.0
retriable_errors: tuple = () # 可重试的错误类型
class RetryableTask:
"""可重试任务包装器"""
def __init__(self, config: RetryConfig = None):
self.config = config or RetryConfig()
def execute(self, fn: Callable, *args, **kwargs) -> Any:
"""执行任务,自动重试"""
last_error = None
for attempt in range(self.config.max_retries + 1):
try:
result = fn(*args, **kwargs)
if attempt > 0:
logger.info(f"任务成功重试 (尝试 {attempt + 1})")
return result
except Exception as e:
last_error = e
# 检查是否是可重试的错误
if self.config.retriable_errors:
if not any(isinstance(e, err_type) for err_type in self.config.retriable_errors):
logger.error(f"任务失败且不可重试: {e}")
raise
if attempt < self.config.max_retries:
delay = self._calculate_delay(attempt)
logger.warning(f"任务失败,{delay:.1f}秒后重试 (尝试 {attempt + 1}/{self.config.max_retries}): {e}")
time.sleep(delay)
else:
logger.error(f"任务重试次数用尽: {e}")
raise last_error
def _calculate_delay(self, attempt: int) -> float:
"""计算重试延迟"""
if self.config.backoff == BackoffStrategy.FIXED:
delay = self.config.initial_delay
elif self.config.backoff == BackoffStrategy.LINEAR:
delay = self.config.initial_delay * (1 + attempt)
elif self.config.backoff == BackoffStrategy.EXPONENTIAL:
delay = self.config.initial_delay * (self.config.multiplier ** attempt)
else:
delay = self.config.initial_delay
return min(delay, self.config.max_delay)
# 使用示例
def unreliable_task():
"""模拟不可靠任务"""
import random
if random.random() < 0.7:
raise ConnectionError("网络连接失败")
return "任务成功完成"
# 配置
config = RetryConfig(
max_retries=3,
initial_delay=1.0,
backoff=BackoffStrategy.EXPONENTIAL,
multiplier=2.0,
retriable_errors=(ConnectionError, TimeoutError)
)
retry_handler = RetryableTask(config)
# 执行
for i in range(5):
try:
result = retry_handler.execute(unreliable_task)
print(f"第{i+1}次尝试: {result}")
break
except ConnectionError as e:
print(f"第{i+1}次尝试失败: {e}")

五、主流工作流框架对比
5.1 框架对比
| 框架 | 特点 | 适用场景 |
|---|---|---|
| Airflow | 功能完善,生态丰富 | 数据管道,定时任务 |
| Prefect | 现代化API,易用性好 | 数据工程,ML工作流 |
| Dagster | 代码即配置,类型安全 | 复杂ML系统 |
| Argo Workflows | K8s原生,容器化执行 | 云原生AI平台 |
| Temporal | 持久化执行,强一致 | 关键业务逻辑 |
5.2 选型建议
数据管道:优先选择Airflow或Prefect,社区成熟。
ML工作流:Prefect或Dagster,对ML场景支持更好。
云原生平台:Argo Workflows,与K8s深度集成。
关键业务:Temporal,保证执行可靠性。

六、实战:构建AI数据处理流水线
"""
AI数据处理完整流水线示例
"""
class AIDataPipeline:
"""AI数据处理流水线"""
def __init__(self, config: dict):
self.config = config
self.engine = WorkflowEngine()
self._register_tasks()
def _register_tasks(self):
"""注册所有任务"""
self.engine.register_task("fetch_data", self.fetch_data)
self.engine.register_task("preprocess", self.preprocess)
self.engine.register_task("feature_engineering", self.feature_engineering)
self.engine.register_task("train_model", self.train_model)
self.engine.register_task("evaluate_model", self.evaluate_model)
self.engine.register_task("deploy_model", self.deploy_model)
self.engine.register_task("notify", self.notify)
def build_dag(self):
"""构建DAG"""
dag = DAG()
dag.add_node(DAGNode(id="fetch_data", name="数据采集", dependencies=[]))
dag.add_node(DAGNode(id="preprocess", name="数据清洗", dependencies=["fetch_data"]))
dag.add_node(DAGNode(id="feature_engineering", name="特征工程", dependencies=["preprocess"]))
dag.add_node(DAGNode(id="train_model", name="模型训练", dependencies=["feature_engineering"]))
dag.add_node(DAGNode(id="evaluate_model", name="模型评估", dependencies=["train_model"]))
dag.add_node(DAGNode(id="deploy_model", name="模型部署", dependencies=["evaluate_model"]))
dag.add_node(DAGNode(id="notify", name="通知", dependencies=["deploy_model"]))
self.engine.load_dag(dag)
def fetch_data(self, context):
"""数据采集"""
print("1. 采集数据...")
return {"rows": 10000}
def preprocess(self, context):
"""数据清洗"""
rows = context.get("fetch_data", {}).get("rows", 0)
print(f"2. 清洗数据 ({rows} 行)...")
return {"cleaned_rows": int(rows * 0.9)}
def feature_engineering(self, context):
"""特征工程"""
rows = context.get("preprocess", {}).get("cleaned_rows", 0)
print(f"3. 特征工程 ({rows} 样本)...")
return {"features": 50}
def train_model(self, context):
"""模型训练"""
features = context.get("feature_engineering", {}).get("features", 0)
print(f"4. 训练模型 ({features} 特征)...")
return {"model_version": "v1.0"}
def evaluate_model(self, context):
"""模型评估"""
version = context.get("train_model", {}).get("model_version", "")
print(f"5. 评估模型 ({version})...")
return {"accuracy": 0.95, "passed": True}
def deploy_model(self, context):
"""模型部署"""
eval_result = context.get("evaluate_model", {})
if not eval_result.get("passed"):
raise ValueError("模型评估未通过")
print("6. 部署模型...")
return {"endpoint": "https://api.example.com/model"}
def notify(self, context):
"""发送通知"""
endpoint = context.get("deploy_model", {}).get("endpoint", "")
print(f"7. 发送通知 (endpoint: {endpoint})...")
return {"notified": True}
def run(self):
"""运行流水线"""
self.build_dag()
print("=" * 50)
print("开始执行AI数据处理流水线")
print("=" * 50)
start_time = time.time()
results = self.engine.run()
duration = time.time() - start_time
print("\n" + "=" * 50)
print(f"流水线执行完成,耗时: {duration:.2f}s")
print("=" * 50)
return results
# 运行
pipeline = AIDataPipeline({})
results = pipeline.run()
for task_id, result in results.items():
status = "✅" if result.status == TaskStatus.SUCCESS else "❌"
print(f"{status} {task_id}: {result.status.value}")
七、总结
DAG是复杂任务编排的核心。它用图的形式清晰表达任务间的依赖关系,让复杂的业务流程变得可管理。
工作流引擎让DAG活起来。调度、执行、状态管理、错误处理,工作流引擎帮你搞定一切。
并行优化提升效率。没有依赖的任务并行执行,充分利用计算资源。
错误处理是生产级的关键。重试机制、熔断策略、回滚方案,让工作流在生产环境中稳定运行。
延伸阅读
- Airflow官方文档
- Prefect文档
- Dagster文档
- Temporal文档
课后练习
基础题:使用DAG表示法描述一个数据ETL流程。
进阶题:实现一个支持动态任务生成的工作流引擎。
挑战题:为工作流引擎添加分布式执行支持。