17-多模态大模型
多模态基础
什么是多模态
多模态(Multimodal)是指模型能够理解和处理多种类型的数据,如文本、图像、音频、视频等。多模态大模型通过融合不同模态的信息,实现更全面的理解和生成能力。
常见的模态组合:
- 视觉+语言:图像描述、视觉问答、图文检索
- 音频+语言:语音识别、语音合成、音乐生成
- 视频+语言:视频理解、视频描述、视频问答
- 多模态融合:视频+音频+文本的综合理解
多模态的优势:
- 更丰富的信息:图像包含文字无法完全表达的信息
- 更自然的交互:人类本身就是多模态的
- 更强的泛化能力:跨模态的知识迁移
- 更广的应用场景:从纯文本扩展到真实世界
视觉+语言结合
视觉语言模型(Vision-Language Model, VLM)是最常见的多模态模型。
典型任务:
# 1. 图像描述生成(Image Captioning)
输入: 一张猫的图片
输出: "一只橙色的猫坐在窗台上看着外面"
# 2. 视觉问答(Visual Question Answering, VQA)
输入: 图片 + "图中有几只猫?"
输出: "两只"
# 3. 图文检索(Image-Text Retrieval)
输入: "一只猫在睡觉"
输出: 相关图片列表
# 4. 视觉推理(Visual Reasoning)
输入: 图片 + "如果移走椅子会发生什么?"
输出: "猫会掉到地上"
基础架构:
[图像] ---> 视觉编码器 ---> 视觉特征
|
v
跨模态融合层
^
|
[文本] ---> 文本编码器 ---> 文本特征
融合特征 ---> 解码器 ---> 输出
模态对齐问题
模态对齐(Modal Alignment)是多模态学习的核心挑战,指的是如何让模型理解不同模态之间的对应关系。
对齐的层次:
- 语义对齐:整体语义层面的对应
- 对象对齐:图像中的对象与文本中的名词对应
- 关系对齐:对象间的关系与文本中的动词、介词对应
- 细粒度对齐:图像区域与文本片段的精确对应
对齐方法:
import torch
import torch.nn as nn
class ContrastiveLoss(nn.Module):
"""对比学习损失 - 用于模态对齐"""
def __init__(self, temperature: float = 0.07):
super().__init__()
self.temperature = temperature
def forward(self, image_features, text_features):
"""
image_features: [batch_size, feature_dim]
text_features: [batch_size, feature_dim]
"""
# 归一化
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 计算相似度矩阵
logits = torch.matmul(image_features, text_features.t()) / self.temperature
# 对角线元素是正样本
labels = torch.arange(len(logits), device=logits.device)
# 双向对比损失
loss_i2t = nn.functional.cross_entropy(logits, labels)
loss_t2i = nn.functional.cross_entropy(logits.t(), labels)
loss = (loss_i2t + loss_t2i) / 2
return loss
# 使用示例
batch_size = 32
feature_dim = 512
image_features = torch.randn(batch_size, feature_dim)
text_features = torch.randn(batch_size, feature_dim)
loss_fn = ContrastiveLoss()
loss = loss_fn(image_features, text_features)
print(f"对比损失: {loss.item():.4f}")
对齐训练示例:
class AlignmentTrainer:
"""模态对齐训练器"""
def __init__(self, image_encoder, text_encoder, device='cuda'):
self.image_encoder = image_encoder
self.text_encoder = text_encoder
self.device = device
self.loss_fn = ContrastiveLoss()
def train_step(self, images, texts):
"""训练一步"""
# 编码
image_features = self.image_encoder(images)
text_features = self.text_encoder(texts)
# 计算损失
loss = self.loss_fn(image_features, text_features)
return loss
def compute_similarity(self, image, text):
"""计算图像-文本相似度"""
with torch.no_grad():
image_features = self.image_encoder(image.unsqueeze(0))
text_features = self.text_encoder(text)
# 归一化
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 余弦相似度
similarity = torch.matmul(image_features, text_features.t())
return similarity.item()
CLIP模型
CLIP(Contrastive Language-Image Pre-training)是OpenAI开发的视觉语言模型,通过对比学习实现强大的跨模态理解能力。
对比学习原理
CLIP使用对比学习在大规模图像-文本对上进行训练。
核心思想:
- 匹配的图像-文本对应该有高相似度
- 不匹配的图像-文本对应该有低相似度
训练过程:
import torch
import torch.nn as nn
from torchvision import models
from transformers import BertModel, BertTokenizer
class CLIPModel(nn.Module):
"""CLIP模型实现"""
def __init__(
self,
image_encoder_name: str = "resnet50",
text_encoder_name: str = "bert-base-uncased",
embed_dim: int = 512
):
super().__init__()
# 图像编码器
if image_encoder_name == "resnet50":
resnet = models.resnet50(pretrained=True)
self.image_encoder = nn.Sequential(*list(resnet.children())[:-1])
image_feature_dim = 2048
else:
raise ValueError(f"不支持的图像编码器: {image_encoder_name}")
# 文本编码器
self.text_encoder = BertModel.from_pretrained(text_encoder_name)
text_feature_dim = self.text_encoder.config.hidden_size
# 投影层
self.image_projection = nn.Linear(image_feature_dim, embed_dim)
self.text_projection = nn.Linear(text_feature_dim, embed_dim)
# 可学习的温度参数
self.temperature = nn.Parameter(torch.ones([]) * 0.07)
def encode_image(self, images):
"""编码图像"""
# images: [batch_size, 3, 224, 224]
features = self.image_encoder(images)
features = features.squeeze(-1).squeeze(-1) # [batch_size, 2048]
embeddings = self.image_projection(features) # [batch_size, embed_dim]
# 归一化
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
return embeddings
def encode_text(self, input_ids, attention_mask):
"""编码文本"""
# input_ids: [batch_size, seq_len]
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
features = outputs.last_hidden_state[:, 0, :] # [CLS] token
embeddings = self.text_projection(features) # [batch_size, embed_dim]
# 归一化
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
return embeddings
def forward(self, images, input_ids, attention_mask):
"""前向传播"""
image_embeddings = self.encode_image(images)
text_embeddings = self.encode_text(input_ids, attention_mask)
# 计算相似度矩阵
logits = torch.matmul(image_embeddings, text_embeddings.t()) / self.temperature
return logits, image_embeddings, text_embeddings
# 创建模型
model = CLIPModel()
print(f"参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
训练代码:
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
class ImageTextDataset(Dataset):
"""图像-文本数据集"""
def __init__(self, image_paths, captions, tokenizer, transform=None):
self.image_paths = image_paths
self.captions = captions
self.tokenizer = tokenizer
self.transform = transform or transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 加载图像
image = Image.open(self.image_paths[idx]).convert('RGB')
image = self.transform(image)
# 编码文本
caption = self.captions[idx]
encoding = self.tokenizer(
caption,
padding='max_length',
truncation=True,
max_length=77,
return_tensors='pt'
)
return {
'image': image,
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0)
}
def train_clip(model, dataloader, epochs=10, device='cuda'):
"""训练CLIP模型"""
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
for epoch in range(epochs):
model.train()
total_loss = 0
for batch in dataloader:
images = batch['image'].to(device)
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
# 前向传播
logits, _, _ = model(images, input_ids, attention_mask)
# 计算损失
batch_size = len(images)
labels = torch.arange(batch_size, device=device)
loss_i2t = nn.functional.cross_entropy(logits, labels)
loss_t2i = nn.functional.cross_entropy(logits.t(), labels)
loss = (loss_i2t + loss_t2i) / 2
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")
# 使用示例
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# 模拟数据
image_paths = ['image1.jpg', 'image2.jpg', 'image3.jpg'] * 100
captions = ['a dog', 'a cat', 'a bird'] * 100
dataset = ImageTextDataset(image_paths, captions, tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
model = CLIPModel()
# train_clip(model, dataloader, epochs=10)
图像-文本对齐
CLIP通过对比学习实现图像和文本在同一语义空间中的对齐。
class CLIPInference:
"""CLIP推理"""
def __init__(self, model, tokenizer, transform, device='cuda'):
self.model = model.to(device)
self.model.eval()
self.tokenizer = tokenizer
self.transform = transform
self.device = device
@torch.no_grad()
def compute_similarity(self, image_path, texts):
"""计算图像与多个文本的相似度"""
# 加载和处理图像
image = Image.open(image_path).convert('RGB')
image = self.transform(image).unsqueeze(0).to(self.device)
# 编码图像
image_embedding = self.model.encode_image(image)
# 编码文本
text_encodings = self.tokenizer(
texts,
padding='max_length',
truncation=True,
max_length=77,
return_tensors='pt'
).to(self.device)
text_embeddings = self.model.encode_text(
text_encodings['input_ids'],
text_encodings['attention_mask']
)
# 计算相似度
similarities = torch.matmul(image_embedding, text_embeddings.t())
return similarities.cpu().numpy()[0]
def image_classification(self, image_path, classes):
"""零样本图像分类"""
# 构建提示
prompts = [f"a photo of a {cls}" for cls in classes]
# 计算相似度
similarities = self.compute_similarity(image_path, prompts)
# 应用softmax
probs = torch.softmax(torch.tensor(similarities), dim=0).numpy()
# 排序
sorted_indices = probs.argsort()[::-1]
results = [
{"class": classes[i], "probability": probs[i]}
for i in sorted_indices
]
return results
def image_retrieval(self, query_text, image_paths, top_k=5):
"""文本到图像检索"""
# 编码查询文本
text_encoding = self.tokenizer(
query_text,
padding='max_length',
truncation=True,
max_length=77,
return_tensors='pt'
).to(self.device)
text_embedding = self.model.encode_text(
text_encoding['input_ids'],
text_encoding['attention_mask']
)
# 编码所有图像
image_embeddings = []
for image_path in image_paths:
image = Image.open(image_path).convert('RGB')
image = self.transform(image).unsqueeze(0).to(self.device)
image_embedding = self.model.encode_image(image)
image_embeddings.append(image_embedding)
image_embeddings = torch.cat(image_embeddings, dim=0)
# 计算相似度
similarities = torch.matmul(text_embedding, image_embeddings.t())
# 获取top k
top_k_values, top_k_indices = torch.topk(similarities, k=top_k, dim=1)
results = [
{"image": image_paths[i], "score": top_k_values[0, j].item()}
for j, i in enumerate(top_k_indices[0])
]
return results
# 使用示例
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
inference = CLIPInference(model, tokenizer, transform)
# 零样本分类
classes = ['dog', 'cat', 'bird', 'fish']
results = inference.image_classification('pet.jpg', classes)
print("分类结果:")
for r in results:
print(f" {r['class']}: {r['probability']:.3f}")
# 图像检索
image_paths = ['img1.jpg', 'img2.jpg', 'img3.jpg']
results = inference.image_retrieval("a cute cat", image_paths)
print("\n检索结果:")
for r in results:
print(f" {r['image']}: {r['score']:.3f}")
Zero-shot图像分类
CLIP的一个强大能力是零样本图像分类,无需任何训练数据即可分类新类别。
class ZeroShotClassifier:
"""零样本分类器"""
def __init__(self, clip_model, tokenizer, transform, device='cuda'):
self.model = clip_model.to(device)
self.model.eval()
self.tokenizer = tokenizer
self.transform = transform
self.device = device
# 提示模板
self.prompt_templates = [
"a photo of a {}",
"a picture of a {}",
"an image of a {}",
"{} in the photo",
"a {} in the image",
]
@torch.no_grad()
def classify(self, image_path, classes, use_ensemble=True):
"""分类图像"""
# 加载图像
image = Image.open(image_path).convert('RGB')
image = self.transform(image).unsqueeze(0).to(self.device)
image_embedding = self.model.encode_image(image)
# 生成所有提示
all_prompts = []
if use_ensemble:
for cls in classes:
for template in self.prompt_templates:
all_prompts.append(template.format(cls))
else:
all_prompts = [f"a photo of a {cls}" for cls in classes]
# 编码文本
text_encodings = self.tokenizer(
all_prompts,
padding='max_length',
truncation=True,
max_length=77,
return_tensors='pt'
).to(self.device)
text_embeddings = self.model.encode_text(
text_encodings['input_ids'],
text_encodings['attention_mask']
)
# 计算相似度
similarities = torch.matmul(image_embedding, text_embeddings.t())
if use_ensemble:
# 平均每个类别的多个提示的相似度
n_templates = len(self.prompt_templates)
similarities = similarities.view(1, len(classes), n_templates)
similarities = similarities.mean(dim=2)
# Softmax
probs = torch.softmax(similarities[0] * 100, dim=0)
# 排序
sorted_indices = torch.argsort(probs, descending=True)
results = [
{
"class": classes[i],
"probability": probs[i].item(),
"logit": similarities[0][i].item()
}
for i in sorted_indices
]
return results
# 使用示例
classifier = ZeroShotClassifier(model, tokenizer, transform)
# 细粒度分类
dog_breeds = [
'golden retriever',
'german shepherd',
'bulldog',
'poodle',
'husky'
]
results = classifier.classify('dog.jpg', dog_breeds, use_ensemble=True)
print("狗的品种分类:")
for r in results[:3]:
print(f" {r['class']}: {r['probability']:.3f}")
完整代码实现
"""
完整的CLIP实现,包括训练和推理
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from transformers import BertModel, BertTokenizer
from PIL import Image
import numpy as np
from tqdm import tqdm
import json
class CLIPConfig:
"""CLIP配置"""
def __init__(self):
# 模型配置
self.image_encoder = "resnet50"
self.text_encoder = "bert-base-uncased"
self.embed_dim = 512
# 训练配置
self.batch_size = 128
self.learning_rate = 1e-4
self.weight_decay = 0.2
self.epochs = 30
self.warmup_epochs = 1
# 数据配置
self.image_size = 224
self.max_text_length = 77
class CLIP(nn.Module):
"""完整的CLIP模型"""
def __init__(self, config: CLIPConfig):
super().__init__()
self.config = config
# 图像编码器
self.image_encoder = self._build_image_encoder()
# 文本编码器
self.text_encoder = BertModel.from_pretrained(config.text_encoder)
# 投影层
image_dim = 2048 if config.image_encoder == "resnet50" else 768
text_dim = self.text_encoder.config.hidden_size
self.image_projection = nn.Sequential(
nn.Linear(image_dim, config.embed_dim),
nn.LayerNorm(config.embed_dim)
)
self.text_projection = nn.Sequential(
nn.Linear(text_dim, config.embed_dim),
nn.LayerNorm(config.embed_dim)
)
# 温度参数
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def _build_image_encoder(self):
"""构建图像编码器"""
if self.config.image_encoder == "resnet50":
resnet = models.resnet50(pretrained=True)
return nn.Sequential(*list(resnet.children())[:-1])
else:
raise ValueError(f"不支持的图像编码器: {self.config.image_encoder}")
def encode_image(self, images):
"""编码图像"""
features = self.image_encoder(images)
features = features.view(features.size(0), -1)
embeddings = self.image_projection(features)
# L2归一化
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
return embeddings
def encode_text(self, input_ids, attention_mask):
"""编码文本"""
outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
features = outputs.last_hidden_state[:, 0, :]
embeddings = self.text_projection(features)
# L2归一化
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
return embeddings
def forward(self, images, input_ids, attention_mask):
"""前向传播"""
image_embeddings = self.encode_image(images)
text_embeddings = self.encode_text(input_ids, attention_mask)
# 缩放相似度
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_embeddings @ text_embeddings.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text
class CLIPTrainer:
"""CLIP训练器"""
def __init__(self, model, config: CLIPConfig, device='cuda'):
self.model = model.to(device)
self.config = config
self.device = device
# 优化器
self.optimizer = optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# 学习率调度
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=config.epochs
)
def train_epoch(self, dataloader):
"""训练一个epoch"""
self.model.train()
total_loss = 0
progress_bar = tqdm(dataloader, desc="Training")
for batch in progress_bar:
images = batch['image'].to(self.device)
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
# 前向传播
logits_per_image, logits_per_text = self.model(
images, input_ids, attention_mask
)
# 计算损失
batch_size = len(images)
labels = torch.arange(batch_size, device=self.device)
loss_i2t = nn.functional.cross_entropy(logits_per_image, labels)
loss_t2i = nn.functional.cross_entropy(logits_per_text, labels)
loss = (loss_i2t + loss_t2i) / 2
# 反向传播
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
return total_loss / len(dataloader)
def train(self, train_loader, val_loader=None):
"""完整训练流程"""
best_val_loss = float('inf')
for epoch in range(self.config.epochs):
print(f"\nEpoch {epoch + 1}/{self.config.epochs}")
# 训练
train_loss = self.train_epoch(train_loader)
print(f"Train Loss: {train_loss:.4f}")
# 验证
if val_loader:
val_loss = self.validate(val_loader)
print(f"Val Loss: {val_loss:.4f}")
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
self.save_checkpoint('best_model.pt')
# 更新学习率
self.scheduler.step()
def validate(self, dataloader):
"""验证"""
self.model.eval()
total_loss = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc="Validation"):
images = batch['image'].to(self.device)
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
logits_per_image, logits_per_text = self.model(
images, input_ids, attention_mask
)
batch_size = len(images)
labels = torch.arange(batch_size, device=self.device)
loss_i2t = nn.functional.cross_entropy(logits_per_image, labels)
loss_t2i = nn.functional.cross_entropy(logits_per_text, labels)
loss = (loss_i2t + loss_t2i) / 2
total_loss += loss.item()
return total_loss / len(dataloader)
def save_checkpoint(self, path):
"""保存检查点"""
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'config': self.config
}, path)
print(f"模型已保存到 {path}")
def load_checkpoint(self, path):
"""加载检查点"""
checkpoint = torch.load(path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f"模型已从 {path} 加载")
# 使用示例
def main():
# 配置
config = CLIPConfig()
# 创建模型
model = CLIP(config)
# 创建训练器
trainer = CLIPTrainer(model, config)
# 准备数据(这里需要实际的数据集)
# train_loader = DataLoader(train_dataset, batch_size=config.batch_size)
# val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
# 训练
# trainer.train(train_loader, val_loader)
if __name__ == "__main__":
main()
视觉语言模型
LLaVA架构
LLaVA(Large Language and Vision Assistant)是结合视觉编码器和大语言模型的多模态架构。
整体架构:
输入图像 → 视觉编码器(CLIP ViT) → 视觉特征
↓
投影层
↓
LLM输入
↓
输入文本 → Tokenizer → 文本Token → 与视觉Token拼接
↓
LLM
↓
输出
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
class LLaVA(nn.Module):
"""LLaVA模型实现"""
def __init__(
self,
vision_model_name: str = "openai/clip-vit-large-patch14",
llm_model_name: str = "meta-llama/Llama-2-7b-hf",
projection_dim: int = 4096
):
super().__init__()
# 视觉编码器(冻结)
self.vision_model = CLIPVisionModel.from_pretrained(vision_model_name)
for param in self.vision_model.parameters():
param.requires_grad = False
# 投影层(可训练)
vision_hidden_size = self.vision_model.config.hidden_size
self.vision_projection = nn.Linear(vision_hidden_size, projection_dim)
# LLM(可选择冻结或微调)
self.llm = LlamaForCausalLM.from_pretrained(llm_model_name)
self.tokenizer = LlamaTokenizer.from_pretrained(llm_model_name)
# 特殊token
self.image_token_id = self.tokenizer.convert_tokens_to_ids("<image>")
def encode_image(self, images):
"""编码图像"""
# images: [batch_size, 3, 224, 224]
vision_outputs = self.vision_model(pixel_values=images)
# 取所有patch的特征
image_features = vision_outputs.last_hidden_state # [batch_size, num_patches, hidden_size]
# 投影到LLM的维度
image_features = self.vision_projection(image_features) # [batch_size, num_patches, projection_dim]
return image_features
def forward(self, images, input_ids, attention_mask, labels=None):
"""前向传播"""
batch_size = images.size(0)
# 编码图像
image_features = self.encode_image(images) # [batch_size, num_patches, embed_dim]
# 获取文本embeddings
text_embeds = self.llm.get_input_embeddings()(input_ids) # [batch_size, seq_len, embed_dim]
# 找到<image> token的位置并替换
image_token_mask = (input_ids == self.image_token_id)
# 构建最终的输入embeddings
final_embeds = text_embeds.clone()
for i in range(batch_size):
image_positions = torch.where(image_token_mask[i])[0]
if len(image_positions) > 0:
# 将图像特征插入到<image> token的位置
pos = image_positions[0]
# 这里简化处理,实际需要更复杂的拼接逻辑
final_embeds[i, pos:pos + image_features.size(1)] = image_features[i]
# 通过LLM
outputs = self.llm(
inputs_embeds=final_embeds,
attention_mask=attention_mask,
labels=labels
)
return outputs
class LLaVATrainer:
"""LLaVA训练器"""
def __init__(self, model, device='cuda'):
self.model = model.to(device)
self.device = device
# 只训练投影层和LLM的LoRA
trainable_params = [
self.model.vision_projection.parameters(),
# 这里可以添加LoRA参数
]
self.optimizer = torch.optim.AdamW(
[p for params in trainable_params for p in params],
lr=2e-5
)
def train_step(self, batch):
"""训练一步"""
self.model.train()
images = batch['images'].to(self.device)
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
# 前向传播
outputs = self.model(
images=images,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs.loss
# 反向传播
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
视觉编码器(CLIP ViT)
Vision Transformer (ViT) 将图像分割成patches并用Transformer处理。
class VisionTransformer(nn.Module):
"""Vision Transformer实现"""
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
num_channels: int = 3,
hidden_size: int = 768,
num_layers: int = 12,
num_heads: int = 12,
mlp_dim: int = 3072
):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
# Patch embedding
self.patch_embedding = nn.Conv2d(
num_channels,
hidden_size,
kernel_size=patch_size,
stride=patch_size
)
# Position embedding
self.position_embedding = nn.Parameter(
torch.randn(1, self.num_patches + 1, hidden_size)
)
# CLS token
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_size))
# Transformer
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=mlp_dim,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
# Layer norm
self.layer_norm = nn.LayerNorm(hidden_size)
def forward(self, images):
"""前向传播"""
batch_size = images.size(0)
# Patch embedding
patches = self.patch_embedding(images) # [B, hidden_size, H/P, W/P]
patches = patches.flatten(2).transpose(1, 2) # [B, num_patches, hidden_size]
# 添加CLS token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, patches], dim=1) # [B, num_patches+1, hidden_size]
# 添加position embedding
x = x + self.position_embedding
# Transformer
x = self.transformer(x)
# Layer norm
x = self.layer_norm(x)
return x
# 测试
vit = VisionTransformer()
images = torch.randn(2, 3, 224, 224)
features = vit(images)
print(f"输出形状: {features.shape}") # [2, 197, 768]
投影层设计
投影层将视觉特征映射到LLM的输入空间。
class MLPProjector(nn.Module):
"""MLP投影层"""
def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = None):
super().__init__()
if hidden_dim is None:
hidden_dim = (input_dim + output_dim) // 2
self.projector = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.projector(x)
class QFormerProjector(nn.Module):
"""Q-Former投影层(参考BLIP-2)"""
def __init__(
self,
vision_dim: int,
llm_dim: int,
num_query_tokens: int = 32,
num_layers: int = 6
):
super().__init__()
# Query tokens
self.query_tokens = nn.Parameter(torch.randn(1, num_query_tokens, vision_dim))
# Cross-attention layers
self.cross_attention_layers = nn.ModuleList([
nn.MultiheadAttention(vision_dim, num_heads=8, batch_first=True)
for _ in range(num_layers)
])
# 最终投影
self.final_projection = nn.Linear(vision_dim, llm_dim)
def forward(self, vision_features):
"""
vision_features: [batch_size, num_patches, vision_dim]
"""
batch_size = vision_features.size(0)
# 扩展query tokens
queries = self.query_tokens.expand(batch_size, -1, -1)
# 多层cross-attention
for layer in self.cross_attention_layers:
queries, _ = layer(queries, vision_features, vision_features)
# 投影到LLM维度
output = self.final_projection(queries)
return output
# 对比两种投影层
vision_features = torch.randn(2, 196, 1024) # CLIP ViT-L features
# MLP投影
mlp_proj = MLPProjector(input_dim=1024, output_dim=4096)
mlp_output = mlp_proj(vision_features)
print(f"MLP输出: {mlp_output.shape}") # [2, 196, 4096]
# Q-Former投影
qformer_proj = QFormerProjector(vision_dim=1024, llm_dim=4096, num_query_tokens=32)
qformer_output = qformer_proj(vision_features)
print(f"Q-Former输出: {qformer_output.shape}") # [2, 32, 4096]
LLM解码器
class MultimodalLLM(nn.Module):
"""多模态LLM包装器"""
def __init__(self, llm, vision_encoder, projector):
super().__init__()
self.vision_encoder = vision_encoder
self.projector = projector
self.llm = llm
def prepare_inputs(self, images, text_tokens):
"""准备多模态输入"""
# 编码图像
vision_features = self.vision_encoder(images)
# 投影
vision_embeds = self.projector(vision_features)
# 获取文本embeddings
text_embeds = self.llm.get_input_embeddings()(text_tokens)
# 拼接(图像在前,文本在后)
combined_embeds = torch.cat([vision_embeds, text_embeds], dim=1)
return combined_embeds
def generate(self, images, text_tokens, max_length=100, **kwargs):
"""生成文本"""
# 准备输入
inputs_embeds = self.prepare_inputs(images, text_tokens)
# 生成
outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
max_length=max_length,
**kwargs
)
return outputs
训练流程
class MultimodalTrainingPipeline:
"""多模态训练流程"""
def __init__(self, model, tokenizer, device='cuda'):
self.model = model.to(device)
self.tokenizer = tokenizer
self.device = device
# 配置优化器 - 分层学习率
vision_params = list(model.vision_encoder.parameters())
projector_params = list(model.projector.parameters())
llm_params = list(model.llm.parameters())
self.optimizer = torch.optim.AdamW([
{'params': vision_params, 'lr': 1e-5}, # 视觉编码器较小学习率
{'params': projector_params, 'lr': 2e-4}, # 投影层较大学习率
{'params': llm_params, 'lr': 2e-5}, # LLM中等学习率
])
def prepare_batch(self, batch):
"""准备批次数据"""
images = batch['images'].to(self.device)
questions = batch['questions']
answers = batch['answers']
# 构建对话格式
prompts = []
for q, a in zip(questions, answers):
prompt = f"<image>\nQuestion: {q}\nAnswer: {a}"
prompts.append(prompt)
# Tokenize
encodings = self.tokenizer(
prompts,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
input_ids = encodings['input_ids'].to(self.device)
attention_mask = encodings['attention_mask'].to(self.device)
# Labels(用于计算loss)
labels = input_ids.clone()
# 只计算answer部分的loss
# 这里需要找到answer开始的位置
labels[:, :questions[0].count(' ')] = -100 # 忽略question部分
return images, input_ids, attention_mask, labels
def train_epoch(self, dataloader):
"""训练一个epoch"""
self.model.train()
total_loss = 0
for batch in tqdm(dataloader, desc="Training"):
images, input_ids, attention_mask, labels = self.prepare_batch(batch)
# 前向传播
outputs = self.model(
images=images,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs.loss
# 反向传播
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
def evaluate(self, dataloader):
"""评估"""
self.model.eval()
total_loss = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
images, input_ids, attention_mask, labels = self.prepare_batch(batch)
outputs = self.model(
images=images,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
total_loss += outputs.loss.item()
return total_loss / len(dataloader)
其他多模态模型
GPT-4V
GPT-4 Vision是OpenAI的多模态模型,支持图像理解。
import openai
import base64
class GPT4VisionAPI:
"""GPT-4 Vision API包装"""
def __init__(self, api_key: str):
openai.api_key = api_key
def encode_image(self, image_path: str) -> str:
"""编码图像为base64"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def analyze_image(self, image_path: str, prompt: str, max_tokens: int = 300):
"""分析图像"""
base64_image = self.encode_image(image_path)
response = openai.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
max_tokens=max_tokens
)
return response.choices[0].message.content
def multi_image_analysis(self, image_paths: List[str], prompt: str):
"""多图像分析"""
content = [{"type": "text", "text": prompt}]
for image_path in image_paths:
base64_image = self.encode_image(image_path)
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
response = openai.chat.completions.create(
model="gpt-4-vision-preview",
messages=[{"role": "user", "content": content}],
max_tokens=500
)
return response.choices[0].message.content
# 使用示例
gpt4v = GPT4VisionAPI(api_key="your-api-key")
# 单图分析
result = gpt4v.analyze_image(
"photo.jpg",
"请详细描述这张图片中的内容"
)
print(result)
# 多图比较
result = gpt4v.multi_image_analysis(
["photo1.jpg", "photo2.jpg"],
"比较这两张图片的异同"
)
print(result)
Gemini
Google的Gemini模型支持原生多模态。
import google.generativeai as genai
from PIL import Image
class GeminiAPI:
"""Gemini API包装"""
def __init__(self, api_key: str):
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel('gemini-pro-vision')
def analyze_image(self, image_path: str, prompt: str):
"""分析图像"""
image = Image.open(image_path)
response = self.model.generate_content([prompt, image])
return response.text
def chat_with_image(self, image_path: str, messages: List[Dict]):
"""与图像对话"""
image = Image.open(image_path)
chat = self.model.start_chat(history=[])
# 第一条消息包含图像
first_message = messages[0]['content']
response = chat.send_message([first_message, image])
results = [{"role": "model", "content": response.text}]
# 后续对话
for message in messages[1:]:
response = chat.send_message(message['content'])
results.append({"role": "model", "content": response.text})
return results
# 使用示例
gemini = GeminiAPI(api_key="your-api-key")
# 图像分析
result = gemini.analyze_image(
"document.jpg",
"提取这份文档中的所有文字"
)
print(result)
# 对话
messages = [
{"content": "这张图片中有什么?"},
{"content": "图中的人在做什么?"},
{"content": "背景是哪里?"}
]
results = gemini.chat_with_image("photo.jpg", messages)
for r in results:
print(f"{r['role']}: {r['content']}")
Qwen-VL
阿里的通义千问视觉语言模型。
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
class QwenVL:
"""Qwen-VL模型"""
def __init__(self, model_path: str = "Qwen/Qwen-VL-Chat"):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
trust_remote_code=True
).eval()
self.model.generation_config = GenerationConfig.from_pretrained(
model_path,
trust_remote_code=True
)
def chat(self, image_path: str, query: str):
"""单轮对话"""
query_with_image = f'<img>{image_path}</img>{query}'
response, history = self.model.chat(
self.tokenizer,
query=query_with_image,
history=None
)
return response
def multi_turn_chat(self, image_path: str, queries: List[str]):
"""多轮对话"""
# 第一轮包含图像
first_query = f'<img>{image_path}</img>{queries[0]}'
response, history = self.model.chat(
self.tokenizer,
query=first_query,
history=None
)
results = [{"query": queries[0], "response": response}]
# 后续轮次
for query in queries[1:]:
response, history = self.model.chat(
self.tokenizer,
query=query,
history=history
)
results.append({"query": query, "response": response})
return results
# 使用示例
qwen_vl = QwenVL()
# 单轮对话
response = qwen_vl.chat("image.jpg", "描述这张图片")
print(response)
# 多轮对话
queries = [
"图中有什么?",
"它们在做什么?",
"这可能是什么场景?"
]
results = qwen_vl.multi_turn_chat("image.jpg", queries)
for r in results:
print(f"Q: {r['query']}")
print(f"A: {r['response']}\n")
CogVLM
清华的CogVLM模型。
class CogVLM:
"""CogVLM模型"""
def __init__(self, model_path: str = "THUDM/cogvlm-chat-hf"):
from transformers import AutoModelForCausalLM, LlamaTokenizer
self.tokenizer = LlamaTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to('cuda').eval()
def generate(self, image_path: str, query: str):
"""生成响应"""
from PIL import Image
image = Image.open(image_path).convert('RGB')
inputs = self.model.build_conversation_input_ids(
self.tokenizer,
query=query,
history=[],
images=[image]
)
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
}
gen_kwargs = {"max_length": 2048, "do_sample": False}
with torch.no_grad():
outputs = self.model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = self.tokenizer.decode(outputs[0])
return response
# 使用示例
cogvlm = CogVLM()
response = cogvlm.generate("image.jpg", "请描述图片内容")
print(response)
多模态应用
图像描述生成
class ImageCaptioning:
"""图像描述生成"""
def __init__(self, model, processor, device='cuda'):
self.model = model.to(device)
self.processor = processor
self.device = device
def generate_caption(self, image_path: str, num_beams: int = 5):
"""生成图像描述"""
from PIL import Image
image = Image.open(image_path).convert('RGB')
# 预处理
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# 生成
with torch.no_grad():
outputs = self.model.generate(
**inputs,
num_beams=num_beams,
max_length=100,
early_stopping=True
)
# 解码
caption = self.processor.decode(outputs[0], skip_special_tokens=True)
return caption
def generate_diverse_captions(self, image_path: str, num_captions: int = 5):
"""生成多样化的描述"""
from PIL import Image
image = Image.open(image_path).convert('RGB')
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
captions = []
with torch.no_grad():
for _ in range(num_captions):
outputs = self.model.generate(
**inputs,
do_sample=True,
temperature=0.7,
top_p=0.9,
max_length=100
)
caption = self.processor.decode(outputs[0], skip_special_tokens=True)
captions.append(caption)
return captions
# 使用BLIP模型
from transformers import BlipProcessor, BlipForConditionalGeneration
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
captioning = ImageCaptioning(model, processor)
# 生成单个描述
caption = captioning.generate_caption("photo.jpg")
print(f"描述: {caption}")
# 生成多个描述
captions = captioning.generate_diverse_captions("photo.jpg", num_captions=3)
for i, caption in enumerate(captions, 1):
print(f"描述{i}: {caption}")
视觉问答(VQA)
class VisualQuestionAnswering:
"""视觉问答系统"""
def __init__(self, model, processor, device='cuda'):
self.model = model.to(device)
self.processor = processor
self.device = device
def answer_question(self, image_path: str, question: str):
"""回答问题"""
from PIL import Image
image = Image.open(image_path).convert('RGB')
# 预处理
inputs = self.processor(
images=image,
text=question,
return_tensors="pt"
).to(self.device)
# 生成答案
with torch.no_grad():
outputs = self.model.generate(**inputs, max_length=50)
# 解码
answer = self.processor.decode(outputs[0], skip_special_tokens=True)
return answer
def batch_qa(self, image_path: str, questions: List[str]):
"""批量问答"""
answers = {}
for question in questions:
answer = self.answer_question(image_path, question)
answers[question] = answer
return answers
# 使用BLIP-2
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
vqa = VisualQuestionAnswering(model, processor)
# 单个问题
answer = vqa.answer_question("scene.jpg", "图中有几个人?")
print(f"回答: {answer}")
# 多个问题
questions = [
"图中有什么?",
"天气怎么样?",
"这是什么地方?",
"图中的人在做什么?"
]
answers = vqa.batch_qa("scene.jpg", questions)
for q, a in answers.items():
print(f"Q: {q}")
print(f"A: {a}\n")
OCR和文档理解
import pytesseract
from PIL import Image
import cv2
import numpy as np
class DocumentUnderstanding:
"""文档理解系统"""
def __init__(self):
# OCR配置
self.tesseract_config = '--oem 3 --psm 6'
def preprocess_image(self, image_path: str):
"""图像预处理"""
image = cv2.imread(image_path)
# 灰度化
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 二值化
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# 降噪
denoised = cv2.medianBlur(binary, 3)
return denoised
def extract_text(self, image_path: str) -> str:
"""提取文本"""
# 预处理
processed = self.preprocess_image(image_path)
# OCR
text = pytesseract.image_to_string(processed, config=self.tesseract_config)
return text
def extract_structured_data(self, image_path: str) -> Dict:
"""提取结构化数据"""
# OCR with box info
image = Image.open(image_path)
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
# 组织数据
structured_data = {
'text': [],
'boxes': [],
'confidences': []
}
n_boxes = len(data['text'])
for i in range(n_boxes):
if int(data['conf'][i]) > 60: # 置信度阈值
text = data['text'][i]
if text.strip():
structured_data['text'].append(text)
structured_data['boxes'].append({
'x': data['left'][i],
'y': data['top'][i],
'w': data['width'][i],
'h': data['height'][i]
})
structured_data['confidences'].append(data['conf'][i])
return structured_data
def understand_document(self, image_path: str, vqa_model):
"""文档理解(结合OCR和VQA)"""
# 提取文本
text = self.extract_text(image_path)
# 使用VQA理解文档类型和内容
questions = [
"这是什么类型的文档?",
"文档的主题是什么?",
"有什么重要信息?"
]
answers = {}
for question in questions:
answer = vqa_model.answer_question(image_path, question)
answers[question] = answer
return {
'extracted_text': text,
'understanding': answers
}
# 使用示例
doc_understanding = DocumentUnderstanding()
# 提取文本
text = doc_understanding.extract_text("document.jpg")
print("提取的文本:")
print(text)
# 提取结构化数据
structured = doc_understanding.extract_structured_data("document.jpg")
print(f"\n检测到 {len(structured['text'])} 个文本块")
# 文档理解
understanding = doc_understanding.understand_document("document.jpg", vqa)
print("\n文档理解:")
for question, answer in understanding['understanding'].items():
print(f"{question}: {answer}")
视频理解
import cv2
from typing import List
class VideoUnderstanding:
"""视频理解系统"""
def __init__(self, vqa_model, captioning_model):
self.vqa_model = vqa_model
self.captioning_model = captioning_model
def extract_frames(self, video_path: str, num_frames: int = 10) -> List[np.ndarray]:
"""提取关键帧"""
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 均匀采样
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frames.append(frame)
cap.release()
return frames
def describe_video(self, video_path: str, num_frames: int = 10) -> List[str]:
"""描述视频内容"""
# 提取帧
frames = self.extract_frames(video_path, num_frames)
# 为每一帧生成描述
descriptions = []
for i, frame in enumerate(frames):
# 保存临时图像
temp_path = f"temp_frame_{i}.jpg"
cv2.imwrite(temp_path, frame)
# 生成描述
caption = self.captioning_model.generate_caption(temp_path)
descriptions.append(caption)
# 删除临时文件
import os
os.remove(temp_path)
return descriptions
def answer_video_question(self, video_path: str, question: str, num_frames: int = 10):
"""回答关于视频的问题"""
# 提取帧
frames = self.extract_frames(video_path, num_frames)
# 为每一帧回答问题
answers = []
for i, frame in enumerate(frames):
temp_path = f"temp_frame_{i}.jpg"
cv2.imwrite(temp_path, frame)
answer = self.vqa_model.answer_question(temp_path, question)
answers.append(answer)
import os
os.remove(temp_path)
# 聚合答案(简单的多数投票)
from collections import Counter
most_common = Counter(answers).most_common(1)[0][0]
return {
'final_answer': most_common,
'frame_answers': answers
}
def summarize_video(self, video_path: str, num_frames: int = 10) -> str:
"""总结视频内容"""
descriptions = self.describe_video(video_path, num_frames)
# 使用LLM总结
from openai import OpenAI
client = OpenAI()
frame_descriptions = "\n".join([
f"帧{i+1}: {desc}"
for i, desc in enumerate(descriptions)
])
prompt = f"""基于视频的关键帧描述,生成视频的整体摘要。
帧描述:
{frame_descriptions}
视频摘要:"""
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}]
)
return response.choices[0].message.content
# 使用示例
video_understanding = VideoUnderstanding(vqa, captioning)
# 描述视频
descriptions = video_understanding.describe_video("video.mp4", num_frames=8)
print("视频帧描述:")
for i, desc in enumerate(descriptions, 1):
print(f" 帧{i}: {desc}")
# 视频问答
result = video_understanding.answer_video_question(
"video.mp4",
"视频中的人在做什么?",
num_frames=8
)
print(f"\n问答结果: {result['final_answer']}")
# 视频摘要
summary = video_understanding.summarize_video("video.mp4", num_frames=10)
print(f"\n视频摘要:\n{summary}")
实战项目
构建图像问答系统
"""
完整的图像问答系统实现
支持:图像上传、问题解答、对话历史
"""
from flask import Flask, request, jsonify
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from PIL import Image
import io
import base64
app = Flask(__name__)
class ImageQASystem:
"""图像问答系统"""
def __init__(self, model_name="Salesforce/blip2-opt-2.7b"):
print("加载模型...")
self.processor = Blip2Processor.from_pretrained(model_name)
self.model = Blip2ForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16
).to('cuda')
print("模型加载完成")
# 会话存储
self.sessions = {}
def create_session(self, session_id: str, image: Image.Image):
"""创建会话"""
self.sessions[session_id] = {
'image': image,
'history': []
}
def answer(self, session_id: str, question: str) -> str:
"""回答问题"""
if session_id not in self.sessions:
return "会话不存在"
session = self.sessions[session_id]
image = session['image']
# 构建上下文(包含历史)
context = ""
for qa in session['history'][-3:]: # 最近3轮对话
context += f"Q: {qa['question']}\nA: {qa['answer']}\n"
full_question = f"{context}Q: {question}\nA:"
# 处理
inputs = self.processor(
images=image,
text=full_question,
return_tensors="pt"
).to('cuda', torch.float16)
# 生成
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=100,
num_beams=5
)
answer = self.processor.decode(outputs[0], skip_special_tokens=True)
# 保存历史
session['history'].append({
'question': question,
'answer': answer
})
return answer
# 初始化系统
qa_system = ImageQASystem()
@app.route('/upload', methods=['POST'])
def upload_image():
"""上传图像"""
try:
# 获取图像
file = request.files['image']
image = Image.open(file.stream).convert('RGB')
# 创建会话
session_id = str(hash(file.filename))
qa_system.create_session(session_id, image)
return jsonify({
'success': True,
'session_id': session_id,
'message': '图像上传成功'
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 400
@app.route('/ask', methods=['POST'])
def ask_question():
"""提问"""
try:
data = request.json
session_id = data['session_id']
question = data['question']
answer = qa_system.answer(session_id, question)
return jsonify({
'success': True,
'answer': answer
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 400
@app.route('/history', methods=['GET'])
def get_history():
"""获取对话历史"""
try:
session_id = request.args.get('session_id')
if session_id not in qa_system.sessions:
return jsonify({
'success': False,
'error': '会话不存在'
}), 404
history = qa_system.sessions[session_id]['history']
return jsonify({
'success': True,
'history': history
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 400
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
# 客户端示例
"""
import requests
# 上传图像
with open('image.jpg', 'rb') as f:
response = requests.post(
'http://localhost:5000/upload',
files={'image': f}
)
session_id = response.json()['session_id']
# 提问
questions = [
"图中有什么?",
"它们的颜色是什么?",
"这是在哪里拍的?"
]
for question in questions:
response = requests.post(
'http://localhost:5000/ask',
json={'session_id': session_id, 'question': question}
)
print(f"Q: {question}")
print(f"A: {response.json()['answer']}\n")
# 获取历史
response = requests.get(
f'http://localhost:5000/history?session_id={session_id}'
)
print("对话历史:")
for qa in response.json()['history']:
print(f"Q: {qa['question']}")
print(f"A: {qa['answer']}\n")
"""
文档解析助手
"""
智能文档解析助手
功能:OCR、表格提取、信息抽取、问答
"""
import pytesseract
from PIL import Image
import cv2
import numpy as np
from typing import Dict, List
import pandas as pd
import json
class DocumentParser:
"""文档解析助手"""
def __init__(self, vqa_model=None):
self.vqa_model = vqa_model
def parse_document(self, image_path: str) -> Dict:
"""解析文档"""
result = {
'text': self.extract_text(image_path),
'tables': self.extract_tables(image_path),
'layout': self.analyze_layout(image_path),
'entities': self.extract_entities(image_path)
}
return result
def extract_text(self, image_path: str) -> str:
"""提取文本"""
image = cv2.imread(image_path)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
text = pytesseract.image_to_string(binary, lang='chi_sim+eng')
return text
def extract_tables(self, image_path: str) -> List[pd.DataFrame]:
"""提取表格"""
# 这里使用简化的方法,实际应该使用专门的表格检测模型
image = cv2.imread(image_path)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 检测水平和垂直线
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
horizontal_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel)
vertical_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel)
# 组合
table_mask = cv2.add(horizontal_lines, vertical_lines)
# 查找轮廓
contours, _ = cv2.findContours(table_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
tables = []
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
if w > 100 and h > 100: # 过滤小的区域
table_region = gray[y:y+h, x:x+w]
# 这里应该进一步处理提取表格数据
# 简化处理,返回区域坐标
tables.append({
'bbox': (x, y, w, h),
'data': None # 实际应该提取表格数据
})
return tables
def analyze_layout(self, image_path: str) -> Dict:
"""分析布局"""
# 使用OCR获取文字区域
image = Image.open(image_path)
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
# 聚类文字区域
blocks = {}
for i in range(len(data['text'])):
if int(data['conf'][i]) > 60:
block_num = data['block_num'][i]
if block_num not in blocks:
blocks[block_num] = {
'text': [],
'bbox': [data['left'][i], data['top'][i],
data['width'][i], data['height'][i]]
}
blocks[block_num]['text'].append(data['text'][i])
layout = {
'num_blocks': len(blocks),
'blocks': [
{
'id': k,
'text': ' '.join(v['text']),
'bbox': v['bbox']
}
for k, v in blocks.items()
]
}
return layout
def extract_entities(self, image_path: str) -> Dict:
"""提取实体(使用VQA)"""
if not self.vqa_model:
return {}
# 定义要提取的实体
entity_questions = {
'date': '文档的日期是什么?',
'title': '文档的标题是什么?',
'author': '文档的作者是谁?',
'organization': '文档来自哪个组织?',
'amount': '文档中提到的金额是多少?'
}
entities = {}
for entity_type, question in entity_questions.items():
answer = self.vqa_model.answer_question(image_path, question)
entities[entity_type] = answer
return entities
def answer_question(self, image_path: str, question: str) -> str:
"""回答关于文档的问题"""
# 结合OCR文本和VQA
text = self.extract_text(image_path)
# 先在文本中搜索
# 这里简化处理,实际应该使用更复杂的检索方法
if self.vqa_model:
answer = self.vqa_model.answer_question(image_path, question)
else:
answer = "无法回答(需要VQA模型)"
return answer
# 使用示例
parser = DocumentParser(vqa_model=vqa)
# 解析文档
result = parser.parse_document("contract.jpg")
print("文档解析结果:")
print(f"文本长度: {len(result['text'])} 字符")
print(f"检测到 {result['layout']['num_blocks']} 个文本块")
print(f"检测到 {len(result['tables'])} 个表格")
print("\n提取的实体:")
for entity_type, value in result['entities'].items():
print(f" {entity_type}: {value}")
# 文档问答
questions = [
"合同的甲方是谁?",
"合同金额是多少?",
"合同有效期到什么时候?"
]
print("\n文档问答:")
for question in questions:
answer = parser.answer_question("contract.jpg", question)
print(f"Q: {question}")
print(f"A: {answer}\n")