第14章 AI编程助手原理与实现
AI编程助手已经成为现代软件开发的重要工具,从GitHub Copilot到Claude Code,这些工具极大地提升了开发效率。本章将深入探讨AI编程助手的核心原理、代码理解技术、实现方法,以及如何开发VSCode插件。
14.1 AI编程助手概述
14.1.1 编程助手的发展历程
AI编程助手的发展经历了几个重要阶段:
1. 传统代码补全时代(2000-2015)
- 基于规则的语法补全
- IDE内置的智能感知(IntelliSense)
- 依赖静态分析和类型信息
2. 统计学习时代(2015-2020)
- 基于n-gram的代码补全
- 使用RNN/LSTM预测下一个token
- TabNine等工具的出现
3. 预训练模型时代(2020-2021)
- GPT-2/GPT-3应用于代码生成
- CodeBERT、GraphCodeBERT等专用模型
- 代码-文本双向理解
4. 大模型编程助手时代(2021-至今)
- GitHub Copilot(基于Codex)
- Claude Code、Cursor等
- 支持多文件上下文、自然语言交互
14.1.2 核心功能与能力
现代AI编程助手通常具备以下核心能力:
class AICodeAssistant:
"""AI编程助手核心能力抽象"""
def __init__(self):
self.capabilities = {
# 代码生成能力
'code_generation': {
'inline_completion': True, # 行内补全
'multi_line_suggestion': True, # 多行建议
'function_generation': True, # 函数生成
'class_generation': True, # 类生成
},
# 代码理解能力
'code_understanding': {
'syntax_parsing': True, # 语法解析
'semantic_analysis': True, # 语义分析
'call_graph': True, # 调用图分析
'dependency_tracking': True, # 依赖追踪
},
# 代码编辑能力
'code_editing': {
'refactoring': True, # 重构
'bug_fixing': True, # 修复
'optimization': True, # 优化
'documentation': True, # 文档生成
},
# 交互能力
'interaction': {
'chat': True, # 对话交互
'context_aware': True, # 上下文感知
'multi_file': True, # 多文件理解
'terminal_integration': True, # 终端集成
}
}
def generate_code(self, context, intent):
"""代码生成接口"""
raise NotImplementedError
def understand_code(self, code, language):
"""代码理解接口"""
raise NotImplementedError
def edit_code(self, code, instruction):
"""代码编辑接口"""
raise NotImplementedError
14.1.3 技术架构
典型的AI编程助手技术架构:
┌─────────────────────────────────────────────────────────────┐
│ 用户界面层 │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ IDE插件 │ │ Chat UI │ │ Terminal │ │ Web UI │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ 编辑器集成层 │
│ ┌────────────────┐ ┌────────────────┐ ┌──────────────┐ │
│ │ LSP Client │ │ Event Capture │ │ Diff Engine │ │
│ └────────────────┘ └────────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ 上下文管理层 │
│ ┌────────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 代码索引 │ │ 文件系统 │ │ Git集成 │ │ 依赖图 │ │
│ └────────────┘ └──────────┘ └──────────┘ └──────────┘ │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ AI推理层 │
│ ┌────────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 代码模型 │ │ 检索增强 │ │ 工具调用 │ │ 推理引擎 │ │
│ └────────────┘ └──────────┘ └──────────┘ └──────────┘ │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ 基础设施层 │
│ ┌────────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 模型服务 │ │ 向量数据库│ │ 缓存系统 │ │ 监控日志 │ │
│ └────────────┘ └──────────┘ └──────────┘ └──────────┘ │
└─────────────────────────────────────────────────────────────┘
14.2 代码理解技术
14.2.1 代码解析与AST
抽象语法树(AST)是代码理解的基础:
import ast
import json
from typing import Dict, List, Any
class CodeParser:
"""代码解析器 - 将代码转换为结构化表示"""
def __init__(self):
self.ast_cache = {}
def parse_python(self, code: str) -> Dict[str, Any]:
"""解析Python代码"""
try:
tree = ast.parse(code)
return {
'success': True,
'ast': tree,
'structure': self._extract_structure(tree),
'symbols': self._extract_symbols(tree),
'dependencies': self._extract_dependencies(tree)
}
except SyntaxError as e:
return {
'success': False,
'error': str(e),
'line': e.lineno,
'offset': e.offset
}
def _extract_structure(self, tree: ast.AST) -> Dict[str, List]:
"""提取代码结构"""
structure = {
'classes': [],
'functions': [],
'imports': [],
'global_vars': []
}
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
structure['classes'].append({
'name': node.name,
'line': node.lineno,
'methods': [m.name for m in node.body if isinstance(m, ast.FunctionDef)],
'bases': [self._get_name(base) for base in node.bases],
'decorators': [self._get_name(d) for d in node.decorator_list]
})
elif isinstance(node, ast.FunctionDef):
# 只记录顶层函数
if not self._is_nested(node, tree):
structure['functions'].append({
'name': node.name,
'line': node.lineno,
'args': [arg.arg for arg in node.args.args],
'returns': self._get_annotation(node.returns),
'decorators': [self._get_name(d) for d in node.decorator_list],
'is_async': isinstance(node, ast.AsyncFunctionDef)
})
elif isinstance(node, (ast.Import, ast.ImportFrom)):
structure['imports'].append(self._extract_import(node))
return structure
def _extract_symbols(self, tree: ast.AST) -> Dict[str, List[str]]:
"""提取符号表"""
symbols = {
'defined': [], # 定义的符号
'used': [], # 使用的符号
'assigned': [], # 赋值的符号
'called': [] # 调用的符号
}
for node in ast.walk(tree):
if isinstance(node, ast.Name):
if isinstance(node.ctx, ast.Store):
symbols['assigned'].append(node.id)
elif isinstance(node.ctx, ast.Load):
symbols['used'].append(node.id)
elif isinstance(node, ast.Call):
func_name = self._get_name(node.func)
if func_name:
symbols['called'].append(func_name)
elif isinstance(node, (ast.FunctionDef, ast.ClassDef)):
symbols['defined'].append(node.name)
# 去重
for key in symbols:
symbols[key] = list(set(symbols[key]))
return symbols
def _extract_dependencies(self, tree: ast.AST) -> List[str]:
"""提取依赖关系"""
dependencies = []
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
dependencies.append(alias.name.split('.')[0])
elif isinstance(node, ast.ImportFrom):
if node.module:
dependencies.append(node.module.split('.')[0])
return list(set(dependencies))
def _get_name(self, node) -> str:
"""获取节点名称"""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
value = self._get_name(node.value)
return f"{value}.{node.attr}" if value else node.attr
elif isinstance(node, ast.Call):
return self._get_name(node.func)
return ""
def _get_annotation(self, node) -> str:
"""获取类型注解"""
if node is None:
return None
return ast.unparse(node)
def _is_nested(self, func_node, tree) -> bool:
"""检查函数是否嵌套在类或其他函数中"""
for node in ast.walk(tree):
if isinstance(node, (ast.ClassDef, ast.FunctionDef)):
if func_node in ast.walk(node) and func_node != node:
return True
return False
def _extract_import(self, node) -> Dict[str, Any]:
"""提取导入信息"""
if isinstance(node, ast.Import):
return {
'type': 'import',
'names': [alias.name for alias in node.names],
'aliases': {alias.name: alias.asname for alias in node.names if alias.asname}
}
else: # ImportFrom
return {
'type': 'from_import',
'module': node.module,
'names': [alias.name for alias in node.names],
'level': node.level
}
# 使用示例
parser = CodeParser()
sample_code = '''
import os
from typing import List, Dict
class DataProcessor:
"""数据处理器"""
def __init__(self, config: Dict):
self.config = config
def process(self, data: List[str]) -> List[str]:
"""处理数据"""
result = []
for item in data:
processed = self._transform(item)
result.append(processed)
return result
def _transform(self, item: str) -> str:
return item.upper()
def main():
processor = DataProcessor({'debug': True})
data = ['hello', 'world']
result = processor.process(data)
print(result)
'''
parsed = parser.parse_python(sample_code)
if parsed['success']:
print("代码结构:")
print(json.dumps(parsed['structure'], indent=2, ensure_ascii=False))
print("\n符号表:")
print(json.dumps(parsed['symbols'], indent=2, ensure_ascii=False))
print("\n依赖:")
print(parsed['dependencies'])
14.2.2 语义理解与代码嵌入
使用预训练模型进行代码语义理解:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from typing import List, Dict
import numpy as np
class CodeEmbedder:
"""代码嵌入器 - 将代码转换为向量表示"""
def __init__(self, model_name: str = "microsoft/codebert-base"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
# 使用GPU(如果可用)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
def embed_code(self, code: str, max_length: int = 512) -> np.ndarray:
"""将代码转换为嵌入向量"""
# Tokenize
inputs = self.tokenizer(
code,
max_length=max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 移到设备
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# 获取嵌入
with torch.no_grad():
outputs = self.model(**inputs)
# 使用[CLS] token的嵌入作为代码表示
embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
return embeddings[0]
def embed_batch(self, codes: List[str], max_length: int = 512) -> np.ndarray:
"""批量嵌入"""
inputs = self.tokenizer(
codes,
max_length=max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
return embeddings
def compute_similarity(self, code1: str, code2: str) -> float:
"""计算两段代码的相似度"""
emb1 = self.embed_code(code1)
emb2 = self.embed_code(code2)
# 余弦相似度
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
return float(similarity)
def find_similar_code(
self,
query_code: str,
code_database: List[str],
top_k: int = 5
) -> List[tuple]:
"""在代码库中查找相似代码"""
query_emb = self.embed_code(query_code)
db_embs = self.embed_batch(code_database)
# 计算相似度
similarities = []
for i, db_emb in enumerate(db_embs):
sim = np.dot(query_emb, db_emb) / (
np.linalg.norm(query_emb) * np.linalg.norm(db_emb)
)
similarities.append((i, float(sim)))
# 排序并返回top_k
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
class CodeUnderstandingModel(nn.Module):
"""代码理解模型 - 在CodeBERT基础上进行任务特定训练"""
def __init__(self, base_model_name: str = "microsoft/codebert-base", num_labels: int = 2):
super().__init__()
self.encoder = AutoModel.from_pretrained(base_model_name)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(self.encoder.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0, :] # [CLS] token
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
class CodeClassifier:
"""代码分类器 - 用于代码意图理解"""
def __init__(self, model_path: str = None):
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
self.model = CodeUnderstandingModel(num_labels=4) # 假设4种意图
if model_path:
self.model.load_state_dict(torch.load(model_path))
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
self.intent_labels = {
0: 'data_processing',
1: 'api_call',
2: 'algorithm',
3: 'ui_logic'
}
def classify_intent(self, code: str) -> Dict[str, float]:
"""分类代码意图"""
inputs = self.tokenizer(
code,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
logits = self.model(**inputs)
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
return {
self.intent_labels[i]: float(prob)
for i, prob in enumerate(probs)
}
# 使用示例
embedder = CodeEmbedder()
# 示例代码
code1 = """
def bubble_sort(arr):
n = len(arr)
for i in range(n):
for j in range(0, n-i-1):
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
return arr
"""
code2 = """
def selection_sort(arr):
for i in range(len(arr)):
min_idx = i
for j in range(i+1, len(arr)):
if arr[min_idx] > arr[j]:
min_idx = j
arr[i], arr[min_idx] = arr[min_idx], arr[i]
return arr
"""
code3 = """
def fetch_user_data(user_id):
response = requests.get(f'/api/users/{user_id}')
return response.json()
"""
# 计算相似度
similarity = embedder.compute_similarity(code1, code2)
print(f"冒泡排序 vs 选择排序 相似度: {similarity:.4f}")
similarity2 = embedder.compute_similarity(code1, code3)
print(f"冒泡排序 vs API调用 相似度: {similarity2:.4f}")
# 代码搜索
code_db = [code1, code2, code3]
query = "def quick_sort(arr): ..."
results = embedder.find_similar_code(query, code_db, top_k=2)
print("\n最相似的代码:")
for idx, sim in results:
print(f"索引 {idx}, 相似度: {sim:.4f}")
14.2.3 调用图与依赖分析
import ast
from collections import defaultdict
from typing import Dict, List, Set, Tuple
import networkx as nx
import matplotlib.pyplot as plt
class CallGraphAnalyzer:
"""调用图分析器"""
def __init__(self):
self.call_graph = nx.DiGraph()
self.function_defs = {}
self.class_methods = defaultdict(list)
def analyze_file(self, filepath: str) -> nx.DiGraph:
"""分析文件并构建调用图"""
with open(filepath, 'r', encoding='utf-8') as f:
code = f.read()
tree = ast.parse(code)
# 第一遍:收集所有函数和方法定义
self._collect_definitions(tree)
# 第二遍:分析调用关系
self._analyze_calls(tree)
return self.call_graph
def _collect_definitions(self, tree: ast.AST):
"""收集函数和方法定义"""
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
# 检查是否是类方法
parent_class = self._find_parent_class(node, tree)
if parent_class:
full_name = f"{parent_class}.{node.name}"
self.class_methods[parent_class].append(node.name)
else:
full_name = node.name
self.function_defs[full_name] = {
'node': node,
'lineno': node.lineno,
'args': [arg.arg for arg in node.args.args]
}
# 添加到图中
self.call_graph.add_node(
full_name,
type='function',
lineno=node.lineno
)
def _analyze_calls(self, tree: ast.AST):
"""分析函数调用关系"""
for func_name, func_info in self.function_defs.items():
func_node = func_info['node']
# 遍历函数体中的所有调用
for node in ast.walk(func_node):
if isinstance(node, ast.Call):
called_name = self._get_call_name(node.func)
if called_name:
# 添加边
self.call_graph.add_edge(
func_name,
called_name,
lineno=node.lineno
)
def _find_parent_class(self, func_node, tree) -> str:
"""查找函数的父类"""
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
if func_node in node.body:
return node.name
return None
def _get_call_name(self, node) -> str:
"""获取调用的函数名"""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
# 处理方法调用 obj.method()
if isinstance(node.value, ast.Name):
# 简单情况:self.method() 或 obj.method()
if node.value.id == 'self':
# 查找当前类
return f"*.{node.attr}" # 使用通配符表示任意类
return f"{node.value.id}.{node.attr}"
return node.attr
return None
def find_callers(self, function_name: str) -> List[str]:
"""查找调用指定函数的所有函数"""
return list(self.call_graph.predecessors(function_name))
def find_callees(self, function_name: str) -> List[str]:
"""查找指定函数调用的所有函数"""
return list(self.call_graph.successors(function_name))
def find_call_chain(self, from_func: str, to_func: str) -> List[List[str]]:
"""查找从一个函数到另一个函数的所有调用链"""
try:
paths = nx.all_simple_paths(self.call_graph, from_func, to_func)
return list(paths)
except nx.NodeNotFound:
return []
def detect_cycles(self) -> List[List[str]]:
"""检测循环调用"""
try:
cycles = list(nx.simple_cycles(self.call_graph))
return cycles
except:
return []
def get_complexity_metrics(self) -> Dict[str, Dict]:
"""获取复杂度指标"""
metrics = {}
for node in self.call_graph.nodes():
in_degree = self.call_graph.in_degree(node) # 被调用次数
out_degree = self.call_graph.out_degree(node) # 调用其他函数次数
# 尝试计算圈复杂度
try:
# 简化版:使用出度作为复杂度的近似
complexity = out_degree + 1
except:
complexity = 1
metrics[node] = {
'callers': in_degree,
'callees': out_degree,
'complexity': complexity,
'centrality': in_degree + out_degree
}
return metrics
def visualize(self, output_file: str = 'call_graph.png'):
"""可视化调用图"""
plt.figure(figsize=(12, 8))
# 使用层次布局
pos = nx.spring_layout(self.call_graph, k=2, iterations=50)
# 绘制节点
nx.draw_networkx_nodes(
self.call_graph, pos,
node_color='lightblue',
node_size=500,
alpha=0.9
)
# 绘制边
nx.draw_networkx_edges(
self.call_graph, pos,
edge_color='gray',
arrows=True,
arrowsize=20,
alpha=0.6
)
# 绘制标签
nx.draw_networkx_labels(
self.call_graph, pos,
font_size=8,
font_weight='bold'
)
plt.title("Call Graph")
plt.axis('off')
plt.tight_layout()
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"调用图已保存到: {output_file}")
class DependencyAnalyzer:
"""依赖分析器"""
def __init__(self, project_root: str):
self.project_root = project_root
self.dep_graph = nx.DiGraph()
self.module_imports = defaultdict(set)
def analyze_project(self) -> nx.DiGraph:
"""分析整个项目的依赖"""
import os
for root, dirs, files in os.walk(self.project_root):
# 跳过虚拟环境等
dirs[:] = [d for d in dirs if d not in ['venv', '__pycache__', '.git']]
for file in files:
if file.endswith('.py'):
filepath = os.path.join(root, file)
self._analyze_file_dependencies(filepath)
return self.dep_graph
def _analyze_file_dependencies(self, filepath: str):
"""分析单个文件的依赖"""
try:
with open(filepath, 'r', encoding='utf-8') as f:
tree = ast.parse(f.read())
module_name = self._get_module_name(filepath)
self.dep_graph.add_node(module_name, path=filepath)
# 提取导入
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
imported = alias.name.split('.')[0]
self.module_imports[module_name].add(imported)
self.dep_graph.add_edge(module_name, imported)
elif isinstance(node, ast.ImportFrom):
if node.module:
imported = node.module.split('.')[0]
self.module_imports[module_name].add(imported)
self.dep_graph.add_edge(module_name, imported)
except Exception as e:
print(f"分析文件 {filepath} 时出错: {e}")
def _get_module_name(self, filepath: str) -> str:
"""从文件路径获取模块名"""
import os
relpath = os.path.relpath(filepath, self.project_root)
module = relpath.replace(os.sep, '.').replace('.py', '')
return module
def find_circular_dependencies(self) -> List[List[str]]:
"""查找循环依赖"""
cycles = list(nx.simple_cycles(self.dep_graph))
return cycles
def get_dependency_depth(self, module: str) -> int:
"""获取模块的依赖深度"""
if module not in self.dep_graph:
return 0
try:
# 从该模块可达的最长路径
lengths = nx.single_source_shortest_path_length(
self.dep_graph, module
)
return max(lengths.values()) if lengths else 0
except:
return 0
def get_impact_analysis(self, module: str) -> Dict[str, Set[str]]:
"""分析修改某个模块的影响范围"""
direct_dependents = set(self.dep_graph.predecessors(module))
all_dependents = set()
# BFS找到所有依赖该模块的模块
to_visit = list(direct_dependents)
visited = set([module])
while to_visit:
current = to_visit.pop(0)
if current in visited:
continue
visited.add(current)
all_dependents.add(current)
for predecessor in self.dep_graph.predecessors(current):
if predecessor not in visited:
to_visit.append(predecessor)
return {
'direct': direct_dependents,
'all': all_dependents,
'count': len(all_dependents)
}
# 使用示例
print("=== 调用图分析示例 ===")
# 创建示例代码文件
sample_code = '''
class Calculator:
def __init__(self):
self.result = 0
def add(self, a, b):
self.result = self._compute(a, b, '+')
return self.result
def subtract(self, a, b):
self.result = self._compute(a, b, '-')
return self.result
def _compute(self, a, b, op):
if op == '+':
return self._add_impl(a, b)
elif op == '-':
return self._subtract_impl(a, b)
def _add_impl(self, a, b):
return a + b
def _subtract_impl(self, a, b):
return a - b
def main():
calc = Calculator()
result1 = calc.add(10, 5)
result2 = calc.subtract(10, 5)
print_results(result1, result2)
def print_results(r1, r2):
print(f"Results: {r1}, {r2}")
'''
# 保存示例文件
with open('/tmp/sample_calc.py', 'w') as f:
f.write(sample_code)
# 分析调用图
analyzer = CallGraphAnalyzer()
call_graph = analyzer.analyze_file('/tmp/sample_calc.py')
print(f"函数数量: {len(call_graph.nodes())}")
print(f"调用关系数量: {len(call_graph.edges())}")
# 查找调用关系
print("\nmain函数调用的函数:")
print(analyzer.find_callees('main'))
print("\n复杂度指标:")
metrics = analyzer.get_complexity_metrics()
for func, metric in sorted(metrics.items(), key=lambda x: x[1]['complexity'], reverse=True):
print(f"{func}: 复杂度={metric['complexity']}, 调用数={metric['callees']}")
# 检测循环
cycles = analyzer.detect_cycles()
if cycles:
print(f"\n发现循环调用: {cycles}")
else:
print("\n未发现循环调用")
14.3 AI代码生成技术
14.3.1 基于语言模型的代码生成
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List, Dict, Optional
class CodeGenerator:
"""代码生成器"""
def __init__(
self,
model_name: str = "Salesforce/codegen-350M-mono",
device: str = None
):
"""
初始化代码生成器
Args:
model_name: 模型名称,可选:
- Salesforce/codegen-350M-mono (Python)
- Salesforce/codegen-2B-mono (Python, 更大)
- bigcode/starcoder (多语言)
device: 设备 ('cuda' 或 'cpu')
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
self.model.to(self.device)
self.model.eval()
def generate(
self,
prompt: str,
max_length: int = 256,
num_return_sequences: int = 1,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 50,
stop_tokens: List[str] = None
) -> List[str]:
"""
生成代码
Args:
prompt: 提示代码
max_length: 最大生成长度
num_return_sequences: 返回序列数量
temperature: 温度参数(越低越确定)
top_p: nucleus sampling参数
top_k: top-k sampling参数
stop_tokens: 停止标记列表
"""
# Tokenize输入
inputs = self.tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# 设置停止条件
if stop_tokens is None:
stop_tokens = ["\n\n", "\nclass ", "\ndef ", "\nif __name__"]
# 生成
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_return_sequences=num_return_sequences,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# 解码
generated_texts = []
for output in outputs:
text = self.tokenizer.decode(output, skip_special_tokens=True)
# 移除prompt
if text.startswith(prompt):
text = text[len(prompt):]
# 在停止标记处截断
for stop_token in stop_tokens:
if stop_token in text:
text = text[:text.index(stop_token)]
generated_texts.append(text.strip())
return generated_texts
def complete_function(
self,
function_signature: str,
docstring: str = None,
context: str = None,
num_samples: int = 3
) -> List[str]:
"""
补全函数实现
Args:
function_signature: 函数签名
docstring: 函数文档字符串
context: 上下文代码
num_samples: 生成样本数
"""
# 构建prompt
prompt_parts = []
if context:
prompt_parts.append(context)
prompt_parts.append("\n")
prompt_parts.append(function_signature)
if docstring:
prompt_parts.append(f'\n """{docstring}"""')
prompt_parts.append("\n ")
prompt = "".join(prompt_parts)
# 生成
completions = self.generate(
prompt,
max_length=512,
num_return_sequences=num_samples,
temperature=0.3
)
# 格式化结果
results = []
for completion in completions:
full_function = prompt + completion
results.append(full_function)
return results
def generate_from_comment(
self,
comment: str,
language: str = "python",
num_samples: int = 1
) -> List[str]:
"""
从注释生成代码
Args:
comment: 描述性注释
language: 编程语言
num_samples: 生成样本数
"""
if language.lower() == "python":
prompt = f"# {comment}\n"
else:
prompt = f"// {comment}\n"
return self.generate(
prompt,
max_length=256,
num_return_sequences=num_samples,
temperature=0.2
)
def infill(
self,
prefix: str,
suffix: str,
num_samples: int = 1
) -> List[str]:
"""
填充中间代码(适用于支持FIM的模型)
Args:
prefix: 前缀代码
suffix: 后缀代码
num_samples: 生成样本数
"""
# 对于不支持FIM的模型,使用简单的拼接
prompt = f"{prefix}<FILL>{suffix}"
completions = self.generate(
prefix,
max_length=256,
num_return_sequences=num_samples,
temperature=0.2,
stop_tokens=[suffix.split('\n')[0] if '\n' in suffix else suffix]
)
return completions
class IntelligentCodeCompletion:
"""智能代码补全系统"""
def __init__(self, generator: CodeGenerator):
self.generator = generator
self.completion_cache = {}
def get_completion(
self,
cursor_context: Dict[str, str],
use_cache: bool = True
) -> List[Dict[str, any]]:
"""
获取代码补全建议
Args:
cursor_context: 光标上下文
- before: 光标前的代码
- after: 光标后的代码
- file_context: 文件其他部分
- language: 编程语言
use_cache: 是否使用缓存
Returns:
补全建议列表
"""
before = cursor_context.get('before', '')
after = cursor_context.get('after', '')
# 检查缓存
cache_key = (before, after)
if use_cache and cache_key in self.completion_cache:
return self.completion_cache[cache_key]
# 分析上下文确定补全类型
completion_type = self._detect_completion_type(before, after)
suggestions = []
if completion_type == 'function_body':
# 函数体补全
suggestions = self._complete_function_body(before, after)
elif completion_type == 'line':
# 行级补全
suggestions = self._complete_line(before, after)
elif completion_type == 'multiline':
# 多行补全
suggestions = self._complete_multiline(before, after)
# 缓存结果
if use_cache:
self.completion_cache[cache_key] = suggestions
return suggestions
def _detect_completion_type(self, before: str, after: str) -> str:
"""检测补全类型"""
before_lines = before.split('\n')
last_line = before_lines[-1] if before_lines else ''
# 检查是否在函数定义后
if last_line.strip().endswith(':'):
if 'def ' in last_line:
return 'function_body'
elif 'class ' in last_line:
return 'class_body'
# 检查是否在行中
if not last_line.strip().endswith(':'):
return 'line'
return 'multiline'
def _complete_function_body(
self,
before: str,
after: str
) -> List[Dict[str, any]]:
"""补全函数体"""
# 提取函数签名
lines = before.split('\n')
signature_line = lines[-1] if lines else ''
# 生成函数体
completions = self.generator.generate(
before + "\n ",
max_length=256,
num_return_sequences=3,
temperature=0.2
)
suggestions = []
for i, completion in enumerate(completions):
suggestions.append({
'text': completion,
'type': 'function_body',
'confidence': 1.0 - (i * 0.1), # 第一个建议置信度最高
'display_text': completion.split('\n')[0] # 显示第一行
})
return suggestions
def _complete_line(
self,
before: str,
after: str
) -> List[Dict[str, any]]:
"""补全当前行"""
completions = self.generator.generate(
before,
max_length=50,
num_return_sequences=5,
temperature=0.1,
stop_tokens=['\n']
)
suggestions = []
for i, completion in enumerate(completions):
# 只取第一行
line = completion.split('\n')[0]
suggestions.append({
'text': line,
'type': 'line',
'confidence': 1.0 - (i * 0.15),
'display_text': line
})
return suggestions
def _complete_multiline(
self,
before: str,
after: str
) -> List[Dict[str, any]]:
"""多行补全"""
completions = self.generator.generate(
before,
max_length=128,
num_return_sequences=3,
temperature=0.3
)
suggestions = []
for i, completion in enumerate(completions):
suggestions.append({
'text': completion,
'type': 'multiline',
'confidence': 1.0 - (i * 0.2),
'display_text': completion.split('\n')[0] + '...'
})
return suggestions
# 使用示例
print("=== 代码生成示例 ===\n")
# 初始化生成器(使用较小的模型演示)
generator = CodeGenerator(model_name="Salesforce/codegen-350M-mono")
# 示例1: 从注释生成代码
print("1. 从注释生成代码:")
comment = "Calculate the factorial of a number using recursion"
code = generator.generate_from_comment(comment, num_samples=1)
print(f"注释: {comment}")
print(f"生成代码:\n{code[0]}\n")
# 示例2: 补全函数
print("2. 补全函数实现:")
signature = "def binary_search(arr, target):"
docstring = "Perform binary search on a sorted array"
completions = generator.complete_function(signature, docstring, num_samples=2)
print("函数签名:", signature)
print("第一个补全:")
print(completions[0])
print()
# 示例3: 智能补全
print("3. 智能代码补全:")
completion_system = IntelligentCodeCompletion(generator)
context = {
'before': '''def calculate_average(numbers):
"""Calculate the average of a list of numbers"""
total = sum(numbers)
''',
'after': ' return average',
'language': 'python'
}
suggestions = completion_system.get_completion(context)
print("上下文:")
print(context['before'])
print("\n补全建议:")
for i, sugg in enumerate(suggestions[:3], 1):
print(f"{i}. [{sugg['type']}] (置信度: {sugg['confidence']:.2f})")
print(f" {sugg['display_text']}")
14.3.2 代码质量评估与排序
import ast
from typing import List, Dict, Tuple
import re
class CodeQualityEvaluator:
"""代码质量评估器"""
def __init__(self):
self.weights = {
'syntax': 0.3, # 语法正确性
'style': 0.2, # 代码风格
'complexity': 0.2, # 复杂度
'completeness': 0.15, # 完整性
'efficiency': 0.15 # 效率
}
def evaluate(self, code: str, language: str = 'python') -> Dict[str, float]:
"""
评估代码质量
Returns:
评分字典,包含各项指标和总分
"""
scores = {}
# 语法检查
scores['syntax'] = self._check_syntax(code, language)
# 风格检查
scores['style'] = self._check_style(code, language)
# 复杂度检查
scores['complexity'] = self._check_complexity(code, language)
# 完整性检查
scores['completeness'] = self._check_completeness(code, language)
# 效率检查
scores['efficiency'] = self._check_efficiency(code, language)
# 计算总分
scores['total'] = sum(
scores[key] * self.weights[key]
for key in self.weights.keys()
)
return scores
def _check_syntax(self, code: str, language: str) -> float:
"""检查语法正确性"""
if language == 'python':
try:
ast.parse(code)
return 1.0
except SyntaxError:
return 0.0
return 0.5 # 其他语言暂不支持
def _check_style(self, code: str, language: str) -> float:
"""检查代码风格"""
if language != 'python':
return 0.5
score = 1.0
issues = []
# 检查缩进(应为4空格)
lines = code.split('\n')
for i, line in enumerate(lines):
if line and not line.startswith('#'):
# 计算前导空格
leading_spaces = len(line) - len(line.lstrip(' '))
if leading_spaces % 4 != 0:
issues.append(f"Line {i+1}: 缩进不是4的倍数")
score -= 0.1
# 检查命名规范
try:
tree = ast.parse(code)
for node in ast.walk(tree):
# 函数名应为小写+下划线
if isinstance(node, ast.FunctionDef):
if not re.match(r'^[a-z_][a-z0-9_]*$', node.name):
if not node.name.startswith('_'):
issues.append(f"函数名 {node.name} 不符合规范")
score -= 0.05
# 类名应为驼峰
elif isinstance(node, ast.ClassDef):
if not re.match(r'^[A-Z][a-zA-Z0-9]*$', node.name):
issues.append(f"类名 {node.name} 不符合规范")
score -= 0.05
except:
pass
# 检查行长度(不超过120字符)
for i, line in enumerate(lines):
if len(line) > 120:
issues.append(f"Line {i+1}: 超过120字符")
score -= 0.05
return max(0.0, min(1.0, score))
def _check_complexity(self, code: str, language: str) -> float:
"""检查代码复杂度"""
if language != 'python':
return 0.5
try:
tree = ast.parse(code)
max_complexity = 0
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
complexity = self._calculate_cyclomatic_complexity(node)
max_complexity = max(max_complexity, complexity)
# 复杂度评分:1-5优秀,6-10良好,11-20一般,>20差
if max_complexity <= 5:
return 1.0
elif max_complexity <= 10:
return 0.8
elif max_complexity <= 20:
return 0.6
else:
return 0.4
except:
return 0.5
def _calculate_cyclomatic_complexity(self, node: ast.FunctionDef) -> int:
"""计算圈复杂度"""
complexity = 1 # 基础复杂度
for child in ast.walk(node):
# 每个决策点+1
if isinstance(child, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
complexity += 1
elif isinstance(child, ast.BoolOp):
# 布尔操作符(and/or)
complexity += len(child.values) - 1
return complexity
def _check_completeness(self, code: str, language: str) -> float:
"""检查代码完整性"""
if language != 'python':
return 0.5
score = 1.0
# 检查是否有未完成的标记
incomplete_markers = ['...', 'pass', 'TODO', 'FIXME', 'NotImplementedError']
for marker in incomplete_markers:
if marker in code:
score -= 0.2
# 检查是否有函数体
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
# 检查函数体是否只有pass
if len(node.body) == 1 and isinstance(node.body[0], ast.Pass):
score -= 0.3
except:
pass
return max(0.0, score)
def _check_efficiency(self, code: str, language: str) -> float:
"""检查代码效率"""
if language != 'python':
return 0.5
score = 1.0
# 检查一些常见的低效模式
inefficient_patterns = [
(r'for .+ in .+:\s+if .+:\s+.+\.append\(', '使用列表推导式'),
(r'\.append\([^)]+\)\s+\.append\(', '多次append可以优化'),
]
for pattern, reason in inefficient_patterns:
if re.search(pattern, code):
score -= 0.1
# 检查是否有嵌套循环(可能的性能问题)
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, (ast.For, ast.While)):
# 检查循环体中是否还有循环
for child in ast.walk(node):
if child != node and isinstance(child, (ast.For, ast.While)):
score -= 0.15
break
except:
pass
return max(0.0, score)
class CodeRanker:
"""代码排序器"""
def __init__(self, evaluator: CodeQualityEvaluator = None):
self.evaluator = evaluator or CodeQualityEvaluator()
def rank_candidates(
self,
candidates: List[str],
context: Dict[str, any] = None
) -> List[Tuple[str, float, Dict]]:
"""
对候选代码进行排序
Args:
candidates: 候选代码列表
context: 上下文信息(可选)
Returns:
排序后的列表 [(code, score, details), ...]
"""
ranked = []
for code in candidates:
# 评估代码质量
scores = self.evaluator.evaluate(code)
# 如果有上下文,进行额外评分
if context:
context_score = self._evaluate_context_fit(code, context)
# 将上下文适配性作为额外因素
final_score = scores['total'] * 0.7 + context_score * 0.3
else:
final_score = scores['total']
ranked.append((code, final_score, scores))
# 按分数降序排序
ranked.sort(key=lambda x: x[1], reverse=True)
return ranked
def _evaluate_context_fit(self, code: str, context: Dict) -> float:
"""评估代码与上下文的适配性"""
score = 1.0
# 检查是否使用了上下文中的变量
if 'available_vars' in context:
used_vars = self._extract_used_variables(code)
available_vars = set(context['available_vars'])
# 使用了不存在的变量,扣分
undefined_vars = used_vars - available_vars
score -= len(undefined_vars) * 0.1
# 检查返回类型是否匹配
if 'expected_return_type' in context:
inferred_type = self._infer_return_type(code)
if inferred_type != context['expected_return_type']:
score -= 0.2
return max(0.0, score)
def _extract_used_variables(self, code: str) -> set:
"""提取代码中使用的变量"""
try:
tree = ast.parse(code)
used_vars = set()
for node in ast.walk(tree):
if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
used_vars.add(node.id)
return used_vars
except:
return set()
def _infer_return_type(self, code: str) -> str:
"""推断返回类型(简化版)"""
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Return) and node.value:
# 简单类型推断
if isinstance(node.value, ast.Constant):
return type(node.value.value).__name__
elif isinstance(node.value, ast.List):
return 'list'
elif isinstance(node.value, ast.Dict):
return 'dict'
return 'unknown'
except:
return 'unknown'
# 使用示例
print("=== 代码质量评估示例 ===\n")
evaluator = CodeQualityEvaluator()
ranker = CodeRanker(evaluator)
# 准备几个候选代码
candidates = [
# 候选1: 质量较高
"""def calculate_sum(numbers):
'''Calculate the sum of a list of numbers'''
total = 0
for num in numbers:
total += num
return total""",
# 候选2: 使用内置函数,更简洁
"""def calculate_sum(numbers):
'''Calculate the sum of a list of numbers'''
return sum(numbers)""",
# 候选3: 有语法错误
"""def calculate_sum(numbers):
'''Calculate the sum of a list of numbers'''
total = 0
for num in numbers
total += num
return total""",
# 候选4: 复杂度较高
"""def calculate_sum(numbers):
'''Calculate the sum of a list of numbers'''
result = []
for i in range(len(numbers)):
if i == 0:
result.append(numbers[i])
else:
result.append(result[i-1] + numbers[i])
return result[-1] if result else 0"""
]
print("对候选代码进行评估和排序:\n")
ranked_candidates = ranker.rank_candidates(candidates)
for i, (code, score, details) in enumerate(ranked_candidates, 1):
print(f"排名 {i} (总分: {score:.3f})")
print(f"详细评分:")
for key, value in details.items():
if key != 'total':
print(f" {key}: {value:.3f}")
print(f"代码:")
print(code)
print("-" * 60)
14.4 VSCode插件开发
14.4.1 插件架构设计
VSCode插件的基本结构:
my-ai-assistant/
├── package.json # 插件清单
├── tsconfig.json # TypeScript配置
├── .vscodeignore # 发布时忽略的文件
├── src/
│ ├── extension.ts # 插件入口
│ ├── completionProvider.ts # 代码补全提供者
│ ├── chatProvider.ts # 聊天界面提供者
│ ├── codeActionProvider.ts # 代码操作提供者
│ ├── apiClient.ts # API客户端
│ ├── contextBuilder.ts # 上下文构建器
│ └── utils/
│ ├── ast.ts # AST工具
│ ├── cache.ts # 缓存管理
│ └── logger.ts # 日志
├── media/ # 静态资源
│ ├── icon.png
│ └── styles.css
└── test/ # 测试
└── extension.test.ts
package.json 配置:
{
"name": "ai-code-assistant",
"displayName": "AI Code Assistant",
"description": "AI-powered code completion and assistance",
"version": "0.1.0",
"engines": {
"vscode": "^1.80.0"
},
"categories": [
"Programming Languages",
"Machine Learning",
"Other"
],
"activationEvents": [
"onStartupFinished"
],
"main": "./out/extension.js",
"contributes": {
"commands": [
{
"command": "aiAssistant.chat",
"title": "AI Assistant: Open Chat"
},
{
"command": "aiAssistant.explainCode",
"title": "AI Assistant: Explain Code"
},
{
"command": "aiAssistant.refactorCode",
"title": "AI Assistant: Refactor Code"
},
{
"command": "aiAssistant.generateTests",
"title": "AI Assistant: Generate Tests"
}
],
"keybindings": [
{
"command": "aiAssistant.chat",
"key": "ctrl+shift+a",
"mac": "cmd+shift+a"
}
],
"configuration": {
"title": "AI Code Assistant",
"properties": {
"aiAssistant.apiEndpoint": {
"type": "string",
"default": "http://localhost:8000",
"description": "API endpoint for the AI model"
},
"aiAssistant.modelName": {
"type": "string",
"default": "codegen",
"description": "Model name to use for code generation"
},
"aiAssistant.maxTokens": {
"type": "number",
"default": 256,
"description": "Maximum tokens to generate"
},
"aiAssistant.temperature": {
"type": "number",
"default": 0.2,
"description": "Temperature for generation (0-1)"
},
"aiAssistant.enableInlineCompletion": {
"type": "boolean",
"default": true,
"description": "Enable inline code completion"
}
}
},
"viewsContainers": {
"activitybar": [
{
"id": "ai-assistant",
"title": "AI Assistant",
"icon": "media/icon.svg"
}
]
},
"views": {
"ai-assistant": [
{
"id": "aiAssistant.chatView",
"name": "Chat",
"type": "webview"
}
]
}
},
"scripts": {
"vscode:prepublish": "npm run compile",
"compile": "tsc -p ./",
"watch": "tsc -watch -p ./",
"pretest": "npm run compile",
"test": "node ./out/test/runTest.js"
},
"devDependencies": {
"@types/node": "^18.0.0",
"@types/vscode": "^1.80.0",
"typescript": "^5.0.0"
},
"dependencies": {
"axios": "^1.4.0"
}
}
14.4.2 核心功能实现
1. 插件入口 (extension.ts):
import * as vscode from 'vscode';
import { CompletionProvider } from './completionProvider';
import { ChatProvider } from './chatProvider';
import { CodeActionProvider } from './codeActionProvider';
import { APIClient } from './apiClient';
import { Logger } from './utils/logger';
let completionProvider: CompletionProvider;
let chatProvider: ChatProvider;
let codeActionProvider: CodeActionProvider;
let apiClient: APIClient;
export function activate(context: vscode.ExtensionContext) {
Logger.info('AI Code Assistant is activating...');
// 初始化API客户端
const config = vscode.workspace.getConfiguration('aiAssistant');
apiClient = new APIClient(
config.get('apiEndpoint', 'http://localhost:8000'),
config.get('modelName', 'codegen')
);
// 注册代码补全提供者
if (config.get('enableInlineCompletion', true)) {
completionProvider = new CompletionProvider(apiClient);
const completionDisposable = vscode.languages.registerInlineCompletionItemProvider(
{ pattern: '**' },
completionProvider
);
context.subscriptions.push(completionDisposable);
}
// 注册聊天视图
chatProvider = new ChatProvider(context, apiClient);
const chatDisposable = vscode.window.registerWebviewViewProvider(
'aiAssistant.chatView',
chatProvider
);
context.subscriptions.push(chatDisposable);
// 注册代码操作提供者
codeActionProvider = new CodeActionProvider(apiClient);
const codeActionDisposable = vscode.languages.registerCodeActionsProvider(
{ pattern: '**' },
codeActionProvider,
{
providedCodeActionKinds: CodeActionProvider.providedCodeActionKinds
}
);
context.subscriptions.push(codeActionDisposable);
// 注册命令
registerCommands(context);
Logger.info('AI Code Assistant activated successfully');
}
function registerCommands(context: vscode.ExtensionContext) {
// 打开聊天命令
const chatCommand = vscode.commands.registerCommand(
'aiAssistant.chat',
() => {
vscode.commands.executeCommand('aiAssistant.chatView.focus');
}
);
// 解释代码命令
const explainCommand = vscode.commands.registerCommand(
'aiAssistant.explainCode',
async () => {
const editor = vscode.window.activeTextEditor;
if (!editor) {
vscode.window.showWarningMessage('No active editor');
return;
}
const selection = editor.selection;
const code = editor.document.getText(selection);
if (!code) {
vscode.window.showWarningMessage('No code selected');
return;
}
const explanation = await apiClient.explainCode(code);
// 在新文档中显示解释
const doc = await vscode.workspace.openTextDocument({
content: explanation,
language: 'markdown'
});
await vscode.window.showTextDocument(doc);
}
);
// 重构代码命令
const refactorCommand = vscode.commands.registerCommand(
'aiAssistant.refactorCode',
async () => {
const editor = vscode.window.activeTextEditor;
if (!editor) return;
const selection = editor.selection;
const code = editor.document.getText(selection);
if (!code) {
vscode.window.showWarningMessage('No code selected');
return;
}
// 显示进度
await vscode.window.withProgress(
{
location: vscode.ProgressLocation.Notification,
title: 'Refactoring code...',
cancellable: false
},
async (progress) => {
const refactored = await apiClient.refactorCode(code);
// 替换选中的代码
await editor.edit(editBuilder => {
editBuilder.replace(selection, refactored);
});
vscode.window.showInformationMessage('Code refactored successfully');
}
);
}
);
// 生成测试命令
const generateTestsCommand = vscode.commands.registerCommand(
'aiAssistant.generateTests',
async () => {
const editor = vscode.window.activeTextEditor;
if (!editor) return;
const selection = editor.selection;
const code = editor.document.getText(selection);
if (!code) {
vscode.window.showWarningMessage('No code selected');
return;
}
const tests = await apiClient.generateTests(code);
// 在新文档中显示测试
const doc = await vscode.workspace.openTextDocument({
content: tests,
language: editor.document.languageId
});
await vscode.window.showTextDocument(doc, { viewColumn: vscode.ViewColumn.Beside });
}
);
context.subscriptions.push(
chatCommand,
explainCommand,
refactorCommand,
generateTestsCommand
);
}
export function deactivate() {
Logger.info('AI Code Assistant is deactivating...');
}
2. 代码补全提供者 (completionProvider.ts):
import * as vscode from 'vscode';
import { APIClient } from './apiClient';
import { ContextBuilder } from './contextBuilder';
import { CacheManager } from './utils/cache';
export class CompletionProvider implements vscode.InlineCompletionItemProvider {
private contextBuilder: ContextBuilder;
private cache: CacheManager;
private debounceTimer?: NodeJS.Timeout;
constructor(private apiClient: APIClient) {
this.contextBuilder = new ContextBuilder();
this.cache = new CacheManager(60000); // 1分钟缓存
}
async provideInlineCompletionItems(
document: vscode.TextDocument,
position: vscode.Position,
context: vscode.InlineCompletionContext,
token: vscode.CancellationToken
): Promise<vscode.InlineCompletionItem[] | vscode.InlineCompletionList | undefined> {
// 检查是否应该触发补全
if (!this.shouldTriggerCompletion(document, position, context)) {
return undefined;
}
// 构建上下文
const codeContext = this.contextBuilder.buildContext(document, position);
// 检查缓存
const cacheKey = this.getCacheKey(codeContext);
const cached = this.cache.get(cacheKey);
if (cached) {
return this.createCompletionItems(cached);
}
try {
// 调用API获取补全
const completions = await this.apiClient.getCompletions(codeContext);
// 缓存结果
this.cache.set(cacheKey, completions);
return this.createCompletionItems(completions);
} catch (error) {
console.error('Error getting completions:', error);
return undefined;
}
}
private shouldTriggerCompletion(
document: vscode.TextDocument,
position: vscode.Position,
context: vscode.InlineCompletionContext
): boolean {
// 获取当前行
const line = document.lineAt(position.line);
const textBeforeCursor = line.text.substring(0, position.character);
// 不在注释中触发
if (textBeforeCursor.trim().startsWith('#') ||
textBeforeCursor.trim().startsWith('//')) {
return false;
}
// 不在字符串中触发
const inString = this.isInString(textBeforeCursor);
if (inString) {
return false;
}
// 至少输入了一些字符
if (textBeforeCursor.trim().length < 2) {
return false;
}
return true;
}
private isInString(text: string): boolean {
// 简单检查是否在字符串中
const singleQuotes = (text.match(/'/g) || []).length;
const doubleQuotes = (text.match(/"/g) || []).length;
return (singleQuotes % 2 !== 0) || (doubleQuotes % 2 !== 0);
}
private getCacheKey(context: any): string {
// 使用前后代码的哈希作为缓存键
return `${context.before.slice(-100)}_${context.after.slice(0, 100)}`;
}
private createCompletionItems(
completions: Array<{ text: string; confidence: number }>
): vscode.InlineCompletionItem[] {
return completions.map(completion => {
const item = new vscode.InlineCompletionItem(completion.text);
// 可以添加额外的信息
item.range = undefined; // 使用默认范围
return item;
});
}
}
3. 聊天界面提供者 (chatProvider.ts):
import * as vscode from 'vscode';
import { APIClient } from './apiClient';
export class ChatProvider implements vscode.WebviewViewProvider {
private view?: vscode.WebviewView;
private messages: Array<{ role: string; content: string }> = [];
constructor(
private readonly context: vscode.ExtensionContext,
private readonly apiClient: APIClient
) {}
resolveWebviewView(
webviewView: vscode.WebviewView,
context: vscode.WebviewViewResolveContext,
token: vscode.CancellationToken
): void | Thenable<void> {
this.view = webviewView;
webviewView.webview.options = {
enableScripts: true,
localResourceRoots: [this.context.extensionUri]
};
webviewView.webview.html = this.getHtmlForWebview(webviewView.webview);
// 处理来自webview的消息
webviewView.webview.onDidReceiveMessage(async (data) => {
switch (data.type) {
case 'sendMessage':
await this.handleUserMessage(data.message);
break;
case 'clearChat':
this.messages = [];
this.updateChat();
break;
}
});
}
private async handleUserMessage(message: string) {
// 添加用户消息
this.messages.push({ role: 'user', content: message });
this.updateChat();
try {
// 调用API获取响应
const response = await this.apiClient.chat(this.messages);
// 添加助手响应
this.messages.push({ role: 'assistant', content: response });
this.updateChat();
} catch (error) {
vscode.window.showErrorMessage('Failed to get response from AI');
console.error(error);
}
}
private updateChat() {
if (this.view) {
this.view.webview.postMessage({
type: 'updateMessages',
messages: this.messages
});
}
}
private getHtmlForWebview(webview: vscode.Webview): string {
return `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI Assistant Chat</title>
<style>
body {
padding: 10px;
font-family: var(--vscode-font-family);
color: var(--vscode-foreground);
}
#messages {
height: calc(100vh - 120px);
overflow-y: auto;
padding: 10px;
margin-bottom: 10px;
}
.message {
margin-bottom: 15px;
padding: 10px;
border-radius: 5px;
}
.user {
background-color: var(--vscode-input-background);
margin-left: 20px;
}
.assistant {
background-color: var(--vscode-editor-background);
margin-right: 20px;
}
.role {
font-weight: bold;
margin-bottom: 5px;
font-size: 0.9em;
}
#input-container {
display: flex;
gap: 10px;
}
#message-input {
flex: 1;
padding: 8px;
background-color: var(--vscode-input-background);
color: var(--vscode-input-foreground);
border: 1px solid var(--vscode-input-border);
border-radius: 3px;
}
button {
padding: 8px 15px;
background-color: var(--vscode-button-background);
color: var(--vscode-button-foreground);
border: none;
border-radius: 3px;
cursor: pointer;
}
button:hover {
background-color: var(--vscode-button-hoverBackground);
}
pre {
background-color: var(--vscode-textCodeBlock-background);
padding: 10px;
border-radius: 3px;
overflow-x: auto;
}
code {
font-family: var(--vscode-editor-font-family);
}
</style>
</head>
<body>
<div id="messages"></div>
<div id="input-container">
<input
type="text"
id="message-input"
placeholder="Ask me anything about your code..."
autocomplete="off"
/>
<button onclick="sendMessage()">Send</button>
<button onclick="clearChat()">Clear</button>
</div>
<script>
const vscode = acquireVsCodeApi();
function sendMessage() {
const input = document.getElementById('message-input');
const message = input.value.trim();
if (message) {
vscode.postMessage({
type: 'sendMessage',
message: message
});
input.value = '';
}
}
function clearChat() {
vscode.postMessage({ type: 'clearChat' });
}
// 监听回车键
document.getElementById('message-input').addEventListener('keypress', (e) => {
if (e.key === 'Enter') {
sendMessage();
}
});
// 接收消息更新
window.addEventListener('message', (event) => {
const message = event.data;
if (message.type === 'updateMessages') {
updateMessagesUI(message.messages);
}
});
function updateMessagesUI(messages) {
const container = document.getElementById('messages');
container.innerHTML = '';
messages.forEach(msg => {
const div = document.createElement('div');
div.className = \`message \${msg.role}\`;
const roleDiv = document.createElement('div');
roleDiv.className = 'role';
roleDiv.textContent = msg.role === 'user' ? 'You' : 'AI Assistant';
const contentDiv = document.createElement('div');
contentDiv.innerHTML = formatMessage(msg.content);
div.appendChild(roleDiv);
div.appendChild(contentDiv);
container.appendChild(div);
});
// 滚动到底部
container.scrollTop = container.scrollHeight;
}
function formatMessage(content) {
// 简单的Markdown格式化
// 代码块
content = content.replace(/\`\`\`([\\s\\S]*?)\`\`\`/g, '<pre><code>$1</code></pre>');
// 行内代码
content = content.replace(/\`([^\`]+)\`/g, '<code>$1</code>');
// 换行
content = content.replace(/\\n/g, '<br>');
return content;
}
</script>
</body>
</html>`;
}
}
4. API客户端 (apiClient.ts):
import axios, { AxiosInstance } from 'axios';
export interface CompletionRequest {
before: string;
after: string;
language: string;
maxTokens?: number;
temperature?: number;
}
export interface CompletionResponse {
completions: Array<{
text: string;
confidence: number;
}>;
}
export class APIClient {
private client: AxiosInstance;
constructor(
private baseURL: string,
private modelName: string
) {
this.client = axios.create({
baseURL: this.baseURL,
timeout: 30000,
headers: {
'Content-Type': 'application/json'
}
});
}
async getCompletions(context: any): Promise<Array<{ text: string; confidence: number }>> {
try {
const response = await this.client.post<CompletionResponse>('/completions', {
model: this.modelName,
before: context.before,
after: context.after,
language: context.language,
maxTokens: context.maxTokens || 256,
temperature: context.temperature || 0.2
});
return response.data.completions;
} catch (error) {
console.error('API request failed:', error);
throw error;
}
}
async chat(messages: Array<{ role: string; content: string }>): Promise<string> {
try {
const response = await this.client.post('/chat', {
model: this.modelName,
messages: messages
});
return response.data.response;
} catch (error) {
console.error('Chat request failed:', error);
throw error;
}
}
async explainCode(code: string): Promise<string> {
try {
const response = await this.client.post('/explain', {
model: this.modelName,
code: code
});
return response.data.explanation;
} catch (error) {
console.error('Explain request failed:', error);
throw error;
}
}
async refactorCode(code: string): Promise<string> {
try {
const response = await this.client.post('/refactor', {
model: this.modelName,
code: code
});
return response.data.refactored;
} catch (error) {
console.error('Refactor request failed:', error);
throw error;
}
}
async generateTests(code: string): Promise<string> {
try {
const response = await this.client.post('/generate-tests', {
model: this.modelName,
code: code
});
return response.data.tests;
} catch (error) {
console.error('Generate tests request failed:', error);
throw error;
}
}
}
5. 上下文构建器 (contextBuilder.ts):
import * as vscode from 'vscode';
export interface CodeContext {
before: string;
after: string;
language: string;
currentLine: string;
fileName: string;
imports: string[];
}
export class ContextBuilder {
private readonly maxContextLength = 2000; // 字符
buildContext(
document: vscode.TextDocument,
position: vscode.Position
): CodeContext {
const text = document.getText();
const offset = document.offsetAt(position);
// 获取前后文本
let before = text.substring(0, offset);
let after = text.substring(offset);
// 限制长度
if (before.length > this.maxContextLength) {
before = this.truncateFromStart(before, this.maxContextLength);
}
if (after.length > this.maxContextLength) {
after = after.substring(0, this.maxContextLength);
}
// 获取当前行
const currentLine = document.lineAt(position.line).text;
// 提取导入语句
const imports = this.extractImports(document);
return {
before,
after,
language: document.languageId,
currentLine,
fileName: document.fileName,
imports
};
}
private truncateFromStart(text: string, maxLength: number): string {
if (text.length <= maxLength) {
return text;
}
// 从最近的换行符开始截断,保持代码结构
const truncated = text.substring(text.length - maxLength);
const firstNewline = truncated.indexOf('\n');
if (firstNewline !== -1) {
return truncated.substring(firstNewline + 1);
}
return truncated;
}
private extractImports(document: vscode.TextDocument): string[] {
const imports: string[] = [];
const text = document.getText();
const lines = text.split('\n');
for (const line of lines) {
const trimmed = line.trim();
// Python导入
if (trimmed.startsWith('import ') || trimmed.startsWith('from ')) {
imports.push(trimmed);
}
// TypeScript/JavaScript导入
else if (trimmed.startsWith('import ') || trimmed.startsWith('require(')) {
imports.push(trimmed);
}
// 如果已经过了导入区域(遇到非导入非空行),停止
else if (trimmed && !trimmed.startsWith('#') && !trimmed.startswith('//')) {
break;
}
}
return imports;
}
}
14.4.3 后端API服务
为了支持VSCode插件,我们需要一个后端API服务:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import uvicorn
# 假设我们已经有了前面实现的代码生成器
from code_generator import CodeGenerator, IntelligentCodeCompletion
app = FastAPI(title="AI Code Assistant API")
# 初始化模型
generator = CodeGenerator()
completion_system = IntelligentCodeCompletion(generator)
class CompletionRequest(BaseModel):
model: str
before: str
after: str
language: str
maxTokens: Optional[int] = 256
temperature: Optional[float] = 0.2
class CompletionResponse(BaseModel):
completions: List[dict]
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[ChatMessage]
class ChatResponse(BaseModel):
response: str
class CodeRequest(BaseModel):
model: str
code: str
class CodeResponse(BaseModel):
result: str
@app.post("/completions", response_model=CompletionResponse)
async def get_completions(request: CompletionRequest):
"""获取代码补全"""
try:
context = {
'before': request.before,
'after': request.after,
'language': request.language
}
suggestions = completion_system.get_completion(context)
return CompletionResponse(
completions=[
{
'text': sugg['text'],
'confidence': sugg['confidence']
}
for sugg in suggestions[:5] # 返回前5个
]
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""聊天接口"""
try:
# 这里简化处理,实际应该使用对话模型
last_message = request.messages[-1].content
# 简单响应
response = f"收到您的消息:{last_message}。这是一个演示响应。"
return ChatResponse(response=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/explain")
async def explain_code(request: CodeRequest):
"""解释代码"""
try:
# 使用代码生成器生成解释
prompt = f"Explain the following code:\n\n{request.code}\n\nExplanation:"
explanation = generator.generate(prompt, max_length=512, temperature=0.3)[0]
return {"explanation": explanation}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/refactor")
async def refactor_code(request: CodeRequest):
"""重构代码"""
try:
prompt = f"Refactor the following code to make it cleaner and more efficient:\n\n{request.code}\n\nRefactored code:\n"
refactored = generator.generate(prompt, max_length=512, temperature=0.2)[0]
return {"refactored": refactored}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate-tests")
async def generate_tests(request: CodeRequest):
"""生成测试"""
try:
prompt = f"Generate unit tests for the following code:\n\n{request.code}\n\nTests:\n"
tests = generator.generate(prompt, max_length=512, temperature=0.3)[0]
return {"tests": tests}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "ok", "model": "codegen"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
14.5 性能优化与最佳实践
14.5.1 缓存策略
import hashlib
import pickle
import time
from typing import Any, Optional
from collections import OrderedDict
class LRUCache:
"""LRU缓存实现"""
def __init__(self, capacity: int = 100, ttl: int = 3600):
"""
Args:
capacity: 最大缓存项数
ttl: 过期时间(秒)
"""
self.cache = OrderedDict()
self.capacity = capacity
self.ttl = ttl
self.timestamps = {}
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
if key not in self.cache:
return None
# 检查是否过期
if time.time() - self.timestamps[key] > self.ttl:
self.cache.pop(key)
self.timestamps.pop(key)
return None
# 移到末尾(最近使用)
self.cache.move_to_end(key)
return self.cache[key]
def set(self, key: str, value: Any):
"""设置缓存值"""
if key in self.cache:
self.cache.move_to_end(key)
else:
if len(self.cache) >= self.capacity:
# 移除最久未使用的项
oldest = next(iter(self.cache))
self.cache.pop(oldest)
self.timestamps.pop(oldest)
self.cache[key] = value
self.timestamps[key] = time.time()
def clear(self):
"""清空缓存"""
self.cache.clear()
self.timestamps.clear()
class SmartCache:
"""智能缓存系统"""
def __init__(self, capacity: int = 1000):
self.completion_cache = LRUCache(capacity, ttl=300) # 5分钟
self.embedding_cache = LRUCache(capacity, ttl=3600) # 1小时
self.analysis_cache = LRUCache(capacity, ttl=1800) # 30分钟
def _hash_key(self, data: str) -> str:
"""生成哈希键"""
return hashlib.md5(data.encode()).hexdigest()
def get_completion(self, context: str) -> Optional[list]:
"""获取补全缓存"""
key = self._hash_key(context)
return self.completion_cache.get(key)
def set_completion(self, context: str, completions: list):
"""设置补全缓存"""
key = self._hash_key(context)
self.completion_cache.set(key, completions)
def get_embedding(self, code: str) -> Optional[Any]:
"""获取嵌入缓存"""
key = self._hash_key(code)
return self.embedding_cache.get(key)
def set_embedding(self, code: str, embedding: Any):
"""设置嵌入缓存"""
key = self._hash_key(code)
self.embedding_cache.set(key, embedding)
# 使用示例
cache = SmartCache()
# 缓存补全结果
context = "def calculate_sum(numbers):\n "
completions = ["total = sum(numbers)", "result = 0; for n in numbers: result += n"]
cache.set_completion(context, completions)
# 获取缓存
cached = cache.get_completion(context)
if cached:
print("使用缓存的补全结果")
else:
print("需要重新生成")
14.5.2 批处理与并发
import asyncio
from typing import List, Callable
import concurrent.futures
class BatchProcessor:
"""批处理器"""
def __init__(self, batch_size: int = 8, max_workers: int = 4):
self.batch_size = batch_size
self.max_workers = max_workers
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
async def process_batch(
self,
items: List[Any],
process_func: Callable
) -> List[Any]:
"""异步批处理"""
results = []
# 分批处理
for i in range(0, len(items), self.batch_size):
batch = items[i:i + self.batch_size]
# 并发处理批次中的项
loop = asyncio.get_event_loop()
batch_results = await asyncio.gather(*[
loop.run_in_executor(self.executor, process_func, item)
for item in batch
])
results.extend(batch_results)
return results
def shutdown(self):
"""关闭执行器"""
self.executor.shutdown(wait=True)
# 使用示例
async def main():
processor = BatchProcessor(batch_size=4, max_workers=2)
def process_code(code):
# 模拟处理
time.sleep(0.1)
return f"Processed: {code[:20]}..."
codes = [f"code_sample_{i}" for i in range(10)]
results = await processor.process_batch(codes, process_code)
print(f"处理了 {len(results)} 个代码样本")
processor.shutdown()
# asyncio.run(main())
14.6 总结
本章深入介绍了AI编程助手的核心技术:
- 代码理解:AST解析、语义嵌入、调用图分析
- 代码生成:基于语言模型的生成、质量评估、智能排序
- VSCode插件:完整的插件架构和实现
- 性能优化:缓存策略、批处理、并发处理
这些技术构成了现代AI编程助手的基础。通过理解和实践这些内容,你可以构建自己的AI编程工具,或者更好地理解和使用现有的工具。
下一章我们将探讨RAG(检索增强生成)系统的设计与实现。