02-数据集工程
概述
数据集是机器学习的基石,"Garbage In, Garbage Out" 是 ML 领域的金科玉律。本文深入探讨数据集构建、标注管理、质量控制和数据增强等关键实践。
数据集生命周期
数据集管理全景
┌─────────────────────────────────────────────────────────────────────┐
│ 数据集生命周期管理 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 数据采集 │─►│ 数据清洗 │─►│ 数据标注 │─►│ 质量验证 │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 数据增强 │◄─│ 数据划分 │◄─│ 版本管理 │◄─│ 数据发布 │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 核心能力要求 │ │
│ │ │ │
│ │ • 可追溯性:每条数据的来源和处理历史 │ │
│ │ • 可复现性:相同条件产生相同数据集 │ │
│ │ • 版本控制:数据集变更的完整记录 │ │
│ │ • 质量保证:多维度的数据质量检查 │ │
│ │ • 隐私合规:敏感数据的处理和保护 │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘
数据采集与清洗
数据采集框架
# data_collection.py
"""
数据采集框架
支持多源采集、增量更新、质量过滤
"""
import os
import json
import hashlib
from typing import Dict, List, Optional, Any, Iterator, Callable
from dataclasses import dataclass, field
from datetime import datetime
from abc import ABC, abstractmethod
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
import pandas as pd
from sqlalchemy import create_engine
import boto3
@dataclass
class DataRecord:
"""数据记录"""
id: str
source: str
content: Any
metadata: Dict[str, Any] = field(default_factory=dict)
collected_at: datetime = field(default_factory=datetime.now)
checksum: str = ""
def __post_init__(self):
if not self.checksum:
self.checksum = self._compute_checksum()
def _compute_checksum(self) -> str:
content_str = json.dumps(self.content, sort_keys=True, default=str)
return hashlib.sha256(content_str.encode()).hexdigest()[:16]
class DataSource(ABC):
"""数据源抽象基类"""
@abstractmethod
def connect(self):
pass
@abstractmethod
def fetch(self, query: Dict[str, Any]) -> Iterator[DataRecord]:
pass
@abstractmethod
def get_incremental_key(self) -> str:
pass
class APIDataSource(DataSource):
"""API 数据源"""
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
rate_limit: int = 100
):
self.base_url = base_url
self.auth_token = auth_token
self.rate_limit = rate_limit
self.session = None
def connect(self):
self.session = requests.Session()
if self.auth_token:
self.session.headers["Authorization"] = f"Bearer {self.auth_token}"
def fetch(self, query: Dict[str, Any]) -> Iterator[DataRecord]:
endpoint = query.get("endpoint", "")
params = query.get("params", {})
pagination_key = query.get("pagination_key", "page")
data_key = query.get("data_key", "data")
page = 1
while True:
params[pagination_key] = page
response = self.session.get(
f"{self.base_url}/{endpoint}",
params=params
)
response.raise_for_status()
data = response.json()
items = data.get(data_key, [])
if not items:
break
for item in items:
yield DataRecord(
id=str(item.get("id", hashlib.md5(json.dumps(item).encode()).hexdigest())),
source=f"api:{self.base_url}/{endpoint}",
content=item,
metadata={"page": page}
)
page += 1
def get_incremental_key(self) -> str:
return "updated_at"
class DatabaseDataSource(DataSource):
"""数据库数据源"""
def __init__(self, connection_string: str):
self.connection_string = connection_string
self.engine = None
def connect(self):
self.engine = create_engine(self.connection_string)
def fetch(self, query: Dict[str, Any]) -> Iterator[DataRecord]:
sql = query.get("sql", "")
params = query.get("params", {})
id_column = query.get("id_column", "id")
batch_size = query.get("batch_size", 10000)
with self.engine.connect() as conn:
offset = 0
while True:
batch_sql = f"{sql} LIMIT {batch_size} OFFSET {offset}"
df = pd.read_sql(batch_sql, conn, params=params)
if df.empty:
break
for _, row in df.iterrows():
yield DataRecord(
id=str(row[id_column]),
source=f"db:{self.connection_string.split('@')[-1]}",
content=row.to_dict(),
metadata={"offset": offset}
)
offset += batch_size
def get_incremental_key(self) -> str:
return "updated_at"
class S3DataSource(DataSource):
"""S3 数据源"""
def __init__(
self,
bucket: str,
prefix: str = "",
endpoint_url: Optional[str] = None
):
self.bucket = bucket
self.prefix = prefix
self.endpoint_url = endpoint_url
self.client = None
def connect(self):
self.client = boto3.client(
"s3",
endpoint_url=self.endpoint_url
)
def fetch(self, query: Dict[str, Any]) -> Iterator[DataRecord]:
prefix = query.get("prefix", self.prefix)
file_pattern = query.get("file_pattern", "*.json")
paginator = self.client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
for obj in page.get("Contents", []):
key = obj["Key"]
# 简单的模式匹配
if not self._match_pattern(key, file_pattern):
continue
response = self.client.get_object(Bucket=self.bucket, Key=key)
content = json.loads(response["Body"].read().decode("utf-8"))
yield DataRecord(
id=key,
source=f"s3://{self.bucket}/{key}",
content=content,
metadata={
"size": obj["Size"],
"last_modified": obj["LastModified"].isoformat()
}
)
def _match_pattern(self, key: str, pattern: str) -> bool:
import fnmatch
return fnmatch.fnmatch(key, pattern.replace("*", "**/*"))
def get_incremental_key(self) -> str:
return "LastModified"
class DataCollector:
"""数据采集器"""
def __init__(
self,
output_path: str,
dedup_enabled: bool = True
):
self.output_path = output_path
self.dedup_enabled = dedup_enabled
self.seen_ids = set()
self.filters: List[Callable[[DataRecord], bool]] = []
self.transformers: List[Callable[[DataRecord], DataRecord]] = []
def add_filter(self, filter_fn: Callable[[DataRecord], bool]):
"""添加过滤器"""
self.filters.append(filter_fn)
def add_transformer(self, transformer_fn: Callable[[DataRecord], DataRecord]):
"""添加转换器"""
self.transformers.append(transformer_fn)
def collect(
self,
source: DataSource,
query: Dict[str, Any]
) -> Dict[str, int]:
"""执行采集"""
source.connect()
stats = {
"total": 0,
"filtered": 0,
"duplicates": 0,
"collected": 0
}
records = []
for record in source.fetch(query):
stats["total"] += 1
# 去重
if self.dedup_enabled:
if record.id in self.seen_ids:
stats["duplicates"] += 1
continue
self.seen_ids.add(record.id)
# 过滤
passed = True
for filter_fn in self.filters:
if not filter_fn(record):
passed = False
break
if not passed:
stats["filtered"] += 1
continue
# 转换
for transformer_fn in self.transformers:
record = transformer_fn(record)
records.append(record)
stats["collected"] += 1
# 保存
self._save_records(records)
return stats
def _save_records(self, records: List[DataRecord]):
"""保存记录"""
os.makedirs(self.output_path, exist_ok=True)
output_file = os.path.join(
self.output_path,
f"data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
)
with open(output_file, "w") as f:
for record in records:
f.write(json.dumps({
"id": record.id,
"source": record.source,
"content": record.content,
"metadata": record.metadata,
"collected_at": record.collected_at.isoformat(),
"checksum": record.checksum
}) + "\n")
class DataCleaner:
"""数据清洗器"""
def __init__(self):
self.cleaning_rules: List[Callable[[Dict], Dict]] = []
def add_rule(self, rule: Callable[[Dict], Dict]):
"""添加清洗规则"""
self.cleaning_rules.append(rule)
def clean(self, data: Dict) -> Optional[Dict]:
"""清洗单条数据"""
for rule in self.cleaning_rules:
try:
data = rule(data)
if data is None:
return None
except Exception:
return None
return data
@staticmethod
def remove_nulls(data: Dict) -> Dict:
"""移除空值"""
return {k: v for k, v in data.items() if v is not None}
@staticmethod
def normalize_text(data: Dict, fields: List[str]) -> Dict:
"""文本规范化"""
import unicodedata
import re
for field in fields:
if field in data and isinstance(data[field], str):
text = data[field]
# Unicode 规范化
text = unicodedata.normalize("NFKC", text)
# 去除多余空白
text = re.sub(r"\s+", " ", text).strip()
data[field] = text
return data
@staticmethod
def validate_schema(data: Dict, required_fields: List[str]) -> Optional[Dict]:
"""验证 schema"""
for field in required_fields:
if field not in data:
return None
return data
@staticmethod
def deduplicate_content(data: Dict, content_field: str) -> Dict:
"""内容去重(添加 hash)"""
if content_field in data:
content = str(data[content_field])
data["content_hash"] = hashlib.md5(content.encode()).hexdigest()
return data
# 使用示例
if __name__ == "__main__":
# 创建采集器
collector = DataCollector(
output_path="./collected_data",
dedup_enabled=True
)
# 添加过滤器
collector.add_filter(
lambda r: len(str(r.content.get("text", ""))) > 100
)
# 添加转换器
collector.add_transformer(
lambda r: DataRecord(
id=r.id,
source=r.source,
content={
**r.content,
"text_length": len(str(r.content.get("text", "")))
},
metadata=r.metadata,
collected_at=r.collected_at
)
)
# 从 API 采集
api_source = APIDataSource(
base_url="https://api.example.com",
auth_token="token"
)
stats = collector.collect(api_source, {
"endpoint": "articles",
"params": {"category": "tech"},
"pagination_key": "page",
"data_key": "items"
})
print(f"Collection stats: {stats}")
数据标注系统
标注平台架构
┌─────────────────────────────────────────────────────────────────────┐
│ 数据标注平台架构 │
├─────────────────────────────────────────────────────────────────────┤
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 标注前端 │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 文本标注 │ │ 图像标注 │ │ 音频标注 │ │ 视频标注 │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 标注服务 │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 任务分配 │ │ 质量控制 │ │ 一致性 │ │ 标注员 │ │ │
│ │ │ │ │ │ │ 检查 │ │ 管理 │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 预标注 │ │ 主动学习 │ │ 冲突解决 │ │ 进度追踪 │ │ │
│ │ │ (AI辅助) │ │ │ │ │ │ │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 存储层 │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 原始数据 │ │ 标注数据 │ │ 元数据 │ │ │
│ │ │ (S3) │ │ (DB) │ │ (DB) │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
标注系统实现
# labeling_system.py
"""
数据标注系统
支持多类型标注、质量控制、主动学习
"""
import os
import json
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import uuid
from collections import Counter
import numpy as np
from sqlalchemy import create_engine, Column, String, Integer, DateTime, Text, Float, ForeignKey, Enum as SQLEnum
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.dialects.postgresql import JSONB
Base = declarative_base()
class TaskStatus(str, Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
REVIEWED = "reviewed"
REJECTED = "rejected"
class LabelType(str, Enum):
CLASSIFICATION = "classification"
NER = "ner"
SEQUENCE = "sequence"
BOUNDING_BOX = "bounding_box"
SEGMENTATION = "segmentation"
QA = "qa"
RANKING = "ranking"
# ==================== 数据模型 ====================
class LabelingProject(Base):
"""标注项目"""
__tablename__ = "labeling_projects"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String(255), nullable=False)
description = Column(Text)
label_type = Column(SQLEnum(LabelType), nullable=False)
label_schema = Column(JSONB) # 标注 schema 定义
guidelines = Column(Text) # 标注指南
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
tasks = relationship("LabelingTask", back_populates="project")
class LabelingTask(Base):
"""标注任务"""
__tablename__ = "labeling_tasks"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("labeling_projects.id"), nullable=False)
data_id = Column(String(255), nullable=False) # 原始数据 ID
data_content = Column(JSONB) # 待标注内容
status = Column(SQLEnum(TaskStatus), default=TaskStatus.PENDING)
priority = Column(Integer, default=0)
assigned_to = Column(String(255))
created_at = Column(DateTime, default=datetime.utcnow)
completed_at = Column(DateTime)
project = relationship("LabelingProject", back_populates="tasks")
annotations = relationship("Annotation", back_populates="task")
class Annotation(Base):
"""标注结果"""
__tablename__ = "annotations"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
task_id = Column(String(36), ForeignKey("labeling_tasks.id"), nullable=False)
annotator_id = Column(String(255), nullable=False)
label = Column(JSONB) # 标注内容
confidence = Column(Float) # 标注置信度
time_spent_seconds = Column(Integer) # 耗时
created_at = Column(DateTime, default=datetime.utcnow)
is_ground_truth = Column(Integer, default=0) # 是否为黄金标准
task = relationship("LabelingTask", back_populates="annotations")
class AnnotatorStats(Base):
"""标注员统计"""
__tablename__ = "annotator_stats"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
annotator_id = Column(String(255), nullable=False, unique=True)
total_annotations = Column(Integer, default=0)
accuracy = Column(Float, default=0.0)
avg_time_seconds = Column(Float, default=0.0)
agreement_rate = Column(Float, default=0.0)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# ==================== 标注服务 ====================
@dataclass
class LabelSchema:
"""标注 Schema"""
label_type: LabelType
classes: List[str] = field(default_factory=list)
attributes: Dict[str, List[str]] = field(default_factory=dict)
relations: List[str] = field(default_factory=list)
constraints: Dict[str, Any] = field(default_factory=dict)
class LabelingService:
"""标注服务"""
def __init__(self, database_url: str):
self.engine = create_engine(database_url)
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)
# ==================== 项目管理 ====================
def create_project(
self,
name: str,
label_type: LabelType,
label_schema: LabelSchema,
description: str = "",
guidelines: str = ""
) -> str:
"""创建标注项目"""
session = self.Session()
try:
project = LabelingProject(
name=name,
description=description,
label_type=label_type,
label_schema={
"classes": label_schema.classes,
"attributes": label_schema.attributes,
"relations": label_schema.relations,
"constraints": label_schema.constraints
},
guidelines=guidelines
)
session.add(project)
session.commit()
return project.id
finally:
session.close()
def import_data(
self,
project_id: str,
data_items: List[Dict[str, Any]],
id_field: str = "id",
content_field: str = "content"
) -> int:
"""导入待标注数据"""
session = self.Session()
try:
count = 0
for item in data_items:
task = LabelingTask(
project_id=project_id,
data_id=str(item.get(id_field, str(uuid.uuid4()))),
data_content=item.get(content_field, item),
status=TaskStatus.PENDING
)
session.add(task)
count += 1
session.commit()
return count
finally:
session.close()
# ==================== 任务分配 ====================
def get_next_task(
self,
project_id: str,
annotator_id: str,
strategy: str = "random"
) -> Optional[Dict[str, Any]]:
"""获取下一个任务"""
session = self.Session()
try:
query = session.query(LabelingTask).filter(
LabelingTask.project_id == project_id,
LabelingTask.status == TaskStatus.PENDING
)
if strategy == "priority":
query = query.order_by(LabelingTask.priority.desc())
elif strategy == "random":
query = query.order_by(func.random())
task = query.first()
if task:
# 分配任务
task.status = TaskStatus.IN_PROGRESS
task.assigned_to = annotator_id
session.commit()
return {
"task_id": task.id,
"data_id": task.data_id,
"data_content": task.data_content,
"project_id": task.project_id
}
return None
finally:
session.close()
def submit_annotation(
self,
task_id: str,
annotator_id: str,
label: Dict[str, Any],
confidence: float = 1.0,
time_spent_seconds: int = 0
) -> str:
"""提交标注"""
session = self.Session()
try:
annotation = Annotation(
task_id=task_id,
annotator_id=annotator_id,
label=label,
confidence=confidence,
time_spent_seconds=time_spent_seconds
)
session.add(annotation)
# 更新任务状态
task = session.query(LabelingTask).filter_by(id=task_id).first()
if task:
task.status = TaskStatus.COMPLETED
task.completed_at = datetime.utcnow()
# 更新标注员统计
self._update_annotator_stats(session, annotator_id)
session.commit()
return annotation.id
finally:
session.close()
def _update_annotator_stats(self, session, annotator_id: str):
"""更新标注员统计"""
stats = session.query(AnnotatorStats).filter_by(
annotator_id=annotator_id
).first()
if not stats:
stats = AnnotatorStats(annotator_id=annotator_id)
session.add(stats)
# 计算统计
annotations = session.query(Annotation).filter_by(
annotator_id=annotator_id
).all()
stats.total_annotations = len(annotations)
if annotations:
stats.avg_time_seconds = np.mean([
a.time_spent_seconds for a in annotations if a.time_spent_seconds
])
class QualityController:
"""质量控制器"""
def __init__(self, labeling_service: LabelingService):
self.service = labeling_service
def compute_inter_annotator_agreement(
self,
project_id: str,
metric: str = "cohen_kappa"
) -> Dict[str, float]:
"""计算标注员一致性"""
session = self.service.Session()
try:
# 获取有多人标注的任务
from sqlalchemy import func
multi_annotated = session.query(
Annotation.task_id
).group_by(
Annotation.task_id
).having(func.count(Annotation.id) > 1).all()
task_ids = [t[0] for t in multi_annotated]
if not task_ids:
return {"agreement": 0.0}
# 获取标注结果
annotations = session.query(Annotation).filter(
Annotation.task_id.in_(task_ids)
).all()
# 按任务分组
task_annotations = {}
for ann in annotations:
if ann.task_id not in task_annotations:
task_annotations[ann.task_id] = []
task_annotations[ann.task_id].append(ann)
# 计算一致性
agreements = []
for task_id, anns in task_annotations.items():
if len(anns) >= 2:
# 简化:比较前两个标注
label1 = json.dumps(anns[0].label, sort_keys=True)
label2 = json.dumps(anns[1].label, sort_keys=True)
agreements.append(1.0 if label1 == label2 else 0.0)
return {
"agreement": np.mean(agreements) if agreements else 0.0,
"sample_size": len(agreements)
}
finally:
session.close()
def create_golden_set(
self,
project_id: str,
task_ids: List[str],
verified_labels: Dict[str, Dict]
):
"""创建黄金标准集"""
session = self.service.Session()
try:
for task_id in task_ids:
if task_id in verified_labels:
annotation = Annotation(
task_id=task_id,
annotator_id="gold_standard",
label=verified_labels[task_id],
confidence=1.0,
is_ground_truth=1
)
session.add(annotation)
session.commit()
finally:
session.close()
def evaluate_annotator(
self,
annotator_id: str,
project_id: str
) -> Dict[str, float]:
"""评估标注员质量"""
session = self.service.Session()
try:
# 获取该标注员的标注
annotator_anns = session.query(Annotation).join(
LabelingTask
).filter(
LabelingTask.project_id == project_id,
Annotation.annotator_id == annotator_id
).all()
# 获取黄金标准
gold_anns = session.query(Annotation).join(
LabelingTask
).filter(
LabelingTask.project_id == project_id,
Annotation.is_ground_truth == 1
).all()
gold_by_task = {a.task_id: a.label for a in gold_anns}
# 计算准确率
correct = 0
total = 0
for ann in annotator_anns:
if ann.task_id in gold_by_task:
total += 1
if json.dumps(ann.label, sort_keys=True) == json.dumps(gold_by_task[ann.task_id], sort_keys=True):
correct += 1
return {
"accuracy": correct / total if total > 0 else 0.0,
"evaluated_samples": total,
"total_annotations": len(annotator_anns)
}
finally:
session.close()
def resolve_conflicts(
self,
task_id: str,
resolution_method: str = "majority_vote"
) -> Dict[str, Any]:
"""解决标注冲突"""
session = self.service.Session()
try:
annotations = session.query(Annotation).filter_by(
task_id=task_id
).all()
if len(annotations) < 2:
return {"resolved": False, "reason": "Not enough annotations"}
if resolution_method == "majority_vote":
# 多数投票
label_counts = Counter()
for ann in annotations:
label_key = json.dumps(ann.label, sort_keys=True)
label_counts[label_key] += 1
most_common = label_counts.most_common(1)[0]
resolved_label = json.loads(most_common[0])
confidence = most_common[1] / len(annotations)
elif resolution_method == "confidence_weighted":
# 置信度加权
label_scores = {}
for ann in annotations:
label_key = json.dumps(ann.label, sort_keys=True)
if label_key not in label_scores:
label_scores[label_key] = 0
label_scores[label_key] += ann.confidence or 1.0
best_label = max(label_scores.items(), key=lambda x: x[1])
resolved_label = json.loads(best_label[0])
confidence = best_label[1] / sum(label_scores.values())
else:
return {"resolved": False, "reason": "Unknown method"}
return {
"resolved": True,
"label": resolved_label,
"confidence": confidence,
"method": resolution_method,
"annotator_count": len(annotations)
}
finally:
session.close()
class ActiveLearningSelector:
"""主动学习选择器"""
def __init__(self, labeling_service: LabelingService):
self.service = labeling_service
def select_uncertain_samples(
self,
project_id: str,
model_predictions: Dict[str, Dict[str, float]],
n_samples: int = 100,
strategy: str = "uncertainty"
) -> List[str]:
"""选择不确定样本"""
session = self.service.Session()
try:
# 获取未标注任务
unlabeled_tasks = session.query(LabelingTask).filter(
LabelingTask.project_id == project_id,
LabelingTask.status == TaskStatus.PENDING
).all()
task_ids = [t.data_id for t in unlabeled_tasks]
# 计算不确定性分数
scores = []
for task_id in task_ids:
if task_id in model_predictions:
probs = list(model_predictions[task_id].values())
if strategy == "uncertainty":
# 熵
entropy = -sum(p * np.log(p + 1e-10) for p in probs)
scores.append((task_id, entropy))
elif strategy == "margin":
# 边距采样
sorted_probs = sorted(probs, reverse=True)
margin = sorted_probs[0] - sorted_probs[1] if len(sorted_probs) > 1 else 1.0
scores.append((task_id, -margin)) # 负边距,越小越不确定
elif strategy == "least_confidence":
# 最小置信度
scores.append((task_id, -max(probs)))
# 排序并选择
scores.sort(key=lambda x: x[1], reverse=True)
selected = [s[0] for s in scores[:n_samples]]
# 更新优先级
for task in unlabeled_tasks:
if task.data_id in selected:
task.priority = 10 # 提高优先级
session.commit()
return selected
finally:
session.close()
# 使用示例
if __name__ == "__main__":
# 初始化服务
service = LabelingService("postgresql://user:pass@localhost/labeling")
# 创建项目
schema = LabelSchema(
label_type=LabelType.CLASSIFICATION,
classes=["positive", "negative", "neutral"],
attributes={"confidence": ["high", "medium", "low"]}
)
project_id = service.create_project(
name="Sentiment Analysis",
label_type=LabelType.CLASSIFICATION,
label_schema=schema,
guidelines="Label the sentiment of each text..."
)
# 导入数据
data = [
{"id": "1", "content": {"text": "This is great!"}},
{"id": "2", "content": {"text": "This is terrible."}}
]
service.import_data(project_id, data)
# 标注员获取任务
task = service.get_next_task(project_id, "annotator_1")
if task:
# 提交标注
service.submit_annotation(
task_id=task["task_id"],
annotator_id="annotator_1",
label={"class": "positive", "confidence": "high"},
time_spent_seconds=30
)
# 质量控制
qc = QualityController(service)
agreement = qc.compute_inter_annotator_agreement(project_id)
print(f"Agreement: {agreement}")
数据质量验证
多维度质量检查
# data_quality.py
"""
数据质量验证框架
多维度检查、自动化验证、质量报告
"""
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import json
import pandas as pd
import numpy as np
from scipy import stats
class QualityDimension(str, Enum):
COMPLETENESS = "completeness"
ACCURACY = "accuracy"
CONSISTENCY = "consistency"
TIMELINESS = "timeliness"
UNIQUENESS = "uniqueness"
VALIDITY = "validity"
class Severity(str, Enum):
CRITICAL = "critical"
WARNING = "warning"
INFO = "info"
@dataclass
class QualityCheckResult:
"""质量检查结果"""
check_name: str
dimension: QualityDimension
passed: bool
score: float
severity: Severity
details: Dict[str, Any] = field(default_factory=dict)
failed_samples: List[Any] = field(default_factory=list)
@dataclass
class QualityReport:
"""质量报告"""
dataset_name: str
check_time: datetime
total_records: int
overall_score: float
dimension_scores: Dict[str, float]
check_results: List[QualityCheckResult]
recommendations: List[str]
class DataQualityChecker:
"""数据质量检查器"""
def __init__(self):
self.checks: List[Callable[[pd.DataFrame], QualityCheckResult]] = []
def add_check(self, check: Callable[[pd.DataFrame], QualityCheckResult]):
"""添加检查"""
self.checks.append(check)
def run(self, df: pd.DataFrame, dataset_name: str = "dataset") -> QualityReport:
"""运行所有检查"""
results = []
for check in self.checks:
try:
result = check(df)
results.append(result)
except Exception as e:
results.append(QualityCheckResult(
check_name=check.__name__,
dimension=QualityDimension.VALIDITY,
passed=False,
score=0.0,
severity=Severity.CRITICAL,
details={"error": str(e)}
))
# 计算维度分数
dimension_scores = {}
for dim in QualityDimension:
dim_results = [r for r in results if r.dimension == dim]
if dim_results:
dimension_scores[dim.value] = np.mean([r.score for r in dim_results])
# 整体分数
overall_score = np.mean(list(dimension_scores.values())) if dimension_scores else 0.0
# 生成建议
recommendations = self._generate_recommendations(results)
return QualityReport(
dataset_name=dataset_name,
check_time=datetime.now(),
total_records=len(df),
overall_score=overall_score,
dimension_scores=dimension_scores,
check_results=results,
recommendations=recommendations
)
def _generate_recommendations(
self,
results: List[QualityCheckResult]
) -> List[str]:
"""生成改进建议"""
recommendations = []
failed_critical = [r for r in results if not r.passed and r.severity == Severity.CRITICAL]
failed_warning = [r for r in results if not r.passed and r.severity == Severity.WARNING]
for result in failed_critical:
recommendations.append(
f"[CRITICAL] {result.check_name}: {result.details.get('message', 'Check failed')}"
)
for result in failed_warning:
recommendations.append(
f"[WARNING] {result.check_name}: {result.details.get('message', 'Check failed')}"
)
return recommendations
# ==================== 预定义检查 ====================
class CompletenessChecks:
"""完整性检查"""
@staticmethod
def null_check(
columns: Optional[List[str]] = None,
threshold: float = 0.05
) -> Callable:
"""空值检查"""
def check(df: pd.DataFrame) -> QualityCheckResult:
cols = columns or df.columns.tolist()
null_ratios = {}
failed_cols = []
for col in cols:
if col in df.columns:
ratio = df[col].isnull().mean()
null_ratios[col] = ratio
if ratio > threshold:
failed_cols.append(col)
passed = len(failed_cols) == 0
score = 1.0 - np.mean(list(null_ratios.values()))
return QualityCheckResult(
check_name="null_check",
dimension=QualityDimension.COMPLETENESS,
passed=passed,
score=score,
severity=Severity.WARNING if passed else Severity.CRITICAL,
details={
"null_ratios": null_ratios,
"threshold": threshold,
"failed_columns": failed_cols,
"message": f"Columns with null ratio > {threshold}: {failed_cols}"
}
)
return check
@staticmethod
def required_fields_check(required: List[str]) -> Callable:
"""必填字段检查"""
def check(df: pd.DataFrame) -> QualityCheckResult:
missing = [col for col in required if col not in df.columns]
passed = len(missing) == 0
score = (len(required) - len(missing)) / len(required) if required else 1.0
return QualityCheckResult(
check_name="required_fields_check",
dimension=QualityDimension.COMPLETENESS,
passed=passed,
score=score,
severity=Severity.CRITICAL,
details={
"required": required,
"missing": missing,
"message": f"Missing required fields: {missing}"
}
)
return check
class ConsistencyChecks:
"""一致性检查"""
@staticmethod
def format_check(
column: str,
pattern: str,
sample_size: int = 1000
) -> Callable:
"""格式一致性检查"""
import re
def check(df: pd.DataFrame) -> QualityCheckResult:
if column not in df.columns:
return QualityCheckResult(
check_name=f"format_check_{column}",
dimension=QualityDimension.CONSISTENCY,
passed=False,
score=0.0,
severity=Severity.CRITICAL,
details={"message": f"Column {column} not found"}
)
sample = df[column].dropna().sample(min(sample_size, len(df)))
regex = re.compile(pattern)
matches = sample.apply(lambda x: bool(regex.match(str(x))))
match_ratio = matches.mean()
passed = match_ratio >= 0.95
failed_samples = sample[~matches].head(10).tolist()
return QualityCheckResult(
check_name=f"format_check_{column}",
dimension=QualityDimension.CONSISTENCY,
passed=passed,
score=match_ratio,
severity=Severity.WARNING,
details={
"column": column,
"pattern": pattern,
"match_ratio": match_ratio,
"message": f"Format mismatch ratio: {1-match_ratio:.2%}"
},
failed_samples=failed_samples
)
return check
@staticmethod
def cross_field_consistency(
conditions: List[Dict[str, Any]]
) -> Callable:
"""跨字段一致性检查"""
def check(df: pd.DataFrame) -> QualityCheckResult:
violations = []
for cond in conditions:
query = cond["query"]
name = cond.get("name", query)
try:
violating = df.query(query)
if len(violating) > 0:
violations.append({
"rule": name,
"count": len(violating),
"samples": violating.head(5).to_dict("records")
})
except Exception as e:
violations.append({
"rule": name,
"error": str(e)
})
passed = len(violations) == 0
score = 1.0 - len(violations) / len(conditions) if conditions else 1.0
return QualityCheckResult(
check_name="cross_field_consistency",
dimension=QualityDimension.CONSISTENCY,
passed=passed,
score=score,
severity=Severity.WARNING,
details={
"conditions": conditions,
"violations": violations,
"message": f"Found {len(violations)} consistency violations"
}
)
return check
class UniquenessChecks:
"""唯一性检查"""
@staticmethod
def duplicate_check(
columns: Optional[List[str]] = None,
threshold: float = 0.01
) -> Callable:
"""重复检查"""
def check(df: pd.DataFrame) -> QualityCheckResult:
cols = columns or df.columns.tolist()
duplicates = df.duplicated(subset=cols, keep=False)
dup_ratio = duplicates.mean()
passed = dup_ratio <= threshold
return QualityCheckResult(
check_name="duplicate_check",
dimension=QualityDimension.UNIQUENESS,
passed=passed,
score=1.0 - dup_ratio,
severity=Severity.WARNING,
details={
"columns": cols,
"duplicate_ratio": dup_ratio,
"duplicate_count": duplicates.sum(),
"threshold": threshold,
"message": f"Duplicate ratio: {dup_ratio:.2%}"
},
failed_samples=df[duplicates].head(10).to_dict("records")
)
return check
@staticmethod
def primary_key_check(columns: List[str]) -> Callable:
"""主键唯一性检查"""
def check(df: pd.DataFrame) -> QualityCheckResult:
is_unique = not df.duplicated(subset=columns).any()
return QualityCheckResult(
check_name="primary_key_check",
dimension=QualityDimension.UNIQUENESS,
passed=is_unique,
score=1.0 if is_unique else 0.0,
severity=Severity.CRITICAL,
details={
"primary_key_columns": columns,
"is_unique": is_unique,
"message": "Primary key is unique" if is_unique else "Primary key has duplicates"
}
)
return check
class ValidityChecks:
"""有效性检查"""
@staticmethod
def range_check(
column: str,
min_val: Optional[float] = None,
max_val: Optional[float] = None
) -> Callable:
"""范围检查"""
def check(df: pd.DataFrame) -> QualityCheckResult:
if column not in df.columns:
return QualityCheckResult(
check_name=f"range_check_{column}",
dimension=QualityDimension.VALIDITY,
passed=False,
score=0.0,
severity=Severity.CRITICAL,
details={"message": f"Column {column} not found"}
)
values = df[column].dropna()
out_of_range = pd.Series([False] * len(values))
if min_val is not None:
out_of_range |= values < min_val
if max_val is not None:
out_of_range |= values > max_val
valid_ratio = 1.0 - out_of_range.mean()
passed = valid_ratio >= 0.99
return QualityCheckResult(
check_name=f"range_check_{column}",
dimension=QualityDimension.VALIDITY,
passed=passed,
score=valid_ratio,
severity=Severity.WARNING,
details={
"column": column,
"min": min_val,
"max": max_val,
"out_of_range_count": out_of_range.sum(),
"valid_ratio": valid_ratio,
"message": f"Values out of range: {out_of_range.sum()}"
}
)
return check
@staticmethod
def categorical_check(
column: str,
valid_values: List[Any]
) -> Callable:
"""类别有效性检查"""
def check(df: pd.DataFrame) -> QualityCheckResult:
if column not in df.columns:
return QualityCheckResult(
check_name=f"categorical_check_{column}",
dimension=QualityDimension.VALIDITY,
passed=False,
score=0.0,
severity=Severity.CRITICAL,
details={"message": f"Column {column} not found"}
)
values = df[column].dropna()
invalid = ~values.isin(valid_values)
valid_ratio = 1.0 - invalid.mean()
passed = valid_ratio >= 0.99
invalid_values = values[invalid].unique().tolist()[:10]
return QualityCheckResult(
check_name=f"categorical_check_{column}",
dimension=QualityDimension.VALIDITY,
passed=passed,
score=valid_ratio,
severity=Severity.WARNING,
details={
"column": column,
"valid_values": valid_values,
"invalid_values": invalid_values,
"invalid_count": invalid.sum(),
"message": f"Invalid values: {invalid_values}"
}
)
return check
class DistributionChecks:
"""分布检查"""
@staticmethod
def distribution_drift_check(
column: str,
reference_stats: Dict[str, float],
threshold: float = 0.1
) -> Callable:
"""分布漂移检查"""
def check(df: pd.DataFrame) -> QualityCheckResult:
if column not in df.columns:
return QualityCheckResult(
check_name=f"distribution_drift_{column}",
dimension=QualityDimension.VALIDITY,
passed=False,
score=0.0,
severity=Severity.CRITICAL,
details={"message": f"Column {column} not found"}
)
current_stats = {
"mean": df[column].mean(),
"std": df[column].std(),
"median": df[column].median()
}
drifts = {}
for stat, ref_val in reference_stats.items():
if stat in current_stats and ref_val != 0:
drift = abs(current_stats[stat] - ref_val) / abs(ref_val)
drifts[stat] = drift
max_drift = max(drifts.values()) if drifts else 0.0
passed = max_drift <= threshold
return QualityCheckResult(
check_name=f"distribution_drift_{column}",
dimension=QualityDimension.VALIDITY,
passed=passed,
score=1.0 - min(max_drift, 1.0),
severity=Severity.WARNING,
details={
"column": column,
"reference_stats": reference_stats,
"current_stats": current_stats,
"drifts": drifts,
"max_drift": max_drift,
"threshold": threshold,
"message": f"Max drift: {max_drift:.2%}"
}
)
return check
# 使用示例
if __name__ == "__main__":
# 创建检查器
checker = DataQualityChecker()
# 添加检查
checker.add_check(CompletenessChecks.null_check(threshold=0.05))
checker.add_check(CompletenessChecks.required_fields_check(["id", "text", "label"]))
checker.add_check(UniquenessChecks.duplicate_check(columns=["id"]))
checker.add_check(UniquenessChecks.primary_key_check(["id"]))
checker.add_check(ValidityChecks.categorical_check("label", ["positive", "negative", "neutral"]))
checker.add_check(ConsistencyChecks.format_check("id", r"^[A-Z]{2}\d{6}$"))
# 加载数据
df = pd.DataFrame({
"id": ["AB123456", "CD789012", "EF345678", "AB123456"],
"text": ["Great product!", "Terrible.", None, "OK"],
"label": ["positive", "negative", "neutral", "unknown"]
})
# 运行检查
report = checker.run(df, "test_dataset")
print(f"Overall Score: {report.overall_score:.2%}")
print(f"\nDimension Scores:")
for dim, score in report.dimension_scores.items():
print(f" {dim}: {score:.2%}")
print(f"\nRecommendations:")
for rec in report.recommendations:
print(f" - {rec}")
总结
数据集工程是 ML 项目成功的基础:
数据采集
- 多源数据采集
- 增量更新机制
- 去重和过滤
数据标注
- 标注平台建设
- 质量控制流程
- 主动学习优化
数据质量
- 多维度检查
- 自动化验证
- 持续监控
最佳实践
- 标注指南标准化
- 多人交叉验证
- 定期质量审计
下一章节将探讨 Feature Store 的设计与实现。