03-数据集详解-从获取到预处理
1. 数据集是什么
1.1 核心概念
在机器学习中,数据集是用于训练、验证和测试模型的数据集合。一个完整的机器学习项目通常需要将数据集划分为三个部分:
训练集(Training Set)
- 用途:用于训练模型参数
- 特点:模型会反复"看到"这些数据,并通过梯度下降等算法调整参数
- 作用:让模型学习数据中的模式和规律
验证集(Validation Set)
- 用途:在训练过程中评估模型性能
- 特点:模型不直接用这些数据训练,但会根据验证结果调整超参数
- 作用:防止过拟合,选择最佳模型
测试集(Test Set)
- 用途:最终评估模型的泛化能力
- 特点:只在训练完全结束后使用一次
- 作用:给出模型在真实场景中的预期性能
1.2 数据集划分比例
常见的划分比例:
大数据集(100万+样本)
训练集:98%
验证集:1%
测试集:1%
例如:ImageNet有120万训练图像,可以用99%训练,0.5%验证,0.5%测试
中等数据集(1万-100万样本)
训练集:80%
验证集:10%
测试集:10%
这是最常用的比例,适用于大多数项目
小数据集(1万以下样本)
训练集:70%
验证集:15%
测试集:15%
或者使用K折交叉验证
1.3 划分方法
随机划分
from sklearn.model_selection import train_test_split
# 假设有1000个样本
X = [...] # 特征数据
y = [...] # 标签数据
# 先划分出测试集(10%)
X_temp, X_test, y_temp, y_test = train_test_split(
X, y, test_size=0.1, random_state=42
)
# 再从剩余90%中划分训练集和验证集(81% + 9%)
X_train, X_val, y_train, y_val = train_test_split(
X_temp, y_temp, test_size=0.1, random_state=42
)
print(f"训练集: {len(X_train)} 样本")
print(f"验证集: {len(X_val)} 样本")
print(f"测试集: {len(X_test)} 样本")
分层划分(Stratified Split)
对于分类任务,确保各类别比例相同:
from sklearn.model_selection import train_test_split
# stratify参数确保各类别比例一致
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=42
)
# 验证类别分布
import numpy as np
print("训练集类别分布:", np.bincount(y_train) / len(y_train))
print("测试集类别分布:", np.bincount(y_test) / len(y_test))
K折交叉验证
小数据集使用K折交叉验证:
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(X)):
print(f"Fold {fold+1}:")
X_train_fold = X[train_idx]
X_val_fold = X[val_idx]
y_train_fold = y[train_idx]
y_val_fold = y[val_idx]
# 训练和验证...
2. 数据集来源
2.1 公开数据集
2.1.1 计算机视觉数据集
ImageNet
- 链接:https://www.image-net.org/
- 规模:1400万张图像,2万个类别
- 常用子集:ILSVRC(120万训练图像,1000类)
- 下载方式:需要注册账号,通过学术许可下载
- 用途:图像分类、目标检测预训练
# 使用torchvision下载ImageNet子集示例
from torchvision.datasets import ImageNet
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 需要手动下载数据后指定路径
imagenet_data = ImageNet(
root='/path/to/imagenet',
split='train',
transform=transform
)
COCO (Common Objects in Context)
- 链接:https://cocodataset.org/
- 规模:33万张图像,200万个实例
- 类别:80个物体类别
- 任务:目标检测、实例分割、关键点检测、图像描述
- 下载:
# 下载2017版本数据集
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
MNIST
- 链接:http://yann.lecun.com/exdb/mnist/
- 规模:7万张手写数字图像(28x28灰度)
- 用途:入门级图像分类
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
# 自动下载MNIST数据集
mnist_train = MNIST(
root='./data',
train=True,
download=True,
transform=transforms.ToTensor()
)
mnist_test = MNIST(
root='./data',
train=False,
download=True,
transform=transforms.ToTensor()
)
print(f"训练集大小: {len(mnist_train)}")
print(f"测试集大小: {len(mnist_test)}")
print(f"图像形状: {mnist_train[0][0].shape}")
CIFAR-10/100
- 链接:https://www.cs.toronto.edu/~kriz/cifar.html
- CIFAR-10:6万张32x32彩色图像,10个类别
- CIFAR-100:6万张32x32彩色图像,100个类别
from torchvision.datasets import CIFAR10, CIFAR100
cifar10 = CIFAR10(root='./data', train=True, download=True)
print(f"类别: {cifar10.classes}")
# ['airplane', 'automobile', 'bird', 'cat', 'deer',
# 'dog', 'frog', 'horse', 'ship', 'truck']
2.1.2 自然语言处理数据集
Common Crawl
- 链接:https://commoncrawl.org/
- 规模:数PB的网页数据
- 用途:大规模语言模型预训练
- 下载:通过AWS S3访问
# 下载示例(需要AWS CLI)
aws s3 ls s3://commoncrawl/
aws s3 cp s3://commoncrawl/crawl-data/CC-MAIN-2024-10/segments/ . --recursive
The Pile
- 链接:https://pile.eleuther.ai/
- 规模:825GB高质量文本数据
- 组成:22个不同来源的数据集
- 下载:
# 使用wget下载
wget https://the-eye.eu/public/AI/pile/train/00.jsonl.zst
# 或使用torrent
WikiText
- 链接:https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/
- 规模:WikiText-103(1.03亿词)
- 用途:语言模型训练
from torchtext.datasets import WikiText2, WikiText103
# 下载WikiText-2
train_iter, val_iter, test_iter = WikiText2()
for text in train_iter:
print(text)
break
IMDb电影评论
- 规模:5万条电影评论(正负各25000)
- 用途:情感分析
from torchtext.datasets import IMDB
train_iter, test_iter = IMDB(split=('train', 'test'))
for label, text in train_iter:
print(f"标签: {label}")
print(f"文本: {text[:100]}...")
break
2.1.3 其他领域数据集
Kaggle数据集
- 链接:https://www.kaggle.com/datasets
- 特点:涵盖各个领域,可直接下载
- 下载方式:
# 安装Kaggle CLI
pip install kaggle
# 配置API token(从Kaggle账户设置下载kaggle.json)
mkdir ~/.kaggle
cp kaggle.json ~/.kaggle/
chmod 600 ~/.kaggle/kaggle.json
# 搜索数据集
kaggle datasets list -s "sentiment analysis"
# 下载数据集
kaggle datasets download -d username/dataset-name
UCI机器学习库
- 链接:https://archive.ics.uci.edu/ml/
- 规模:600+个数据集
- 特点:经典机器学习数据集
HuggingFace Datasets
- 链接:https://huggingface.co/datasets
- 规模:10000+个数据集
- 特点:NLP和多模态数据集
from datasets import load_dataset
# 加载数据集
dataset = load_dataset("squad")
print(dataset)
# 加载特定配置
dataset = load_dataset("glue", "mrpc")
2.2 数据采集
2.2.1 网络爬虫
使用Requests + BeautifulSoup
import requests
from bs4 import BeautifulSoup
import time
def scrape_texts(url):
"""爬取网页文本"""
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
response.encoding = response.apparent_encoding
soup = BeautifulSoup(response.text, 'html.parser')
# 提取文本
paragraphs = soup.find_all('p')
texts = [p.get_text().strip() for p in paragraphs if p.get_text().strip()]
return texts
except Exception as e:
print(f"爬取失败: {e}")
return []
# 使用示例
url = "https://example.com/articles"
texts = scrape_texts(url)
print(f"爬取到 {len(texts)} 段文本")
使用Scrapy框架
# 创建Scrapy项目
# scrapy startproject myspider
# spiders/article_spider.py
import scrapy
class ArticleSpider(scrapy.Spider):
name = "articles"
start_urls = ['https://example.com/articles']
def parse(self, response):
# 提取文章列表
for article in response.css('div.article'):
yield {
'title': article.css('h2::text').get(),
'content': article.css('p::text').getall(),
'url': article.css('a::attr(href)').get()
}
# 翻页
next_page = response.css('a.next::attr(href)').get()
if next_page:
yield response.follow(next_page, self.parse)
# 运行爬虫
# scrapy crawl articles -o articles.json
图像爬虫
import requests
import os
from urllib.parse import urljoin
from bs4 import BeautifulSoup
def download_images(url, save_dir='images'):
"""下载网页中的所有图片"""
os.makedirs(save_dir, exist_ok=True)
response = requests.get(url)
soup = BeautifulSoup(response.text, 'html.parser')
images = soup.find_all('img')
print(f"找到 {len(images)} 张图片")
for idx, img in enumerate(images):
img_url = img.get('src')
if not img_url:
continue
# 处理相对路径
img_url = urljoin(url, img_url)
try:
img_data = requests.get(img_url, timeout=10).content
img_name = f"{idx:04d}.jpg"
img_path = os.path.join(save_dir, img_name)
with open(img_path, 'wb') as f:
f.write(img_data)
print(f"下载: {img_name}")
except Exception as e:
print(f"下载失败 {img_url}: {e}")
# 使用示例
download_images('https://example.com/gallery')
2.2.2 API获取数据
Twitter API示例
import tweepy
import json
# Twitter API凭证
API_KEY = 'your_api_key'
API_SECRET = 'your_api_secret'
ACCESS_TOKEN = 'your_access_token'
ACCESS_SECRET = 'your_access_secret'
# 认证
auth = tweepy.OAuthHandler(API_KEY, API_SECRET)
auth.set_access_token(ACCESS_TOKEN, ACCESS_SECRET)
api = tweepy.API(auth)
# 搜索推文
tweets = api.search_tweets(q="machine learning", lang="en", count=100)
# 保存数据
data = []
for tweet in tweets:
data.append({
'text': tweet.text,
'created_at': str(tweet.created_at),
'user': tweet.user.screen_name,
'retweets': tweet.retweet_count,
'likes': tweet.favorite_count
})
with open('tweets.json', 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
Reddit API示例
import praw
import json
# Reddit API
reddit = praw.Reddit(
client_id='your_client_id',
client_secret='your_client_secret',
user_agent='your_user_agent'
)
# 获取subreddit数据
subreddit = reddit.subreddit('MachineLearning')
posts = []
for post in subreddit.hot(limit=100):
posts.append({
'title': post.title,
'text': post.selftext,
'score': post.score,
'num_comments': post.num_comments,
'created_utc': post.created_utc
})
with open('reddit_ml.json', 'w') as f:
json.dump(posts, f, indent=2)
2.2.3 人工标注
标注工具
Label Studio
- 开源标注平台
- 支持图像、文本、音频、视频
- 安装:
pip install label-studio label-studio startLabelme
- 图像标注工具
- 输出JSON格式
pip install labelme labelmeProdigy
- 商业标注工具
- 支持主动学习
标注质量控制
from sklearn.metrics import cohen_kappa_score
# 多个标注员的标注结果
annotator1 = [1, 0, 1, 1, 0, 1, 0]
annotator2 = [1, 0, 1, 0, 0, 1, 0]
# 计算Cohen's Kappa(标注一致性)
kappa = cohen_kappa_score(annotator1, annotator2)
print(f"Kappa系数: {kappa:.3f}")
# > 0.8: 几乎完全一致
# 0.6-0.8: 一致性强
# 0.4-0.6: 一致性中等
2.3 数据购买和授权
数据购买渠道
商业数据提供商
- AWS Data Exchange
- Google Cloud Marketplace
- Datarade.ai
专业数据公司
- 图像数据:Getty Images, Shutterstock
- 文本数据:LexisNexis, Bloomberg
- 行业数据:各垂直领域数据公司
数据使用授权注意事项
- 确认数据使用许可(个人/商业/研究)
- 检查数据再分发限制
- 遵守隐私法规(GDPR, CCPA等)
- 引用数据集时注明来源
3. 数据集格式
3.1 图像数据格式
3.1.1 图像文件格式
JPEG (.jpg, .jpeg)
- 特点:有损压缩,适合照片
- 大小:相对较小
- 用途:大多数图像分类任务
PNG (.png)
- 特点:无损压缩,支持透明通道
- 大小:相对较大
- 用途:需要保留细节的任务
读取和转换
from PIL import Image
import numpy as np
# 读取图像
img = Image.open('example.jpg')
print(f"格式: {img.format}")
print(f"模式: {img.mode}") # RGB, RGBA, L (灰度)
print(f"大小: {img.size}") # (width, height)
# 转换为numpy数组
img_array = np.array(img)
print(f"数组形状: {img_array.shape}") # (H, W, C)
print(f"数据类型: {img_array.dtype}") # uint8
# 转换格式
img_png = img.convert('RGB')
img_png.save('example.png')
# 转换为灰度
img_gray = img.convert('L')
img_gray.save('example_gray.jpg')
3.1.2 图像数据集目录结构
ImageFolder结构(最常用)
dataset/
├── train/
│ ├── cat/
│ │ ├── cat001.jpg
│ │ ├── cat002.jpg
│ │ └── ...
│ ├── dog/
│ │ ├── dog001.jpg
│ │ ├── dog002.jpg
│ │ └── ...
│ └── bird/
│ ├── bird001.jpg
│ └── ...
├── val/
│ ├── cat/
│ ├── dog/
│ └── bird/
└── test/
├── cat/
├── dog/
└── bird/
使用PyTorch加载
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
train_dataset = ImageFolder(
root='dataset/train',
transform=transform
)
print(f"类别: {train_dataset.classes}")
print(f"类别索引: {train_dataset.class_to_idx}")
print(f"样本数: {len(train_dataset)}")
# 获取一个样本
image, label = train_dataset[0]
print(f"图像形状: {image.shape}") # [3, 224, 224]
print(f"标签: {label}")
3.1.3 LMDB格式
LMDB(Lightning Memory-Mapped Database)是一种高效的键值存储格式,适合大规模图像数据集。
创建LMDB数据集
import lmdb
import cv2
import pickle
import os
def create_lmdb_dataset(image_folder, lmdb_path):
"""将图像文件夹转换为LMDB格式"""
# 估算数据库大小(图像总大小的10倍作为安全边际)
map_size = 1099511627776 # 1TB
env = lmdb.open(lmdb_path, map_size=map_size)
txn = env.begin(write=True)
image_files = sorted(os.listdir(image_folder))
for idx, img_name in enumerate(image_files):
img_path = os.path.join(image_folder, img_name)
# 读取图像
img = cv2.imread(img_path)
if img is None:
continue
# 序列化
img_bytes = pickle.dumps(img)
# 存储
key = f"{idx:08d}".encode()
txn.put(key, img_bytes)
if (idx + 1) % 1000 == 0:
txn.commit()
txn = env.begin(write=True)
print(f"已处理 {idx + 1} 张图像")
# 存储元数据
meta = {
'total': len(image_files),
'image_names': image_files
}
txn.put(b'__meta__', pickle.dumps(meta))
txn.commit()
env.close()
print(f"LMDB创建完成: {lmdb_path}")
# 创建LMDB
create_lmdb_dataset('images/', 'images.lmdb')
读取LMDB数据集
import torch
from torch.utils.data import Dataset
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, transform=None):
self.env = lmdb.open(lmdb_path, readonly=True, lock=False)
self.transform = transform
with self.env.begin() as txn:
meta = pickle.loads(txn.get(b'__meta__'))
self.total = meta['total']
def __len__(self):
return self.total
def __getitem__(self, idx):
with self.env.begin() as txn:
key = f"{idx:08d}".encode()
img_bytes = txn.get(key)
img = pickle.loads(img_bytes)
if self.transform:
img = self.transform(img)
return img
# 使用
dataset = LMDBDataset('images.lmdb', transform=transforms.ToTensor())
print(f"数据集大小: {len(dataset)}")
3.2 文本数据格式
3.2.1 纯文本格式 (.txt)
单文件格式
这是第一条文本数据
这是第二条文本数据
这是第三条文本数据
读取示例
def read_text_file(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
lines = [line.strip() for line in lines if line.strip()]
return lines
texts = read_text_file('data.txt')
print(f"读取了 {len(texts)} 条数据")
3.2.2 JSON格式
单个JSON文件
{
"data": [
{
"text": "这是一条正面评论",
"label": "positive"
},
{
"text": "这是一条负面评论",
"label": "negative"
}
]
}
读取JSON
import json
with open('data.json', 'r', encoding='utf-8') as f:
data = json.load(f)
for item in data['data']:
print(f"文本: {item['text']}")
print(f"标签: {item['label']}")
3.2.3 JSONL格式(JSON Lines)
JSONL是每行一个JSON对象,适合大规模数据集。
JSONL文件示例
{"text": "这是第一条评论", "label": "positive", "score": 0.95}
{"text": "这是第二条评论", "label": "negative", "score": 0.12}
{"text": "这是第三条评论", "label": "neutral", "score": 0.50}
读取JSONL
import json
def read_jsonl(file_path):
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line))
return data
# 使用生成器节省内存
def read_jsonl_lazy(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
yield json.loads(line)
# 逐行处理
for item in read_jsonl_lazy('data.jsonl'):
print(item['text'])
写入JSONL
import json
data = [
{"text": "评论1", "label": "positive"},
{"text": "评论2", "label": "negative"}
]
with open('output.jsonl', 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
3.2.4 Parquet格式
Parquet是列式存储格式,非常适合大规模数据集,比CSV和JSON更高效。
创建Parquet文件
import pandas as pd
# 创建数据
data = {
'text': ['文本1', '文本2', '文本3'] * 10000,
'label': [0, 1, 0] * 10000,
'score': [0.9, 0.1, 0.5] * 10000
}
df = pd.DataFrame(data)
# 保存为Parquet
df.to_parquet('data.parquet', compression='gzip')
print(f"保存了 {len(df)} 条数据")
读取Parquet文件
import pandas as pd
# 读取整个文件
df = pd.read_parquet('data.parquet')
print(df.head())
# 只读取部分列
df_subset = pd.read_parquet('data.parquet', columns=['text', 'label'])
# 使用PyArrow分块读取
import pyarrow.parquet as pq
parquet_file = pq.ParquetFile('data.parquet')
for batch in parquet_file.iter_batches(batch_size=1000):
df_batch = batch.to_pandas()
# 处理batch...
HuggingFace Datasets使用Parquet
from datasets import Dataset
# 从Parquet加载
dataset = Dataset.from_parquet('data.parquet')
print(dataset)
# 保存为Parquet
dataset.to_parquet('output.parquet')
3.3 标注格式
3.3.1 COCO JSON格式
COCO格式用于目标检测和实例分割。
COCO JSON结构
{
"images": [
{
"id": 1,
"file_name": "000001.jpg",
"width": 640,
"height": 480
}
],
"annotations": [
{
"id": 1,
"image_id": 1,
"category_id": 1,
"bbox": [100, 120, 200, 150],
"area": 30000,
"segmentation": [[100, 120, 300, 120, 300, 270, 100, 270]],
"iscrowd": 0
}
],
"categories": [
{
"id": 1,
"name": "cat",
"supercategory": "animal"
},
{
"id": 2,
"name": "dog",
"supercategory": "animal"
}
]
}
读取COCO格式
import json
from pycocotools.coco import COCO
import cv2
import matplotlib.pyplot as plt
# 使用pycocotools读取
coco = COCO('annotations.json')
# 获取所有图像ID
img_ids = coco.getImgIds()
print(f"图像数量: {len(img_ids)}")
# 获取所有类别
cats = coco.loadCats(coco.getCatIds())
cat_names = [cat['name'] for cat in cats]
print(f"类别: {cat_names}")
# 加载一张图像的标注
img_id = img_ids[0]
img_info = coco.loadImgs(img_id)[0]
ann_ids = coco.getAnnIds(imgIds=img_id)
anns = coco.loadAnns(ann_ids)
print(f"图像: {img_info['file_name']}")
print(f"标注数量: {len(anns)}")
for ann in anns:
print(f"类别: {ann['category_id']}")
print(f"边界框: {ann['bbox']}") # [x, y, width, height]
创建COCO格式标注
import json
from datetime import datetime
def create_coco_annotation(image_files, annotations):
"""创建COCO格式标注文件"""
coco_format = {
"info": {
"description": "My Dataset",
"version": "1.0",
"year": 2024,
"date_created": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
},
"images": [],
"annotations": [],
"categories": [
{"id": 1, "name": "cat", "supercategory": "animal"},
{"id": 2, "name": "dog", "supercategory": "animal"}
]
}
ann_id = 1
for img_id, img_file in enumerate(image_files, 1):
# 读取图像尺寸
img = cv2.imread(img_file)
h, w = img.shape[:2]
coco_format["images"].append({
"id": img_id,
"file_name": img_file,
"width": w,
"height": h
})
# 添加标注
for ann in annotations.get(img_id, []):
coco_format["annotations"].append({
"id": ann_id,
"image_id": img_id,
"category_id": ann['category_id'],
"bbox": ann['bbox'],
"area": ann['bbox'][2] * ann['bbox'][3],
"iscrowd": 0
})
ann_id += 1
with open('coco_annotations.json', 'w') as f:
json.dump(coco_format, f, indent=2)
# 使用示例
image_files = ['img1.jpg', 'img2.jpg']
annotations = {
1: [{'category_id': 1, 'bbox': [100, 100, 200, 150]}],
2: [{'category_id': 2, 'bbox': [50, 80, 180, 200]}]
}
create_coco_annotation(image_files, annotations)
3.3.2 YOLO TXT格式
YOLO格式每张图像对应一个txt文件,每行一个目标。
目录结构
dataset/
├── images/
│ ├── img1.jpg
│ ├── img2.jpg
│ └── ...
└── labels/
├── img1.txt
├── img2.txt
└── ...
YOLO标注格式
# img1.txt
# <class_id> <x_center> <y_center> <width> <height>
# 坐标都是归一化到[0,1]的
0 0.5 0.5 0.3 0.4
1 0.2 0.3 0.15 0.2
读取YOLO格式
import cv2
import numpy as np
def read_yolo_label(label_file, img_width, img_height):
"""读取YOLO格式标注"""
boxes = []
with open(label_file, 'r') as f:
for line in f:
parts = line.strip().split()
class_id = int(parts[0])
x_center = float(parts[1]) * img_width
y_center = float(parts[2]) * img_height
width = float(parts[3]) * img_width
height = float(parts[4]) * img_height
# 转换为左上角坐标
x1 = int(x_center - width / 2)
y1 = int(y_center - height / 2)
x2 = int(x_center + width / 2)
y2 = int(y_center + height / 2)
boxes.append({
'class_id': class_id,
'bbox': [x1, y1, x2, y2]
})
return boxes
# 使用示例
img = cv2.imread('images/img1.jpg')
h, w = img.shape[:2]
boxes = read_yolo_label('labels/img1.txt', w, h)
# 绘制边界框
for box in boxes:
x1, y1, x2, y2 = box['bbox']
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(img, str(box['class_id']), (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
cv2.imwrite('result.jpg', img)
创建YOLO格式标注
def bbox_to_yolo(bbox, img_width, img_height):
"""
将边界框转换为YOLO格式
bbox: [x1, y1, x2, y2]
返回: [x_center, y_center, width, height] (归一化)
"""
x1, y1, x2, y2 = bbox
x_center = ((x1 + x2) / 2) / img_width
y_center = ((y1 + y2) / 2) / img_height
width = (x2 - x1) / img_width
height = (y2 - y1) / img_height
return [x_center, y_center, width, height]
def create_yolo_label(annotations, img_width, img_height, output_file):
"""创建YOLO标注文件"""
with open(output_file, 'w') as f:
for ann in annotations:
class_id = ann['class_id']
bbox = ann['bbox'] # [x1, y1, x2, y2]
yolo_bbox = bbox_to_yolo(bbox, img_width, img_height)
line = f"{class_id} {' '.join(map(str, yolo_bbox))}\n"
f.write(line)
# 使用示例
annotations = [
{'class_id': 0, 'bbox': [100, 100, 300, 250]},
{'class_id': 1, 'bbox': [50, 80, 200, 280]}
]
create_yolo_label(annotations, 640, 480, 'labels/img1.txt')
3.3.3 CSV格式
CSV格式简单直观,适合表格型数据。
CSV标注示例
image_file,label,x1,y1,x2,y2
img1.jpg,cat,100,100,300,250
img1.jpg,dog,50,80,200,280
img2.jpg,cat,120,90,280,230
读取CSV标注
import pandas as pd
# 读取CSV
df = pd.read_csv('annotations.csv')
print(df.head())
# 按图像分组
grouped = df.groupby('image_file')
for img_file, group in grouped:
print(f"\n图像: {img_file}")
for idx, row in group.iterrows():
print(f" 标签: {row['label']}")
print(f" 边界框: ({row['x1']}, {row['y1']}, {row['x2']}, {row['y2']})")
创建CSV标注
import pandas as pd
data = {
'image_file': ['img1.jpg', 'img1.jpg', 'img2.jpg'],
'label': ['cat', 'dog', 'cat'],
'x1': [100, 50, 120],
'y1': [100, 80, 90],
'x2': [300, 200, 280],
'y2': [250, 280, 230]
}
df = pd.DataFrame(data)
df.to_csv('annotations.csv', index=False)
4. 数据预处理
4.1 数据清洗
4.1.1 去重
图像去重
import os
import hashlib
from PIL import Image
def get_image_hash(img_path):
"""计算图像的MD5哈希"""
with open(img_path, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
def remove_duplicate_images(image_folder):
"""删除重复图像"""
seen_hashes = {}
duplicates = []
for img_name in os.listdir(image_folder):
img_path = os.path.join(image_folder, img_name)
if not img_path.lower().endswith(('.jpg', '.jpeg', '.png')):
continue
img_hash = get_image_hash(img_path)
if img_hash in seen_hashes:
duplicates.append(img_path)
print(f"重复: {img_name} (与 {seen_hashes[img_hash]} 相同)")
else:
seen_hashes[img_hash] = img_name
print(f"\n找到 {len(duplicates)} 张重复图像")
# 删除重复项
for dup in duplicates:
os.remove(dup)
return len(duplicates)
# 使用
removed = remove_duplicate_images('images/')
文本去重
def remove_duplicate_texts(texts):
"""去除重复文本"""
seen = set()
unique_texts = []
for text in texts:
# 标准化文本
normalized = text.strip().lower()
if normalized not in seen:
seen.add(normalized)
unique_texts.append(text)
print(f"原始: {len(texts)}, 去重后: {len(unique_texts)}")
return unique_texts
# 使用
texts = ["Hello World", "hello world", "Hello World", "Different text"]
unique = remove_duplicate_texts(texts)
使用MinHash进行近似去重
from datasketch import MinHash, MinHashLSH
def create_minhash(text, num_perm=128):
"""创建文本的MinHash"""
m = MinHash(num_perm=num_perm)
for word in text.split():
m.update(word.encode('utf-8'))
return m
def remove_similar_texts(texts, threshold=0.8):
"""删除相似文本"""
lsh = MinHashLSH(threshold=threshold, num_perm=128)
minhashes = {}
for idx, text in enumerate(texts):
m = create_minhash(text)
minhashes[idx] = m
lsh.insert(idx, m)
# 查找重复
duplicates = set()
for idx, text in enumerate(texts):
if idx in duplicates:
continue
similar = lsh.query(minhashes[idx])
for sim_idx in similar:
if sim_idx != idx:
duplicates.add(sim_idx)
# 保留非重复项
unique_texts = [text for idx, text in enumerate(texts) if idx not in duplicates]
print(f"原始: {len(texts)}, 去重后: {len(unique_texts)}")
return unique_texts
4.1.2 去噪
图像去噪
import cv2
import numpy as np
def denoise_image(image):
"""图像去噪"""
# 高斯模糊
blurred = cv2.GaussianBlur(image, (5, 5), 0)
# 或使用双边滤波(保留边缘)
bilateral = cv2.bilateralFilter(image, 9, 75, 75)
# 或使用非局部均值去噪
denoised = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
return denoised
# 使用
img = cv2.imread('noisy_image.jpg')
clean_img = denoise_image(img)
cv2.imwrite('clean_image.jpg', clean_img)
文本去噪
import re
def clean_text(text):
"""清洗文本数据"""
# 去除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 去除URL
text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
# 去除邮箱
text = re.sub(r'\S+@\S+', '', text)
# 去除特殊字符(保留中英文、数字、基本标点)
text = re.sub(r'[^\w\s\u4e00-\u9fff.,!?;:,。!?;:]', '', text)
# 去除多余空白
text = re.sub(r'\s+', ' ', text).strip()
return text
# 使用
raw_text = "<p>访问 https://example.com 或邮件 user@example.com</p> 多余空格 "
cleaned = clean_text(raw_text)
print(cleaned) # "访问 或邮件"
4.1.3 异常值处理
检测图像异常
import os
from PIL import Image
def check_image_integrity(image_folder):
"""检查图像完整性"""
corrupted = []
for img_name in os.listdir(image_folder):
img_path = os.path.join(image_folder, img_name)
try:
img = Image.open(img_path)
img.verify() # 验证图像完整性
# 重新打开进行更多检查
img = Image.open(img_path)
img.load() # 实际加载图像数据
# 检查尺寸
if img.size[0] < 10 or img.size[1] < 10:
print(f"图像过小: {img_name} {img.size}")
corrupted.append(img_path)
except Exception as e:
print(f"损坏的图像: {img_name} - {e}")
corrupted.append(img_path)
return corrupted
# 使用
corrupted_images = check_image_integrity('images/')
print(f"发现 {len(corrupted_images)} 张问题图像")
检测文本异常
def detect_text_outliers(texts, min_length=10, max_length=1000):
"""检测异常文本"""
outliers = []
for idx, text in enumerate(texts):
# 长度异常
if len(text) < min_length or len(text) > max_length:
outliers.append((idx, f"长度异常: {len(text)}"))
continue
# 字符异常(如全是标点符号)
alpha_ratio = sum(c.isalnum() for c in text) / len(text)
if alpha_ratio < 0.5:
outliers.append((idx, f"字符异常: {alpha_ratio:.2%}是字母数字"))
continue
return outliers
# 使用
texts = ["正常文本", "a", "!!!!!!", "这是一段正常的文本"]
outliers = detect_text_outliers(texts)
for idx, reason in outliers:
print(f"文本 {idx}: {reason}")
4.2 数据增强
4.2.1 图像数据增强
基础图像增强
import torchvision.transforms as transforms
from PIL import Image
# 定义增强操作
train_transform = transforms.Compose([
# 随机裁剪
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
# 随机水平翻转
transforms.RandomHorizontalFlip(p=0.5),
# 随机旋转
transforms.RandomRotation(degrees=15),
# 颜色抖动
transforms.ColorJitter(
brightness=0.2, # 亮度
contrast=0.2, # 对比度
saturation=0.2, # 饱和度
hue=0.1 # 色调
),
# 转换为张量
transforms.ToTensor(),
# 归一化
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# 应用增强
img = Image.open('image.jpg')
augmented_img = train_transform(img)
高级图像增强(Albumentations)
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
# 定义增强管道
transform = A.Compose([
A.RandomResizedCrop(height=224, width=224, scale=(0.8, 1.0)),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.2),
A.Rotate(limit=30, p=0.5),
# 高级变换
A.OneOf([
A.GaussNoise(var_limit=(10.0, 50.0)),
A.GaussianBlur(blur_limit=(3, 7)),
A.MotionBlur(blur_limit=5),
], p=0.3),
A.OneOf([
A.OpticalDistortion(distort_limit=0.5),
A.GridDistortion(num_steps=5, distort_limit=0.3),
A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50),
], p=0.3),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.HueSaturationValue(p=0.3),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# 使用
image = cv2.imread('image.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
augmented = transform(image=image)
augmented_image = augmented['image']
数据增强可视化
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
# 加载原始图像
img = Image.open('image.jpg')
# 定义增强
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.3, contrast=0.3)
])
# 生成多个增强版本
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes[0, 0].imshow(img)
axes[0, 0].set_title('Original')
axes[0, 0].axis('off')
for i in range(7):
row = (i + 1) // 4
col = (i + 1) % 4
augmented = transform(img)
axes[row, col].imshow(augmented)
axes[row, col].set_title(f'Augmented {i+1}')
axes[row, col].axis('off')
plt.tight_layout()
plt.savefig('augmentation_examples.png')
4.2.2 文本数据增强
同义词替换
import random
import nltk
from nltk.corpus import wordnet
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')
def get_synonyms(word):
"""获取单词的同义词"""
synonyms = set()
for syn in wordnet.synsets(word):
for lemma in syn.lemmas():
synonym = lemma.name().replace('_', ' ')
if synonym != word:
synonyms.add(synonym)
return list(synonyms)
def synonym_replacement(text, n=1):
"""同义词替换增强"""
words = text.split()
# 随机选择n个词进行替换
random_words = list(set([word for word in words if word.isalnum()]))
random.shuffle(random_words)
num_replaced = 0
for word in random_words:
synonyms = get_synonyms(word)
if synonyms:
synonym = random.choice(synonyms)
words = [synonym if w == word else w for w in words]
num_replaced += 1
if num_replaced >= n:
break
return ' '.join(words)
# 使用
text = "The quick brown fox jumps over the lazy dog"
augmented = synonym_replacement(text, n=2)
print(f"原文: {text}")
print(f"增强: {augmented}")
回译增强
from googletrans import Translator
def back_translation(text, intermediate_lang='zh-cn'):
"""回译增强"""
translator = Translator()
# 翻译到中间语言
translated = translator.translate(text, dest=intermediate_lang)
# 翻译回原语言
back_translated = translator.translate(translated.text, dest='en')
return back_translated.text
# 使用
text = "This is a great product"
augmented = back_translation(text)
print(f"原文: {text}")
print(f"增强: {augmented}")
随机插入、交换、删除
import random
def random_insertion(text, n=1):
"""随机插入同义词"""
words = text.split()
for _ in range(n):
add_word(words)
return ' '.join(words)
def add_word(words):
"""添加一个同义词"""
synonyms = []
counter = 0
while len(synonyms) < 1:
random_word = words[random.randint(0, len(words)-1)]
synonyms = get_synonyms(random_word)
counter += 1
if counter >= 10:
return
random_synonym = random.choice(synonyms)
random_idx = random.randint(0, len(words)-1)
words.insert(random_idx, random_synonym)
def random_swap(text, n=1):
"""随机交换单词位置"""
words = text.split()
for _ in range(n):
words = swap_word(words)
return ' '.join(words)
def swap_word(words):
"""交换两个单词"""
random_idx_1 = random.randint(0, len(words)-1)
random_idx_2 = random_idx_1
counter = 0
while random_idx_2 == random_idx_1:
random_idx_2 = random.randint(0, len(words)-1)
counter += 1
if counter > 3:
return words
words[random_idx_1], words[random_idx_2] = words[random_idx_2], words[random_idx_1]
return words
def random_deletion(text, p=0.1):
"""随机删除单词"""
words = text.split()
if len(words) == 1:
return words[0]
new_words = []
for word in words:
if random.uniform(0, 1) > p:
new_words.append(word)
if len(new_words) == 0:
return random.choice(words)
return ' '.join(new_words)
# 使用
text = "The quick brown fox jumps over the lazy dog"
print(f"原文: {text}")
print(f"插入: {random_insertion(text, 2)}")
print(f"交换: {random_swap(text, 2)}")
print(f"删除: {random_deletion(text, 0.2)}")
4.3 文本预处理
4.3.1 分词
英文分词
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
nltk.download('punkt')
text = "Hello world! This is a test. How are you?"
# 句子分词
sentences = sent_tokenize(text)
print("句子:", sentences)
# 单词分词
words = word_tokenize(text)
print("单词:", words)
中文分词
import jieba
text = "我爱自然语言处理和深度学习"
# 精确模式
seg_list = jieba.cut(text, cut_all=False)
print("精确模式:", " / ".join(seg_list))
# 全模式
seg_list = jieba.cut(text, cut_all=True)
print("全模式:", " / ".join(seg_list))
# 搜索引擎模式
seg_list = jieba.cut_for_search(text)
print("搜索引擎模式:", " / ".join(seg_list))
# 添加自定义词典
jieba.add_word("深度学习")
使用HuggingFace分词器
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text = "Hello, how are you?"
# 分词
tokens = tokenizer.tokenize(text)
print("Token:", tokens)
# 转换为ID
input_ids = tokenizer.encode(text)
print("IDs:", input_ids)
# 解码
decoded = tokenizer.decode(input_ids)
print("解码:", decoded)
# 完整编码(包含attention mask等)
encoded = tokenizer(
text,
padding='max_length',
max_length=20,
truncation=True,
return_tensors='pt'
)
print(encoded)
4.3.2 去停用词
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
# 英文停用词
stop_words = set(stopwords.words('english'))
text = "This is a sample sentence with some stop words"
words = text.split()
filtered_words = [word for word in words if word.lower() not in stop_words]
print("原文:", text)
print("过滤后:", " ".join(filtered_words))
# 中文停用词(需要自己准备停用词表)
cn_stop_words = set(['的', '了', '在', '是', '我', '有', '和', '就', '不', '人'])
cn_text = "我在北京天安门广场看升旗"
cn_words = jieba.cut(cn_text)
filtered_cn = [word for word in cn_words if word not in cn_stop_words]
print("中文过滤:", " / ".join(filtered_cn))
4.3.3 词干提取和词形还原
词干提取(Stemming)
from nltk.stem import PorterStemmer, SnowballStemmer
# Porter Stemmer
porter = PorterStemmer()
words = ["running", "runs", "ran", "runner", "easily", "fairly"]
print("Porter Stemmer:")
for word in words:
print(f"{word} -> {porter.stem(word)}")
# Snowball Stemmer
snowball = SnowballStemmer("english")
print("\nSnowball Stemmer:")
for word in words:
print(f"{word} -> {snowball.stem(word)}")
词形还原(Lemmatization)
from nltk.stem import WordNetLemmatizer
import nltk
nltk.download('wordnet')
lemmatizer = WordNetLemmatizer()
words = ["running", "runs", "ran", "runner", "better", "worse"]
print("Lemmatization:")
for word in words:
# 动词
print(f"{word} (v) -> {lemmatizer.lemmatize(word, pos='v')}")
# 名词
print(f"{word} (n) -> {lemmatizer.lemmatize(word, pos='n')}")
# 形容词
print(f"{word} (a) -> {lemmatizer.lemmatize(word, pos='a')}")
4.3.4 文本归一化
import re
import unicodedata
def normalize_text(text):
"""文本归一化"""
# 转小写
text = text.lower()
# Unicode归一化
text = unicodedata.normalize('NFKD', text)
# 去除重音符号
text = ''.join([c for c in text if not unicodedata.combining(c)])
# 去除标点
text = re.sub(r'[^\w\s]', '', text)
# 去除多余空格
text = re.sub(r'\s+', ' ', text).strip()
return text
# 使用
text = " Héllo WORLD!!! Multiple spaces "
normalized = normalize_text(text)
print(f"原文: '{text}'")
print(f"归一化: '{normalized}'")
5. DataLoader实现
5.1 PyTorch Dataset和DataLoader
5.1.1 自定义Dataset
图像Dataset
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
class CustomImageDataset(Dataset):
def __init__(self, image_folder, labels_file, transform=None):
"""
Args:
image_folder: 图像文件夹路径
labels_file: 标签文件路径(CSV格式)
transform: 数据变换
"""
self.image_folder = image_folder
self.transform = transform
# 读取标签
self.data = []
with open(labels_file, 'r') as f:
next(f) # 跳过表头
for line in f:
img_name, label = line.strip().split(',')
self.data.append((img_name, int(label)))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_name, label = self.data[idx]
img_path = os.path.join(self.image_folder, img_name)
# 加载图像
image = Image.open(img_path).convert('RGB')
# 应用变换
if self.transform:
image = self.transform(image)
return image, label
# 使用示例
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
dataset = CustomImageDataset(
image_folder='images/',
labels_file='labels.csv',
transform=transform
)
print(f"数据集大小: {len(dataset)}")
image, label = dataset[0]
print(f"图像形状: {image.shape}")
print(f"标签: {label}")
文本Dataset
import torch
from torch.utils.data import Dataset
import json
class TextClassificationDataset(Dataset):
def __init__(self, data_file, tokenizer, max_length=128):
"""
Args:
data_file: JSONL格式数据文件
tokenizer: 分词器
max_length: 最大序列长度
"""
self.data = []
self.tokenizer = tokenizer
self.max_length = max_length
# 读取数据
with open(data_file, 'r', encoding='utf-8') as f:
for line in f:
item = json.loads(line)
self.data.append(item)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
text = item['text']
label = item['label']
# 分词和编码
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'label': torch.tensor(label, dtype=torch.long)
}
# 使用示例
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = TextClassificationDataset(
data_file='data.jsonl',
tokenizer=tokenizer,
max_length=128
)
print(f"数据集大小: {len(dataset)}")
sample = dataset[0]
print(f"Input IDs shape: {sample['input_ids'].shape}")
print(f"Attention mask shape: {sample['attention_mask'].shape}")
print(f"Label: {sample['label']}")
5.1.2 DataLoader配置
基础DataLoader
from torch.utils.data import DataLoader
# 创建DataLoader
train_loader = DataLoader(
dataset=train_dataset,
batch_size=32,
shuffle=True, # 训练时打乱
num_workers=4, # 多线程加载
pin_memory=True, # 加速GPU传输
drop_last=True # 丢弃最后不完整的batch
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=64,
shuffle=False, # 验证时不打乱
num_workers=4,
pin_memory=True
)
# 使用
for batch_idx, (images, labels) in enumerate(train_loader):
print(f"Batch {batch_idx}")
print(f"Images shape: {images.shape}") # [32, 3, 224, 224]
print(f"Labels shape: {labels.shape}") # [32]
break
自定义collate_fn
def custom_collate_fn(batch):
"""
自定义batch整理函数
用于处理不同长度的序列
"""
# batch是一个列表,每个元素是dataset[i]的返回值
texts = [item['text'] for item in batch]
labels = [item['label'] for item in batch]
# 找到最大长度
max_len = max(len(text) for text in texts)
# 填充到相同长度
padded_texts = []
attention_masks = []
for text in texts:
# 填充
pad_len = max_len - len(text)
padded_text = text + [0] * pad_len
attention_mask = [1] * len(text) + [0] * pad_len
padded_texts.append(padded_text)
attention_masks.append(attention_mask)
return {
'input_ids': torch.tensor(padded_texts),
'attention_mask': torch.tensor(attention_masks),
'labels': torch.tensor(labels)
}
# 使用自定义collate_fn
loader = DataLoader(
dataset=dataset,
batch_size=32,
collate_fn=custom_collate_fn,
num_workers=4
)
5.2 批处理和多线程加载
5.2.1 优化DataLoader性能
import torch
from torch.utils.data import DataLoader
import time
def benchmark_dataloader(loader, num_batches=100):
"""测试DataLoader性能"""
start_time = time.time()
for i, batch in enumerate(loader):
if i >= num_batches:
break
# 模拟处理
_ = batch
elapsed = time.time() - start_time
throughput = num_batches / elapsed
print(f"处理 {num_batches} 个batch用时: {elapsed:.2f}秒")
print(f"吞吐量: {throughput:.2f} batches/秒")
return throughput
# 测试不同配置
configs = [
{'num_workers': 0, 'pin_memory': False},
{'num_workers': 2, 'pin_memory': False},
{'num_workers': 4, 'pin_memory': False},
{'num_workers': 4, 'pin_memory': True},
{'num_workers': 8, 'pin_memory': True},
]
for config in configs:
print(f"\n配置: {config}")
loader = DataLoader(dataset, batch_size=32, **config)
benchmark_dataloader(loader)
5.2.2 预取和缓存
class PrefetchLoader:
"""预取DataLoader,提前将数据加载到GPU"""
def __init__(self, loader):
self.loader = loader
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def __iter__(self):
stream = torch.cuda.Stream()
first = True
for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
next_input = next_input.to(self.device, non_blocking=True)
next_target = next_target.to(self.device, non_blocking=True)
if not first:
yield input, target
else:
first = False
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
yield input, target
def __len__(self):
return len(self.loader)
# 使用
base_loader = DataLoader(dataset, batch_size=32, num_workers=4, pin_memory=True)
prefetch_loader = PrefetchLoader(base_loader)
for images, labels in prefetch_loader:
# 数据已经在GPU上
pass
5.3 内存映射和缓存策略
5.3.1 内存映射数据集
import numpy as np
import os
class MemoryMappedDataset(Dataset):
"""使用内存映射的大规模数据集"""
def __init__(self, data_path, shape, dtype=np.float32):
self.data_path = data_path
self.shape = shape
self.dtype = dtype
# 打开内存映射文件
self.data = np.memmap(
data_path,
dtype=dtype,
mode='r',
shape=shape
)
def __len__(self):
return self.shape[0]
def __getitem__(self, idx):
# 只在需要时加载数据
sample = self.data[idx]
return torch.from_numpy(sample.copy())
# 创建内存映射文件
def create_memmap_dataset(data, save_path):
"""将数据保存为内存映射文件"""
memmap_data = np.memmap(
save_path,
dtype=data.dtype,
mode='w+',
shape=data.shape
)
memmap_data[:] = data[:]
memmap_data.flush()
print(f"保存内存映射文件: {save_path}")
# 示例
# data = np.random.rand(10000, 3, 224, 224).astype(np.float32)
# create_memmap_dataset(data, 'data.mmap')
# dataset = MemoryMappedDataset('data.mmap', shape=(10000, 3, 224, 224))
5.3.2 磁盘缓存
import pickle
import hashlib
from pathlib import Path
class CachedDataset(Dataset):
"""带磁盘缓存的数据集"""
def __init__(self, data_source, cache_dir='cache', transform=None):
self.data_source = data_source
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
self.transform = transform
def _get_cache_path(self, idx):
"""获取缓存文件路径"""
cache_key = hashlib.md5(str(idx).encode()).hexdigest()
return self.cache_dir / f"{cache_key}.pkl"
def __len__(self):
return len(self.data_source)
def __getitem__(self, idx):
cache_path = self._get_cache_path(idx)
# 检查缓存
if cache_path.exists():
with open(cache_path, 'rb') as f:
data = pickle.load(f)
else:
# 加载原始数据
data = self.data_source[idx]
# 应用transform
if self.transform:
data = self.transform(data)
# 保存到缓存
with open(cache_path, 'wb') as f:
pickle.dump(data, f)
return data
6. 完整代码示例
6.1 图像分类数据集完整示例
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import os
import pandas as pd
from sklearn.model_selection import train_test_split
# ============ 1. 数据集类 ============
class ImageClassificationDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 加载图像
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
# 应用变换
if self.transform:
image = self.transform(image)
return image, label
# ============ 2. 数据准备 ============
def prepare_data(data_dir, test_size=0.2, val_size=0.1):
"""准备数据集"""
# 收集所有图像路径和标签
image_paths = []
labels = []
class_names = sorted(os.listdir(data_dir))
class_to_idx = {cls_name: idx for idx, cls_name in enumerate(class_names)}
for cls_name in class_names:
cls_dir = os.path.join(data_dir, cls_name)
if not os.path.isdir(cls_dir):
continue
for img_name in os.listdir(cls_dir):
if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
image_paths.append(os.path.join(cls_dir, img_name))
labels.append(class_to_idx[cls_name])
# 划分数据集
X_train, X_temp, y_train, y_temp = train_test_split(
image_paths, labels, test_size=test_size, stratify=labels, random_state=42
)
X_val, X_test, y_val, y_test = train_test_split(
X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42
)
print(f"训练集: {len(X_train)} 样本")
print(f"验证集: {len(X_val)} 样本")
print(f"测试集: {len(X_test)} 样本")
print(f"类别数: {len(class_names)}")
print(f"类别: {class_names}")
return (X_train, y_train), (X_val, y_val), (X_test, y_test), class_names
# ============ 3. 数据变换 ============
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# ============ 4. 创建数据集和DataLoader ============
def create_dataloaders(data_dir, batch_size=32, num_workers=4):
"""创建DataLoader"""
# 准备数据
(X_train, y_train), (X_val, y_val), (X_test, y_test), class_names = prepare_data(data_dir)
# 创建Dataset
train_dataset = ImageClassificationDataset(X_train, y_train, train_transform)
val_dataset = ImageClassificationDataset(X_val, y_val, val_transform)
test_dataset = ImageClassificationDataset(X_test, y_test, val_transform)
# 创建DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
return train_loader, val_loader, test_loader, class_names
# ============ 5. 使用示例 ============
if __name__ == '__main__':
# 数据目录结构:
# data/
# ├── cat/
# │ ├── img1.jpg
# │ └── img2.jpg
# └── dog/
# ├── img1.jpg
# └── img2.jpg
data_dir = 'data/'
train_loader, val_loader, test_loader, class_names = create_dataloaders(
data_dir,
batch_size=32,
num_workers=4
)
# 查看一个batch
images, labels = next(iter(train_loader))
print(f"\nBatch shape:")
print(f"Images: {images.shape}") # [32, 3, 224, 224]
print(f"Labels: {labels.shape}") # [32]
# 可视化
import matplotlib.pyplot as plt
import numpy as np
def denormalize(tensor):
"""反归一化"""
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
return tensor * std + mean
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
if i >= len(images):
break
img = denormalize(images[i]).permute(1, 2, 0).numpy()
img = np.clip(img, 0, 1)
ax.imshow(img)
ax.set_title(class_names[labels[i]])
ax.axis('off')
plt.tight_layout()
plt.savefig('batch_visualization.png')
print("\n保存可视化: batch_visualization.png")
6.2 文本分类数据集完整示例
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import json
from sklearn.model_selection import train_test_split
import numpy as np
# ============ 1. 数据集类 ============
class TextClassificationDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
# 编码文本
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'label': torch.tensor(label, dtype=torch.long)
}
# ============ 2. 数据加载 ============
def load_data(data_file):
"""从JSONL文件加载数据"""
texts = []
labels = []
label_set = set()
with open(data_file, 'r', encoding='utf-8') as f:
for line in f:
item = json.loads(line)
texts.append(item['text'])
labels.append(item['label'])
label_set.add(item['label'])
# 标签转换
label_to_idx = {label: idx for idx, label in enumerate(sorted(label_set))}
labels = [label_to_idx[label] for label in labels]
return texts, labels, label_to_idx
# ============ 3. 数据准备 ============
def prepare_text_data(data_file, test_size=0.2):
"""准备文本数据"""
texts, labels, label_to_idx = load_data(data_file)
# 划分数据集
X_train, X_temp, y_train, y_temp = train_test_split(
texts, labels, test_size=test_size, stratify=labels, random_state=42
)
X_val, X_test, y_val, y_test = train_test_split(
X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42
)
print(f"训练集: {len(X_train)} 样本")
print(f"验证集: {len(X_val)} 样本")
print(f"测试集: {len(X_test)} 样本")
print(f"类别数: {len(label_to_idx)}")
print(f"标签映射: {label_to_idx}")
return (X_train, y_train), (X_val, y_val), (X_test, y_test), label_to_idx
# ============ 4. 创建DataLoader ============
def create_text_dataloaders(data_file, tokenizer_name='bert-base-uncased',
batch_size=32, max_length=128, num_workers=4):
"""创建文本DataLoader"""
# 加载分词器
tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
# 准备数据
(X_train, y_train), (X_val, y_val), (X_test, y_test), label_to_idx = prepare_text_data(data_file)
# 创建Dataset
train_dataset = TextClassificationDataset(X_train, y_train, tokenizer, max_length)
val_dataset = TextClassificationDataset(X_val, y_val, tokenizer, max_length)
test_dataset = TextClassificationDataset(X_test, y_test, tokenizer, max_length)
# 创建DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
return train_loader, val_loader, test_loader, label_to_idx, tokenizer
# ============ 5. 使用示例 ============
if __name__ == '__main__':
# 创建示例数据
sample_data = [
{"text": "This movie is great!", "label": "positive"},
{"text": "I love this product", "label": "positive"},
{"text": "Terrible experience", "label": "negative"},
{"text": "Worst movie ever", "label": "negative"},
] * 100 # 扩展数据
# 保存为JSONL
with open('sentiment_data.jsonl', 'w', encoding='utf-8') as f:
for item in sample_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# 创建DataLoader
train_loader, val_loader, test_loader, label_to_idx, tokenizer = create_text_dataloaders(
'sentiment_data.jsonl',
batch_size=16,
max_length=64
)
# 查看一个batch
batch = next(iter(train_loader))
print(f"\nBatch keys: {batch.keys()}")
print(f"Input IDs shape: {batch['input_ids'].shape}") # [16, 64]
print(f"Attention mask shape: {batch['attention_mask'].shape}") # [16, 64]
print(f"Labels shape: {batch['label'].shape}") # [16]
# 解码示例
print("\n第一个样本:")
decoded_text = tokenizer.decode(batch['input_ids'][0], skip_special_tokens=True)
print(f"文本: {decoded_text}")
print(f"标签: {batch['label'][0].item()}")