HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于
  • AI 完整学习路径

    • AI教程 - 从零到一的完整学习路径
    • 第00章:AI基础与发展史
    • 第01章:Python与AI开发环境
    • 第02章:数学基础-线性代数与微积分
    • 03-数据集详解-从获取到预处理
    • 04-从零训练第一个模型
    • 05-模型文件详解
    • 06-分布式训练-多GPU与多机
    • 07-模型调度与资源管理
    • 08-Transformer架构深度解析
    • 09-大语言模型原理与架构
    • 10-Token与Tokenization详解
    • 11-Prompt Engineering完全指南
    • 第12章:模型微调与LoRA技术
    • 第13章:RLHF与对齐技术
    • 第14章 AI编程助手原理与实现
    • 15-RAG系统设计与实现
    • 16-Agent智能体与工具调用
    • 17-多模态大模型
    • 第18章:AI前沿技术趋势
    • 第19章 AI热门话题与应用案例

第14章 AI编程助手原理与实现

AI编程助手已经成为现代软件开发的重要工具,从GitHub Copilot到Claude Code,这些工具极大地提升了开发效率。本章将深入探讨AI编程助手的核心原理、代码理解技术、实现方法,以及如何开发VSCode插件。

14.1 AI编程助手概述

14.1.1 编程助手的发展历程

AI编程助手的发展经历了几个重要阶段:

1. 传统代码补全时代(2000-2015)

  • 基于规则的语法补全
  • IDE内置的智能感知(IntelliSense)
  • 依赖静态分析和类型信息

2. 统计学习时代(2015-2020)

  • 基于n-gram的代码补全
  • 使用RNN/LSTM预测下一个token
  • TabNine等工具的出现

3. 预训练模型时代(2020-2021)

  • GPT-2/GPT-3应用于代码生成
  • CodeBERT、GraphCodeBERT等专用模型
  • 代码-文本双向理解

4. 大模型编程助手时代(2021-至今)

  • GitHub Copilot(基于Codex)
  • Claude Code、Cursor等
  • 支持多文件上下文、自然语言交互

14.1.2 核心功能与能力

现代AI编程助手通常具备以下核心能力:

class AICodeAssistant:
    """AI编程助手核心能力抽象"""

    def __init__(self):
        self.capabilities = {
            # 代码生成能力
            'code_generation': {
                'inline_completion': True,      # 行内补全
                'multi_line_suggestion': True,  # 多行建议
                'function_generation': True,    # 函数生成
                'class_generation': True,       # 类生成
            },

            # 代码理解能力
            'code_understanding': {
                'syntax_parsing': True,         # 语法解析
                'semantic_analysis': True,      # 语义分析
                'call_graph': True,             # 调用图分析
                'dependency_tracking': True,    # 依赖追踪
            },

            # 代码编辑能力
            'code_editing': {
                'refactoring': True,            # 重构
                'bug_fixing': True,             # 修复
                'optimization': True,           # 优化
                'documentation': True,          # 文档生成
            },

            # 交互能力
            'interaction': {
                'chat': True,                   # 对话交互
                'context_aware': True,          # 上下文感知
                'multi_file': True,             # 多文件理解
                'terminal_integration': True,   # 终端集成
            }
        }

    def generate_code(self, context, intent):
        """代码生成接口"""
        raise NotImplementedError

    def understand_code(self, code, language):
        """代码理解接口"""
        raise NotImplementedError

    def edit_code(self, code, instruction):
        """代码编辑接口"""
        raise NotImplementedError

14.1.3 技术架构

典型的AI编程助手技术架构:

┌─────────────────────────────────────────────────────────────┐
│                        用户界面层                              │
│  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐   │
│  │ IDE插件  │  │  Chat UI │  │ Terminal │  │  Web UI  │   │
│  └──────────┘  └──────────┘  └──────────┘  └──────────┘   │
└─────────────────────────────────────────────────────────────┘
                              ↓
┌─────────────────────────────────────────────────────────────┐
│                        编辑器集成层                            │
│  ┌────────────────┐  ┌────────────────┐  ┌──────────────┐  │
│  │  LSP Client    │  │  Event Capture │  │  Diff Engine │  │
│  └────────────────┘  └────────────────┘  └──────────────┘  │
└─────────────────────────────────────────────────────────────┘
                              ↓
┌─────────────────────────────────────────────────────────────┐
│                        上下文管理层                            │
│  ┌────────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐ │
│  │ 代码索引   │  │ 文件系统 │  │ Git集成  │  │ 依赖图   │ │
│  └────────────┘  └──────────┘  └──────────┘  └──────────┘ │
└─────────────────────────────────────────────────────────────┘
                              ↓
┌─────────────────────────────────────────────────────────────┐
│                        AI推理层                               │
│  ┌────────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐ │
│  │ 代码模型   │  │ 检索增强 │  │ 工具调用 │  │ 推理引擎 │ │
│  └────────────┘  └──────────┘  └──────────┘  └──────────┘ │
└─────────────────────────────────────────────────────────────┘
                              ↓
┌─────────────────────────────────────────────────────────────┐
│                        基础设施层                              │
│  ┌────────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐ │
│  │ 模型服务   │  │ 向量数据库│  │ 缓存系统 │  │ 监控日志 │ │
│  └────────────┘  └──────────┘  └──────────┘  └──────────┘ │
└─────────────────────────────────────────────────────────────┘

14.2 代码理解技术

14.2.1 代码解析与AST

抽象语法树(AST)是代码理解的基础:

import ast
import json
from typing import Dict, List, Any

class CodeParser:
    """代码解析器 - 将代码转换为结构化表示"""

    def __init__(self):
        self.ast_cache = {}

    def parse_python(self, code: str) -> Dict[str, Any]:
        """解析Python代码"""
        try:
            tree = ast.parse(code)
            return {
                'success': True,
                'ast': tree,
                'structure': self._extract_structure(tree),
                'symbols': self._extract_symbols(tree),
                'dependencies': self._extract_dependencies(tree)
            }
        except SyntaxError as e:
            return {
                'success': False,
                'error': str(e),
                'line': e.lineno,
                'offset': e.offset
            }

    def _extract_structure(self, tree: ast.AST) -> Dict[str, List]:
        """提取代码结构"""
        structure = {
            'classes': [],
            'functions': [],
            'imports': [],
            'global_vars': []
        }

        for node in ast.walk(tree):
            if isinstance(node, ast.ClassDef):
                structure['classes'].append({
                    'name': node.name,
                    'line': node.lineno,
                    'methods': [m.name for m in node.body if isinstance(m, ast.FunctionDef)],
                    'bases': [self._get_name(base) for base in node.bases],
                    'decorators': [self._get_name(d) for d in node.decorator_list]
                })

            elif isinstance(node, ast.FunctionDef):
                # 只记录顶层函数
                if not self._is_nested(node, tree):
                    structure['functions'].append({
                        'name': node.name,
                        'line': node.lineno,
                        'args': [arg.arg for arg in node.args.args],
                        'returns': self._get_annotation(node.returns),
                        'decorators': [self._get_name(d) for d in node.decorator_list],
                        'is_async': isinstance(node, ast.AsyncFunctionDef)
                    })

            elif isinstance(node, (ast.Import, ast.ImportFrom)):
                structure['imports'].append(self._extract_import(node))

        return structure

    def _extract_symbols(self, tree: ast.AST) -> Dict[str, List[str]]:
        """提取符号表"""
        symbols = {
            'defined': [],      # 定义的符号
            'used': [],         # 使用的符号
            'assigned': [],     # 赋值的符号
            'called': []        # 调用的符号
        }

        for node in ast.walk(tree):
            if isinstance(node, ast.Name):
                if isinstance(node.ctx, ast.Store):
                    symbols['assigned'].append(node.id)
                elif isinstance(node.ctx, ast.Load):
                    symbols['used'].append(node.id)

            elif isinstance(node, ast.Call):
                func_name = self._get_name(node.func)
                if func_name:
                    symbols['called'].append(func_name)

            elif isinstance(node, (ast.FunctionDef, ast.ClassDef)):
                symbols['defined'].append(node.name)

        # 去重
        for key in symbols:
            symbols[key] = list(set(symbols[key]))

        return symbols

    def _extract_dependencies(self, tree: ast.AST) -> List[str]:
        """提取依赖关系"""
        dependencies = []

        for node in ast.walk(tree):
            if isinstance(node, ast.Import):
                for alias in node.names:
                    dependencies.append(alias.name.split('.')[0])

            elif isinstance(node, ast.ImportFrom):
                if node.module:
                    dependencies.append(node.module.split('.')[0])

        return list(set(dependencies))

    def _get_name(self, node) -> str:
        """获取节点名称"""
        if isinstance(node, ast.Name):
            return node.id
        elif isinstance(node, ast.Attribute):
            value = self._get_name(node.value)
            return f"{value}.{node.attr}" if value else node.attr
        elif isinstance(node, ast.Call):
            return self._get_name(node.func)
        return ""

    def _get_annotation(self, node) -> str:
        """获取类型注解"""
        if node is None:
            return None
        return ast.unparse(node)

    def _is_nested(self, func_node, tree) -> bool:
        """检查函数是否嵌套在类或其他函数中"""
        for node in ast.walk(tree):
            if isinstance(node, (ast.ClassDef, ast.FunctionDef)):
                if func_node in ast.walk(node) and func_node != node:
                    return True
        return False

    def _extract_import(self, node) -> Dict[str, Any]:
        """提取导入信息"""
        if isinstance(node, ast.Import):
            return {
                'type': 'import',
                'names': [alias.name for alias in node.names],
                'aliases': {alias.name: alias.asname for alias in node.names if alias.asname}
            }
        else:  # ImportFrom
            return {
                'type': 'from_import',
                'module': node.module,
                'names': [alias.name for alias in node.names],
                'level': node.level
            }

# 使用示例
parser = CodeParser()

sample_code = '''
import os
from typing import List, Dict

class DataProcessor:
    """数据处理器"""

    def __init__(self, config: Dict):
        self.config = config

    def process(self, data: List[str]) -> List[str]:
        """处理数据"""
        result = []
        for item in data:
            processed = self._transform(item)
            result.append(processed)
        return result

    def _transform(self, item: str) -> str:
        return item.upper()

def main():
    processor = DataProcessor({'debug': True})
    data = ['hello', 'world']
    result = processor.process(data)
    print(result)
'''

parsed = parser.parse_python(sample_code)
if parsed['success']:
    print("代码结构:")
    print(json.dumps(parsed['structure'], indent=2, ensure_ascii=False))
    print("\n符号表:")
    print(json.dumps(parsed['symbols'], indent=2, ensure_ascii=False))
    print("\n依赖:")
    print(parsed['dependencies'])

14.2.2 语义理解与代码嵌入

使用预训练模型进行代码语义理解:

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from typing import List, Dict
import numpy as np

class CodeEmbedder:
    """代码嵌入器 - 将代码转换为向量表示"""

    def __init__(self, model_name: str = "microsoft/codebert-base"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.eval()

        # 使用GPU(如果可用)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

    def embed_code(self, code: str, max_length: int = 512) -> np.ndarray:
        """将代码转换为嵌入向量"""
        # Tokenize
        inputs = self.tokenizer(
            code,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # 移到设备
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # 获取嵌入
        with torch.no_grad():
            outputs = self.model(**inputs)
            # 使用[CLS] token的嵌入作为代码表示
            embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()

        return embeddings[0]

    def embed_batch(self, codes: List[str], max_length: int = 512) -> np.ndarray:
        """批量嵌入"""
        inputs = self.tokenizer(
            codes,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model(**inputs)
            embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()

        return embeddings

    def compute_similarity(self, code1: str, code2: str) -> float:
        """计算两段代码的相似度"""
        emb1 = self.embed_code(code1)
        emb2 = self.embed_code(code2)

        # 余弦相似度
        similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
        return float(similarity)

    def find_similar_code(
        self,
        query_code: str,
        code_database: List[str],
        top_k: int = 5
    ) -> List[tuple]:
        """在代码库中查找相似代码"""
        query_emb = self.embed_code(query_code)
        db_embs = self.embed_batch(code_database)

        # 计算相似度
        similarities = []
        for i, db_emb in enumerate(db_embs):
            sim = np.dot(query_emb, db_emb) / (
                np.linalg.norm(query_emb) * np.linalg.norm(db_emb)
            )
            similarities.append((i, float(sim)))

        # 排序并返回top_k
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]

class CodeUnderstandingModel(nn.Module):
    """代码理解模型 - 在CodeBERT基础上进行任务特定训练"""

    def __init__(self, base_model_name: str = "microsoft/codebert-base", num_labels: int = 2):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base_model_name)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.encoder.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]  # [CLS] token
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

class CodeClassifier:
    """代码分类器 - 用于代码意图理解"""

    def __init__(self, model_path: str = None):
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
        self.model = CodeUnderstandingModel(num_labels=4)  # 假设4种意图

        if model_path:
            self.model.load_state_dict(torch.load(model_path))

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()

        self.intent_labels = {
            0: 'data_processing',
            1: 'api_call',
            2: 'algorithm',
            3: 'ui_logic'
        }

    def classify_intent(self, code: str) -> Dict[str, float]:
        """分类代码意图"""
        inputs = self.tokenizer(
            code,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            logits = self.model(**inputs)
            probs = torch.softmax(logits, dim=1)[0].cpu().numpy()

        return {
            self.intent_labels[i]: float(prob)
            for i, prob in enumerate(probs)
        }

# 使用示例
embedder = CodeEmbedder()

# 示例代码
code1 = """
def bubble_sort(arr):
    n = len(arr)
    for i in range(n):
        for j in range(0, n-i-1):
            if arr[j] > arr[j+1]:
                arr[j], arr[j+1] = arr[j+1], arr[j]
    return arr
"""

code2 = """
def selection_sort(arr):
    for i in range(len(arr)):
        min_idx = i
        for j in range(i+1, len(arr)):
            if arr[min_idx] > arr[j]:
                min_idx = j
        arr[i], arr[min_idx] = arr[min_idx], arr[i]
    return arr
"""

code3 = """
def fetch_user_data(user_id):
    response = requests.get(f'/api/users/{user_id}')
    return response.json()
"""

# 计算相似度
similarity = embedder.compute_similarity(code1, code2)
print(f"冒泡排序 vs 选择排序 相似度: {similarity:.4f}")

similarity2 = embedder.compute_similarity(code1, code3)
print(f"冒泡排序 vs API调用 相似度: {similarity2:.4f}")

# 代码搜索
code_db = [code1, code2, code3]
query = "def quick_sort(arr): ..."
results = embedder.find_similar_code(query, code_db, top_k=2)
print("\n最相似的代码:")
for idx, sim in results:
    print(f"索引 {idx}, 相似度: {sim:.4f}")

14.2.3 调用图与依赖分析

import ast
from collections import defaultdict
from typing import Dict, List, Set, Tuple
import networkx as nx
import matplotlib.pyplot as plt

class CallGraphAnalyzer:
    """调用图分析器"""

    def __init__(self):
        self.call_graph = nx.DiGraph()
        self.function_defs = {}
        self.class_methods = defaultdict(list)

    def analyze_file(self, filepath: str) -> nx.DiGraph:
        """分析文件并构建调用图"""
        with open(filepath, 'r', encoding='utf-8') as f:
            code = f.read()

        tree = ast.parse(code)

        # 第一遍:收集所有函数和方法定义
        self._collect_definitions(tree)

        # 第二遍:分析调用关系
        self._analyze_calls(tree)

        return self.call_graph

    def _collect_definitions(self, tree: ast.AST):
        """收集函数和方法定义"""
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef):
                # 检查是否是类方法
                parent_class = self._find_parent_class(node, tree)
                if parent_class:
                    full_name = f"{parent_class}.{node.name}"
                    self.class_methods[parent_class].append(node.name)
                else:
                    full_name = node.name

                self.function_defs[full_name] = {
                    'node': node,
                    'lineno': node.lineno,
                    'args': [arg.arg for arg in node.args.args]
                }

                # 添加到图中
                self.call_graph.add_node(
                    full_name,
                    type='function',
                    lineno=node.lineno
                )

    def _analyze_calls(self, tree: ast.AST):
        """分析函数调用关系"""
        for func_name, func_info in self.function_defs.items():
            func_node = func_info['node']

            # 遍历函数体中的所有调用
            for node in ast.walk(func_node):
                if isinstance(node, ast.Call):
                    called_name = self._get_call_name(node.func)
                    if called_name:
                        # 添加边
                        self.call_graph.add_edge(
                            func_name,
                            called_name,
                            lineno=node.lineno
                        )

    def _find_parent_class(self, func_node, tree) -> str:
        """查找函数的父类"""
        for node in ast.walk(tree):
            if isinstance(node, ast.ClassDef):
                if func_node in node.body:
                    return node.name
        return None

    def _get_call_name(self, node) -> str:
        """获取调用的函数名"""
        if isinstance(node, ast.Name):
            return node.id
        elif isinstance(node, ast.Attribute):
            # 处理方法调用 obj.method()
            if isinstance(node.value, ast.Name):
                # 简单情况:self.method() 或 obj.method()
                if node.value.id == 'self':
                    # 查找当前类
                    return f"*.{node.attr}"  # 使用通配符表示任意类
                return f"{node.value.id}.{node.attr}"
            return node.attr
        return None

    def find_callers(self, function_name: str) -> List[str]:
        """查找调用指定函数的所有函数"""
        return list(self.call_graph.predecessors(function_name))

    def find_callees(self, function_name: str) -> List[str]:
        """查找指定函数调用的所有函数"""
        return list(self.call_graph.successors(function_name))

    def find_call_chain(self, from_func: str, to_func: str) -> List[List[str]]:
        """查找从一个函数到另一个函数的所有调用链"""
        try:
            paths = nx.all_simple_paths(self.call_graph, from_func, to_func)
            return list(paths)
        except nx.NodeNotFound:
            return []

    def detect_cycles(self) -> List[List[str]]:
        """检测循环调用"""
        try:
            cycles = list(nx.simple_cycles(self.call_graph))
            return cycles
        except:
            return []

    def get_complexity_metrics(self) -> Dict[str, Dict]:
        """获取复杂度指标"""
        metrics = {}

        for node in self.call_graph.nodes():
            in_degree = self.call_graph.in_degree(node)   # 被调用次数
            out_degree = self.call_graph.out_degree(node)  # 调用其他函数次数

            # 尝试计算圈复杂度
            try:
                # 简化版:使用出度作为复杂度的近似
                complexity = out_degree + 1
            except:
                complexity = 1

            metrics[node] = {
                'callers': in_degree,
                'callees': out_degree,
                'complexity': complexity,
                'centrality': in_degree + out_degree
            }

        return metrics

    def visualize(self, output_file: str = 'call_graph.png'):
        """可视化调用图"""
        plt.figure(figsize=(12, 8))

        # 使用层次布局
        pos = nx.spring_layout(self.call_graph, k=2, iterations=50)

        # 绘制节点
        nx.draw_networkx_nodes(
            self.call_graph, pos,
            node_color='lightblue',
            node_size=500,
            alpha=0.9
        )

        # 绘制边
        nx.draw_networkx_edges(
            self.call_graph, pos,
            edge_color='gray',
            arrows=True,
            arrowsize=20,
            alpha=0.6
        )

        # 绘制标签
        nx.draw_networkx_labels(
            self.call_graph, pos,
            font_size=8,
            font_weight='bold'
        )

        plt.title("Call Graph")
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"调用图已保存到: {output_file}")

class DependencyAnalyzer:
    """依赖分析器"""

    def __init__(self, project_root: str):
        self.project_root = project_root
        self.dep_graph = nx.DiGraph()
        self.module_imports = defaultdict(set)

    def analyze_project(self) -> nx.DiGraph:
        """分析整个项目的依赖"""
        import os

        for root, dirs, files in os.walk(self.project_root):
            # 跳过虚拟环境等
            dirs[:] = [d for d in dirs if d not in ['venv', '__pycache__', '.git']]

            for file in files:
                if file.endswith('.py'):
                    filepath = os.path.join(root, file)
                    self._analyze_file_dependencies(filepath)

        return self.dep_graph

    def _analyze_file_dependencies(self, filepath: str):
        """分析单个文件的依赖"""
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                tree = ast.parse(f.read())

            module_name = self._get_module_name(filepath)
            self.dep_graph.add_node(module_name, path=filepath)

            # 提取导入
            for node in ast.walk(tree):
                if isinstance(node, ast.Import):
                    for alias in node.names:
                        imported = alias.name.split('.')[0]
                        self.module_imports[module_name].add(imported)
                        self.dep_graph.add_edge(module_name, imported)

                elif isinstance(node, ast.ImportFrom):
                    if node.module:
                        imported = node.module.split('.')[0]
                        self.module_imports[module_name].add(imported)
                        self.dep_graph.add_edge(module_name, imported)

        except Exception as e:
            print(f"分析文件 {filepath} 时出错: {e}")

    def _get_module_name(self, filepath: str) -> str:
        """从文件路径获取模块名"""
        import os
        relpath = os.path.relpath(filepath, self.project_root)
        module = relpath.replace(os.sep, '.').replace('.py', '')
        return module

    def find_circular_dependencies(self) -> List[List[str]]:
        """查找循环依赖"""
        cycles = list(nx.simple_cycles(self.dep_graph))
        return cycles

    def get_dependency_depth(self, module: str) -> int:
        """获取模块的依赖深度"""
        if module not in self.dep_graph:
            return 0

        try:
            # 从该模块可达的最长路径
            lengths = nx.single_source_shortest_path_length(
                self.dep_graph, module
            )
            return max(lengths.values()) if lengths else 0
        except:
            return 0

    def get_impact_analysis(self, module: str) -> Dict[str, Set[str]]:
        """分析修改某个模块的影响范围"""
        direct_dependents = set(self.dep_graph.predecessors(module))
        all_dependents = set()

        # BFS找到所有依赖该模块的模块
        to_visit = list(direct_dependents)
        visited = set([module])

        while to_visit:
            current = to_visit.pop(0)
            if current in visited:
                continue

            visited.add(current)
            all_dependents.add(current)

            for predecessor in self.dep_graph.predecessors(current):
                if predecessor not in visited:
                    to_visit.append(predecessor)

        return {
            'direct': direct_dependents,
            'all': all_dependents,
            'count': len(all_dependents)
        }

# 使用示例
print("=== 调用图分析示例 ===")

# 创建示例代码文件
sample_code = '''
class Calculator:
    def __init__(self):
        self.result = 0

    def add(self, a, b):
        self.result = self._compute(a, b, '+')
        return self.result

    def subtract(self, a, b):
        self.result = self._compute(a, b, '-')
        return self.result

    def _compute(self, a, b, op):
        if op == '+':
            return self._add_impl(a, b)
        elif op == '-':
            return self._subtract_impl(a, b)

    def _add_impl(self, a, b):
        return a + b

    def _subtract_impl(self, a, b):
        return a - b

def main():
    calc = Calculator()
    result1 = calc.add(10, 5)
    result2 = calc.subtract(10, 5)
    print_results(result1, result2)

def print_results(r1, r2):
    print(f"Results: {r1}, {r2}")
'''

# 保存示例文件
with open('/tmp/sample_calc.py', 'w') as f:
    f.write(sample_code)

# 分析调用图
analyzer = CallGraphAnalyzer()
call_graph = analyzer.analyze_file('/tmp/sample_calc.py')

print(f"函数数量: {len(call_graph.nodes())}")
print(f"调用关系数量: {len(call_graph.edges())}")

# 查找调用关系
print("\nmain函数调用的函数:")
print(analyzer.find_callees('main'))

print("\n复杂度指标:")
metrics = analyzer.get_complexity_metrics()
for func, metric in sorted(metrics.items(), key=lambda x: x[1]['complexity'], reverse=True):
    print(f"{func}: 复杂度={metric['complexity']}, 调用数={metric['callees']}")

# 检测循环
cycles = analyzer.detect_cycles()
if cycles:
    print(f"\n发现循环调用: {cycles}")
else:
    print("\n未发现循环调用")

14.3 AI代码生成技术

14.3.1 基于语言模型的代码生成

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List, Dict, Optional

class CodeGenerator:
    """代码生成器"""

    def __init__(
        self,
        model_name: str = "Salesforce/codegen-350M-mono",
        device: str = None
    ):
        """
        初始化代码生成器

        Args:
            model_name: 模型名称,可选:
                - Salesforce/codegen-350M-mono (Python)
                - Salesforce/codegen-2B-mono (Python, 更大)
                - bigcode/starcoder (多语言)
            device: 设备 ('cuda' 或 'cpu')
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)

        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)

        self.model.to(self.device)
        self.model.eval()

    def generate(
        self,
        prompt: str,
        max_length: int = 256,
        num_return_sequences: int = 1,
        temperature: float = 0.2,
        top_p: float = 0.95,
        top_k: int = 50,
        stop_tokens: List[str] = None
    ) -> List[str]:
        """
        生成代码

        Args:
            prompt: 提示代码
            max_length: 最大生成长度
            num_return_sequences: 返回序列数量
            temperature: 温度参数(越低越确定)
            top_p: nucleus sampling参数
            top_k: top-k sampling参数
            stop_tokens: 停止标记列表
        """
        # Tokenize输入
        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # 设置停止条件
        if stop_tokens is None:
            stop_tokens = ["\n\n", "\nclass ", "\ndef ", "\nif __name__"]

        # 生成
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=num_return_sequences,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                do_sample=temperature > 0,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        # 解码
        generated_texts = []
        for output in outputs:
            text = self.tokenizer.decode(output, skip_special_tokens=True)

            # 移除prompt
            if text.startswith(prompt):
                text = text[len(prompt):]

            # 在停止标记处截断
            for stop_token in stop_tokens:
                if stop_token in text:
                    text = text[:text.index(stop_token)]

            generated_texts.append(text.strip())

        return generated_texts

    def complete_function(
        self,
        function_signature: str,
        docstring: str = None,
        context: str = None,
        num_samples: int = 3
    ) -> List[str]:
        """
        补全函数实现

        Args:
            function_signature: 函数签名
            docstring: 函数文档字符串
            context: 上下文代码
            num_samples: 生成样本数
        """
        # 构建prompt
        prompt_parts = []

        if context:
            prompt_parts.append(context)
            prompt_parts.append("\n")

        prompt_parts.append(function_signature)

        if docstring:
            prompt_parts.append(f'\n    """{docstring}"""')

        prompt_parts.append("\n    ")

        prompt = "".join(prompt_parts)

        # 生成
        completions = self.generate(
            prompt,
            max_length=512,
            num_return_sequences=num_samples,
            temperature=0.3
        )

        # 格式化结果
        results = []
        for completion in completions:
            full_function = prompt + completion
            results.append(full_function)

        return results

    def generate_from_comment(
        self,
        comment: str,
        language: str = "python",
        num_samples: int = 1
    ) -> List[str]:
        """
        从注释生成代码

        Args:
            comment: 描述性注释
            language: 编程语言
            num_samples: 生成样本数
        """
        if language.lower() == "python":
            prompt = f"# {comment}\n"
        else:
            prompt = f"// {comment}\n"

        return self.generate(
            prompt,
            max_length=256,
            num_return_sequences=num_samples,
            temperature=0.2
        )

    def infill(
        self,
        prefix: str,
        suffix: str,
        num_samples: int = 1
    ) -> List[str]:
        """
        填充中间代码(适用于支持FIM的模型)

        Args:
            prefix: 前缀代码
            suffix: 后缀代码
            num_samples: 生成样本数
        """
        # 对于不支持FIM的模型,使用简单的拼接
        prompt = f"{prefix}<FILL>{suffix}"

        completions = self.generate(
            prefix,
            max_length=256,
            num_return_sequences=num_samples,
            temperature=0.2,
            stop_tokens=[suffix.split('\n')[0] if '\n' in suffix else suffix]
        )

        return completions

class IntelligentCodeCompletion:
    """智能代码补全系统"""

    def __init__(self, generator: CodeGenerator):
        self.generator = generator
        self.completion_cache = {}

    def get_completion(
        self,
        cursor_context: Dict[str, str],
        use_cache: bool = True
    ) -> List[Dict[str, any]]:
        """
        获取代码补全建议

        Args:
            cursor_context: 光标上下文
                - before: 光标前的代码
                - after: 光标后的代码
                - file_context: 文件其他部分
                - language: 编程语言
            use_cache: 是否使用缓存

        Returns:
            补全建议列表
        """
        before = cursor_context.get('before', '')
        after = cursor_context.get('after', '')

        # 检查缓存
        cache_key = (before, after)
        if use_cache and cache_key in self.completion_cache:
            return self.completion_cache[cache_key]

        # 分析上下文确定补全类型
        completion_type = self._detect_completion_type(before, after)

        suggestions = []

        if completion_type == 'function_body':
            # 函数体补全
            suggestions = self._complete_function_body(before, after)

        elif completion_type == 'line':
            # 行级补全
            suggestions = self._complete_line(before, after)

        elif completion_type == 'multiline':
            # 多行补全
            suggestions = self._complete_multiline(before, after)

        # 缓存结果
        if use_cache:
            self.completion_cache[cache_key] = suggestions

        return suggestions

    def _detect_completion_type(self, before: str, after: str) -> str:
        """检测补全类型"""
        before_lines = before.split('\n')
        last_line = before_lines[-1] if before_lines else ''

        # 检查是否在函数定义后
        if last_line.strip().endswith(':'):
            if 'def ' in last_line:
                return 'function_body'
            elif 'class ' in last_line:
                return 'class_body'

        # 检查是否在行中
        if not last_line.strip().endswith(':'):
            return 'line'

        return 'multiline'

    def _complete_function_body(
        self,
        before: str,
        after: str
    ) -> List[Dict[str, any]]:
        """补全函数体"""
        # 提取函数签名
        lines = before.split('\n')
        signature_line = lines[-1] if lines else ''

        # 生成函数体
        completions = self.generator.generate(
            before + "\n    ",
            max_length=256,
            num_return_sequences=3,
            temperature=0.2
        )

        suggestions = []
        for i, completion in enumerate(completions):
            suggestions.append({
                'text': completion,
                'type': 'function_body',
                'confidence': 1.0 - (i * 0.1),  # 第一个建议置信度最高
                'display_text': completion.split('\n')[0]  # 显示第一行
            })

        return suggestions

    def _complete_line(
        self,
        before: str,
        after: str
    ) -> List[Dict[str, any]]:
        """补全当前行"""
        completions = self.generator.generate(
            before,
            max_length=50,
            num_return_sequences=5,
            temperature=0.1,
            stop_tokens=['\n']
        )

        suggestions = []
        for i, completion in enumerate(completions):
            # 只取第一行
            line = completion.split('\n')[0]
            suggestions.append({
                'text': line,
                'type': 'line',
                'confidence': 1.0 - (i * 0.15),
                'display_text': line
            })

        return suggestions

    def _complete_multiline(
        self,
        before: str,
        after: str
    ) -> List[Dict[str, any]]:
        """多行补全"""
        completions = self.generator.generate(
            before,
            max_length=128,
            num_return_sequences=3,
            temperature=0.3
        )

        suggestions = []
        for i, completion in enumerate(completions):
            suggestions.append({
                'text': completion,
                'type': 'multiline',
                'confidence': 1.0 - (i * 0.2),
                'display_text': completion.split('\n')[0] + '...'
            })

        return suggestions

# 使用示例
print("=== 代码生成示例 ===\n")

# 初始化生成器(使用较小的模型演示)
generator = CodeGenerator(model_name="Salesforce/codegen-350M-mono")

# 示例1: 从注释生成代码
print("1. 从注释生成代码:")
comment = "Calculate the factorial of a number using recursion"
code = generator.generate_from_comment(comment, num_samples=1)
print(f"注释: {comment}")
print(f"生成代码:\n{code[0]}\n")

# 示例2: 补全函数
print("2. 补全函数实现:")
signature = "def binary_search(arr, target):"
docstring = "Perform binary search on a sorted array"
completions = generator.complete_function(signature, docstring, num_samples=2)
print("函数签名:", signature)
print("第一个补全:")
print(completions[0])
print()

# 示例3: 智能补全
print("3. 智能代码补全:")
completion_system = IntelligentCodeCompletion(generator)

context = {
    'before': '''def calculate_average(numbers):
    """Calculate the average of a list of numbers"""
    total = sum(numbers)
    ''',
    'after': '    return average',
    'language': 'python'
}

suggestions = completion_system.get_completion(context)
print("上下文:")
print(context['before'])
print("\n补全建议:")
for i, sugg in enumerate(suggestions[:3], 1):
    print(f"{i}. [{sugg['type']}] (置信度: {sugg['confidence']:.2f})")
    print(f"   {sugg['display_text']}")

14.3.2 代码质量评估与排序

import ast
from typing import List, Dict, Tuple
import re

class CodeQualityEvaluator:
    """代码质量评估器"""

    def __init__(self):
        self.weights = {
            'syntax': 0.3,          # 语法正确性
            'style': 0.2,           # 代码风格
            'complexity': 0.2,      # 复杂度
            'completeness': 0.15,   # 完整性
            'efficiency': 0.15      # 效率
        }

    def evaluate(self, code: str, language: str = 'python') -> Dict[str, float]:
        """
        评估代码质量

        Returns:
            评分字典,包含各项指标和总分
        """
        scores = {}

        # 语法检查
        scores['syntax'] = self._check_syntax(code, language)

        # 风格检查
        scores['style'] = self._check_style(code, language)

        # 复杂度检查
        scores['complexity'] = self._check_complexity(code, language)

        # 完整性检查
        scores['completeness'] = self._check_completeness(code, language)

        # 效率检查
        scores['efficiency'] = self._check_efficiency(code, language)

        # 计算总分
        scores['total'] = sum(
            scores[key] * self.weights[key]
            for key in self.weights.keys()
        )

        return scores

    def _check_syntax(self, code: str, language: str) -> float:
        """检查语法正确性"""
        if language == 'python':
            try:
                ast.parse(code)
                return 1.0
            except SyntaxError:
                return 0.0
        return 0.5  # 其他语言暂不支持

    def _check_style(self, code: str, language: str) -> float:
        """检查代码风格"""
        if language != 'python':
            return 0.5

        score = 1.0
        issues = []

        # 检查缩进(应为4空格)
        lines = code.split('\n')
        for i, line in enumerate(lines):
            if line and not line.startswith('#'):
                # 计算前导空格
                leading_spaces = len(line) - len(line.lstrip(' '))
                if leading_spaces % 4 != 0:
                    issues.append(f"Line {i+1}: 缩进不是4的倍数")
                    score -= 0.1

        # 检查命名规范
        try:
            tree = ast.parse(code)
            for node in ast.walk(tree):
                # 函数名应为小写+下划线
                if isinstance(node, ast.FunctionDef):
                    if not re.match(r'^[a-z_][a-z0-9_]*$', node.name):
                        if not node.name.startswith('_'):
                            issues.append(f"函数名 {node.name} 不符合规范")
                            score -= 0.05

                # 类名应为驼峰
                elif isinstance(node, ast.ClassDef):
                    if not re.match(r'^[A-Z][a-zA-Z0-9]*$', node.name):
                        issues.append(f"类名 {node.name} 不符合规范")
                        score -= 0.05
        except:
            pass

        # 检查行长度(不超过120字符)
        for i, line in enumerate(lines):
            if len(line) > 120:
                issues.append(f"Line {i+1}: 超过120字符")
                score -= 0.05

        return max(0.0, min(1.0, score))

    def _check_complexity(self, code: str, language: str) -> float:
        """检查代码复杂度"""
        if language != 'python':
            return 0.5

        try:
            tree = ast.parse(code)
            max_complexity = 0

            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef):
                    complexity = self._calculate_cyclomatic_complexity(node)
                    max_complexity = max(max_complexity, complexity)

            # 复杂度评分:1-5优秀,6-10良好,11-20一般,>20差
            if max_complexity <= 5:
                return 1.0
            elif max_complexity <= 10:
                return 0.8
            elif max_complexity <= 20:
                return 0.6
            else:
                return 0.4

        except:
            return 0.5

    def _calculate_cyclomatic_complexity(self, node: ast.FunctionDef) -> int:
        """计算圈复杂度"""
        complexity = 1  # 基础复杂度

        for child in ast.walk(node):
            # 每个决策点+1
            if isinstance(child, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
                complexity += 1
            elif isinstance(child, ast.BoolOp):
                # 布尔操作符(and/or)
                complexity += len(child.values) - 1

        return complexity

    def _check_completeness(self, code: str, language: str) -> float:
        """检查代码完整性"""
        if language != 'python':
            return 0.5

        score = 1.0

        # 检查是否有未完成的标记
        incomplete_markers = ['...', 'pass', 'TODO', 'FIXME', 'NotImplementedError']
        for marker in incomplete_markers:
            if marker in code:
                score -= 0.2

        # 检查是否有函数体
        try:
            tree = ast.parse(code)
            for node in ast.walk(tree):
                if isinstance(node, ast.FunctionDef):
                    # 检查函数体是否只有pass
                    if len(node.body) == 1 and isinstance(node.body[0], ast.Pass):
                        score -= 0.3
        except:
            pass

        return max(0.0, score)

    def _check_efficiency(self, code: str, language: str) -> float:
        """检查代码效率"""
        if language != 'python':
            return 0.5

        score = 1.0

        # 检查一些常见的低效模式
        inefficient_patterns = [
            (r'for .+ in .+:\s+if .+:\s+.+\.append\(', '使用列表推导式'),
            (r'\.append\([^)]+\)\s+\.append\(', '多次append可以优化'),
        ]

        for pattern, reason in inefficient_patterns:
            if re.search(pattern, code):
                score -= 0.1

        # 检查是否有嵌套循环(可能的性能问题)
        try:
            tree = ast.parse(code)
            for node in ast.walk(tree):
                if isinstance(node, (ast.For, ast.While)):
                    # 检查循环体中是否还有循环
                    for child in ast.walk(node):
                        if child != node and isinstance(child, (ast.For, ast.While)):
                            score -= 0.15
                            break
        except:
            pass

        return max(0.0, score)

class CodeRanker:
    """代码排序器"""

    def __init__(self, evaluator: CodeQualityEvaluator = None):
        self.evaluator = evaluator or CodeQualityEvaluator()

    def rank_candidates(
        self,
        candidates: List[str],
        context: Dict[str, any] = None
    ) -> List[Tuple[str, float, Dict]]:
        """
        对候选代码进行排序

        Args:
            candidates: 候选代码列表
            context: 上下文信息(可选)

        Returns:
            排序后的列表 [(code, score, details), ...]
        """
        ranked = []

        for code in candidates:
            # 评估代码质量
            scores = self.evaluator.evaluate(code)

            # 如果有上下文,进行额外评分
            if context:
                context_score = self._evaluate_context_fit(code, context)
                # 将上下文适配性作为额外因素
                final_score = scores['total'] * 0.7 + context_score * 0.3
            else:
                final_score = scores['total']

            ranked.append((code, final_score, scores))

        # 按分数降序排序
        ranked.sort(key=lambda x: x[1], reverse=True)

        return ranked

    def _evaluate_context_fit(self, code: str, context: Dict) -> float:
        """评估代码与上下文的适配性"""
        score = 1.0

        # 检查是否使用了上下文中的变量
        if 'available_vars' in context:
            used_vars = self._extract_used_variables(code)
            available_vars = set(context['available_vars'])

            # 使用了不存在的变量,扣分
            undefined_vars = used_vars - available_vars
            score -= len(undefined_vars) * 0.1

        # 检查返回类型是否匹配
        if 'expected_return_type' in context:
            inferred_type = self._infer_return_type(code)
            if inferred_type != context['expected_return_type']:
                score -= 0.2

        return max(0.0, score)

    def _extract_used_variables(self, code: str) -> set:
        """提取代码中使用的变量"""
        try:
            tree = ast.parse(code)
            used_vars = set()

            for node in ast.walk(tree):
                if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
                    used_vars.add(node.id)

            return used_vars
        except:
            return set()

    def _infer_return_type(self, code: str) -> str:
        """推断返回类型(简化版)"""
        try:
            tree = ast.parse(code)
            for node in ast.walk(tree):
                if isinstance(node, ast.Return) and node.value:
                    # 简单类型推断
                    if isinstance(node.value, ast.Constant):
                        return type(node.value.value).__name__
                    elif isinstance(node.value, ast.List):
                        return 'list'
                    elif isinstance(node.value, ast.Dict):
                        return 'dict'
            return 'unknown'
        except:
            return 'unknown'

# 使用示例
print("=== 代码质量评估示例 ===\n")

evaluator = CodeQualityEvaluator()
ranker = CodeRanker(evaluator)

# 准备几个候选代码
candidates = [
    # 候选1: 质量较高
    """def calculate_sum(numbers):
    '''Calculate the sum of a list of numbers'''
    total = 0
    for num in numbers:
        total += num
    return total""",

    # 候选2: 使用内置函数,更简洁
    """def calculate_sum(numbers):
    '''Calculate the sum of a list of numbers'''
    return sum(numbers)""",

    # 候选3: 有语法错误
    """def calculate_sum(numbers):
    '''Calculate the sum of a list of numbers'''
    total = 0
    for num in numbers
        total += num
    return total""",

    # 候选4: 复杂度较高
    """def calculate_sum(numbers):
    '''Calculate the sum of a list of numbers'''
    result = []
    for i in range(len(numbers)):
        if i == 0:
            result.append(numbers[i])
        else:
            result.append(result[i-1] + numbers[i])
    return result[-1] if result else 0"""
]

print("对候选代码进行评估和排序:\n")

ranked_candidates = ranker.rank_candidates(candidates)

for i, (code, score, details) in enumerate(ranked_candidates, 1):
    print(f"排名 {i} (总分: {score:.3f})")
    print(f"详细评分:")
    for key, value in details.items():
        if key != 'total':
            print(f"  {key}: {value:.3f}")
    print(f"代码:")
    print(code)
    print("-" * 60)

14.4 VSCode插件开发

14.4.1 插件架构设计

VSCode插件的基本结构:

my-ai-assistant/
├── package.json              # 插件清单
├── tsconfig.json            # TypeScript配置
├── .vscodeignore           # 发布时忽略的文件
├── src/
│   ├── extension.ts         # 插件入口
│   ├── completionProvider.ts # 代码补全提供者
│   ├── chatProvider.ts      # 聊天界面提供者
│   ├── codeActionProvider.ts # 代码操作提供者
│   ├── apiClient.ts         # API客户端
│   ├── contextBuilder.ts    # 上下文构建器
│   └── utils/
│       ├── ast.ts           # AST工具
│       ├── cache.ts         # 缓存管理
│       └── logger.ts        # 日志
├── media/                   # 静态资源
│   ├── icon.png
│   └── styles.css
└── test/                    # 测试
    └── extension.test.ts

package.json 配置:

{
  "name": "ai-code-assistant",
  "displayName": "AI Code Assistant",
  "description": "AI-powered code completion and assistance",
  "version": "0.1.0",
  "engines": {
    "vscode": "^1.80.0"
  },
  "categories": [
    "Programming Languages",
    "Machine Learning",
    "Other"
  ],
  "activationEvents": [
    "onStartupFinished"
  ],
  "main": "./out/extension.js",
  "contributes": {
    "commands": [
      {
        "command": "aiAssistant.chat",
        "title": "AI Assistant: Open Chat"
      },
      {
        "command": "aiAssistant.explainCode",
        "title": "AI Assistant: Explain Code"
      },
      {
        "command": "aiAssistant.refactorCode",
        "title": "AI Assistant: Refactor Code"
      },
      {
        "command": "aiAssistant.generateTests",
        "title": "AI Assistant: Generate Tests"
      }
    ],
    "keybindings": [
      {
        "command": "aiAssistant.chat",
        "key": "ctrl+shift+a",
        "mac": "cmd+shift+a"
      }
    ],
    "configuration": {
      "title": "AI Code Assistant",
      "properties": {
        "aiAssistant.apiEndpoint": {
          "type": "string",
          "default": "http://localhost:8000",
          "description": "API endpoint for the AI model"
        },
        "aiAssistant.modelName": {
          "type": "string",
          "default": "codegen",
          "description": "Model name to use for code generation"
        },
        "aiAssistant.maxTokens": {
          "type": "number",
          "default": 256,
          "description": "Maximum tokens to generate"
        },
        "aiAssistant.temperature": {
          "type": "number",
          "default": 0.2,
          "description": "Temperature for generation (0-1)"
        },
        "aiAssistant.enableInlineCompletion": {
          "type": "boolean",
          "default": true,
          "description": "Enable inline code completion"
        }
      }
    },
    "viewsContainers": {
      "activitybar": [
        {
          "id": "ai-assistant",
          "title": "AI Assistant",
          "icon": "media/icon.svg"
        }
      ]
    },
    "views": {
      "ai-assistant": [
        {
          "id": "aiAssistant.chatView",
          "name": "Chat",
          "type": "webview"
        }
      ]
    }
  },
  "scripts": {
    "vscode:prepublish": "npm run compile",
    "compile": "tsc -p ./",
    "watch": "tsc -watch -p ./",
    "pretest": "npm run compile",
    "test": "node ./out/test/runTest.js"
  },
  "devDependencies": {
    "@types/node": "^18.0.0",
    "@types/vscode": "^1.80.0",
    "typescript": "^5.0.0"
  },
  "dependencies": {
    "axios": "^1.4.0"
  }
}

14.4.2 核心功能实现

1. 插件入口 (extension.ts):

import * as vscode from 'vscode';
import { CompletionProvider } from './completionProvider';
import { ChatProvider } from './chatProvider';
import { CodeActionProvider } from './codeActionProvider';
import { APIClient } from './apiClient';
import { Logger } from './utils/logger';

let completionProvider: CompletionProvider;
let chatProvider: ChatProvider;
let codeActionProvider: CodeActionProvider;
let apiClient: APIClient;

export function activate(context: vscode.ExtensionContext) {
    Logger.info('AI Code Assistant is activating...');

    // 初始化API客户端
    const config = vscode.workspace.getConfiguration('aiAssistant');
    apiClient = new APIClient(
        config.get('apiEndpoint', 'http://localhost:8000'),
        config.get('modelName', 'codegen')
    );

    // 注册代码补全提供者
    if (config.get('enableInlineCompletion', true)) {
        completionProvider = new CompletionProvider(apiClient);

        const completionDisposable = vscode.languages.registerInlineCompletionItemProvider(
            { pattern: '**' },
            completionProvider
        );
        context.subscriptions.push(completionDisposable);
    }

    // 注册聊天视图
    chatProvider = new ChatProvider(context, apiClient);
    const chatDisposable = vscode.window.registerWebviewViewProvider(
        'aiAssistant.chatView',
        chatProvider
    );
    context.subscriptions.push(chatDisposable);

    // 注册代码操作提供者
    codeActionProvider = new CodeActionProvider(apiClient);
    const codeActionDisposable = vscode.languages.registerCodeActionsProvider(
        { pattern: '**' },
        codeActionProvider,
        {
            providedCodeActionKinds: CodeActionProvider.providedCodeActionKinds
        }
    );
    context.subscriptions.push(codeActionDisposable);

    // 注册命令
    registerCommands(context);

    Logger.info('AI Code Assistant activated successfully');
}

function registerCommands(context: vscode.ExtensionContext) {
    // 打开聊天命令
    const chatCommand = vscode.commands.registerCommand(
        'aiAssistant.chat',
        () => {
            vscode.commands.executeCommand('aiAssistant.chatView.focus');
        }
    );

    // 解释代码命令
    const explainCommand = vscode.commands.registerCommand(
        'aiAssistant.explainCode',
        async () => {
            const editor = vscode.window.activeTextEditor;
            if (!editor) {
                vscode.window.showWarningMessage('No active editor');
                return;
            }

            const selection = editor.selection;
            const code = editor.document.getText(selection);

            if (!code) {
                vscode.window.showWarningMessage('No code selected');
                return;
            }

            const explanation = await apiClient.explainCode(code);

            // 在新文档中显示解释
            const doc = await vscode.workspace.openTextDocument({
                content: explanation,
                language: 'markdown'
            });
            await vscode.window.showTextDocument(doc);
        }
    );

    // 重构代码命令
    const refactorCommand = vscode.commands.registerCommand(
        'aiAssistant.refactorCode',
        async () => {
            const editor = vscode.window.activeTextEditor;
            if (!editor) return;

            const selection = editor.selection;
            const code = editor.document.getText(selection);

            if (!code) {
                vscode.window.showWarningMessage('No code selected');
                return;
            }

            // 显示进度
            await vscode.window.withProgress(
                {
                    location: vscode.ProgressLocation.Notification,
                    title: 'Refactoring code...',
                    cancellable: false
                },
                async (progress) => {
                    const refactored = await apiClient.refactorCode(code);

                    // 替换选中的代码
                    await editor.edit(editBuilder => {
                        editBuilder.replace(selection, refactored);
                    });

                    vscode.window.showInformationMessage('Code refactored successfully');
                }
            );
        }
    );

    // 生成测试命令
    const generateTestsCommand = vscode.commands.registerCommand(
        'aiAssistant.generateTests',
        async () => {
            const editor = vscode.window.activeTextEditor;
            if (!editor) return;

            const selection = editor.selection;
            const code = editor.document.getText(selection);

            if (!code) {
                vscode.window.showWarningMessage('No code selected');
                return;
            }

            const tests = await apiClient.generateTests(code);

            // 在新文档中显示测试
            const doc = await vscode.workspace.openTextDocument({
                content: tests,
                language: editor.document.languageId
            });
            await vscode.window.showTextDocument(doc, { viewColumn: vscode.ViewColumn.Beside });
        }
    );

    context.subscriptions.push(
        chatCommand,
        explainCommand,
        refactorCommand,
        generateTestsCommand
    );
}

export function deactivate() {
    Logger.info('AI Code Assistant is deactivating...');
}

2. 代码补全提供者 (completionProvider.ts):

import * as vscode from 'vscode';
import { APIClient } from './apiClient';
import { ContextBuilder } from './contextBuilder';
import { CacheManager } from './utils/cache';

export class CompletionProvider implements vscode.InlineCompletionItemProvider {
    private contextBuilder: ContextBuilder;
    private cache: CacheManager;
    private debounceTimer?: NodeJS.Timeout;

    constructor(private apiClient: APIClient) {
        this.contextBuilder = new ContextBuilder();
        this.cache = new CacheManager(60000); // 1分钟缓存
    }

    async provideInlineCompletionItems(
        document: vscode.TextDocument,
        position: vscode.Position,
        context: vscode.InlineCompletionContext,
        token: vscode.CancellationToken
    ): Promise<vscode.InlineCompletionItem[] | vscode.InlineCompletionList | undefined> {

        // 检查是否应该触发补全
        if (!this.shouldTriggerCompletion(document, position, context)) {
            return undefined;
        }

        // 构建上下文
        const codeContext = this.contextBuilder.buildContext(document, position);

        // 检查缓存
        const cacheKey = this.getCacheKey(codeContext);
        const cached = this.cache.get(cacheKey);
        if (cached) {
            return this.createCompletionItems(cached);
        }

        try {
            // 调用API获取补全
            const completions = await this.apiClient.getCompletions(codeContext);

            // 缓存结果
            this.cache.set(cacheKey, completions);

            return this.createCompletionItems(completions);

        } catch (error) {
            console.error('Error getting completions:', error);
            return undefined;
        }
    }

    private shouldTriggerCompletion(
        document: vscode.TextDocument,
        position: vscode.Position,
        context: vscode.InlineCompletionContext
    ): boolean {
        // 获取当前行
        const line = document.lineAt(position.line);
        const textBeforeCursor = line.text.substring(0, position.character);

        // 不在注释中触发
        if (textBeforeCursor.trim().startsWith('#') ||
            textBeforeCursor.trim().startsWith('//')) {
            return false;
        }

        // 不在字符串中触发
        const inString = this.isInString(textBeforeCursor);
        if (inString) {
            return false;
        }

        // 至少输入了一些字符
        if (textBeforeCursor.trim().length < 2) {
            return false;
        }

        return true;
    }

    private isInString(text: string): boolean {
        // 简单检查是否在字符串中
        const singleQuotes = (text.match(/'/g) || []).length;
        const doubleQuotes = (text.match(/"/g) || []).length;

        return (singleQuotes % 2 !== 0) || (doubleQuotes % 2 !== 0);
    }

    private getCacheKey(context: any): string {
        // 使用前后代码的哈希作为缓存键
        return `${context.before.slice(-100)}_${context.after.slice(0, 100)}`;
    }

    private createCompletionItems(
        completions: Array<{ text: string; confidence: number }>
    ): vscode.InlineCompletionItem[] {
        return completions.map(completion => {
            const item = new vscode.InlineCompletionItem(completion.text);

            // 可以添加额外的信息
            item.range = undefined; // 使用默认范围

            return item;
        });
    }
}

3. 聊天界面提供者 (chatProvider.ts):

import * as vscode from 'vscode';
import { APIClient } from './apiClient';

export class ChatProvider implements vscode.WebviewViewProvider {
    private view?: vscode.WebviewView;
    private messages: Array<{ role: string; content: string }> = [];

    constructor(
        private readonly context: vscode.ExtensionContext,
        private readonly apiClient: APIClient
    ) {}

    resolveWebviewView(
        webviewView: vscode.WebviewView,
        context: vscode.WebviewViewResolveContext,
        token: vscode.CancellationToken
    ): void | Thenable<void> {
        this.view = webviewView;

        webviewView.webview.options = {
            enableScripts: true,
            localResourceRoots: [this.context.extensionUri]
        };

        webviewView.webview.html = this.getHtmlForWebview(webviewView.webview);

        // 处理来自webview的消息
        webviewView.webview.onDidReceiveMessage(async (data) => {
            switch (data.type) {
                case 'sendMessage':
                    await this.handleUserMessage(data.message);
                    break;
                case 'clearChat':
                    this.messages = [];
                    this.updateChat();
                    break;
            }
        });
    }

    private async handleUserMessage(message: string) {
        // 添加用户消息
        this.messages.push({ role: 'user', content: message });
        this.updateChat();

        try {
            // 调用API获取响应
            const response = await this.apiClient.chat(this.messages);

            // 添加助手响应
            this.messages.push({ role: 'assistant', content: response });
            this.updateChat();

        } catch (error) {
            vscode.window.showErrorMessage('Failed to get response from AI');
            console.error(error);
        }
    }

    private updateChat() {
        if (this.view) {
            this.view.webview.postMessage({
                type: 'updateMessages',
                messages: this.messages
            });
        }
    }

    private getHtmlForWebview(webview: vscode.Webview): string {
        return `<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>AI Assistant Chat</title>
    <style>
        body {
            padding: 10px;
            font-family: var(--vscode-font-family);
            color: var(--vscode-foreground);
        }

        #messages {
            height: calc(100vh - 120px);
            overflow-y: auto;
            padding: 10px;
            margin-bottom: 10px;
        }

        .message {
            margin-bottom: 15px;
            padding: 10px;
            border-radius: 5px;
        }

        .user {
            background-color: var(--vscode-input-background);
            margin-left: 20px;
        }

        .assistant {
            background-color: var(--vscode-editor-background);
            margin-right: 20px;
        }

        .role {
            font-weight: bold;
            margin-bottom: 5px;
            font-size: 0.9em;
        }

        #input-container {
            display: flex;
            gap: 10px;
        }

        #message-input {
            flex: 1;
            padding: 8px;
            background-color: var(--vscode-input-background);
            color: var(--vscode-input-foreground);
            border: 1px solid var(--vscode-input-border);
            border-radius: 3px;
        }

        button {
            padding: 8px 15px;
            background-color: var(--vscode-button-background);
            color: var(--vscode-button-foreground);
            border: none;
            border-radius: 3px;
            cursor: pointer;
        }

        button:hover {
            background-color: var(--vscode-button-hoverBackground);
        }

        pre {
            background-color: var(--vscode-textCodeBlock-background);
            padding: 10px;
            border-radius: 3px;
            overflow-x: auto;
        }

        code {
            font-family: var(--vscode-editor-font-family);
        }
    </style>
</head>
<body>
    <div id="messages"></div>

    <div id="input-container">
        <input
            type="text"
            id="message-input"
            placeholder="Ask me anything about your code..."
            autocomplete="off"
        />
        <button onclick="sendMessage()">Send</button>
        <button onclick="clearChat()">Clear</button>
    </div>

    <script>
        const vscode = acquireVsCodeApi();

        function sendMessage() {
            const input = document.getElementById('message-input');
            const message = input.value.trim();

            if (message) {
                vscode.postMessage({
                    type: 'sendMessage',
                    message: message
                });
                input.value = '';
            }
        }

        function clearChat() {
            vscode.postMessage({ type: 'clearChat' });
        }

        // 监听回车键
        document.getElementById('message-input').addEventListener('keypress', (e) => {
            if (e.key === 'Enter') {
                sendMessage();
            }
        });

        // 接收消息更新
        window.addEventListener('message', (event) => {
            const message = event.data;

            if (message.type === 'updateMessages') {
                updateMessagesUI(message.messages);
            }
        });

        function updateMessagesUI(messages) {
            const container = document.getElementById('messages');
            container.innerHTML = '';

            messages.forEach(msg => {
                const div = document.createElement('div');
                div.className = \`message \${msg.role}\`;

                const roleDiv = document.createElement('div');
                roleDiv.className = 'role';
                roleDiv.textContent = msg.role === 'user' ? 'You' : 'AI Assistant';

                const contentDiv = document.createElement('div');
                contentDiv.innerHTML = formatMessage(msg.content);

                div.appendChild(roleDiv);
                div.appendChild(contentDiv);
                container.appendChild(div);
            });

            // 滚动到底部
            container.scrollTop = container.scrollHeight;
        }

        function formatMessage(content) {
            // 简单的Markdown格式化
            // 代码块
            content = content.replace(/\`\`\`([\\s\\S]*?)\`\`\`/g, '<pre><code>$1</code></pre>');
            // 行内代码
            content = content.replace(/\`([^\`]+)\`/g, '<code>$1</code>');
            // 换行
            content = content.replace(/\\n/g, '<br>');

            return content;
        }
    </script>
</body>
</html>`;
    }
}

4. API客户端 (apiClient.ts):

import axios, { AxiosInstance } from 'axios';

export interface CompletionRequest {
    before: string;
    after: string;
    language: string;
    maxTokens?: number;
    temperature?: number;
}

export interface CompletionResponse {
    completions: Array<{
        text: string;
        confidence: number;
    }>;
}

export class APIClient {
    private client: AxiosInstance;

    constructor(
        private baseURL: string,
        private modelName: string
    ) {
        this.client = axios.create({
            baseURL: this.baseURL,
            timeout: 30000,
            headers: {
                'Content-Type': 'application/json'
            }
        });
    }

    async getCompletions(context: any): Promise<Array<{ text: string; confidence: number }>> {
        try {
            const response = await this.client.post<CompletionResponse>('/completions', {
                model: this.modelName,
                before: context.before,
                after: context.after,
                language: context.language,
                maxTokens: context.maxTokens || 256,
                temperature: context.temperature || 0.2
            });

            return response.data.completions;
        } catch (error) {
            console.error('API request failed:', error);
            throw error;
        }
    }

    async chat(messages: Array<{ role: string; content: string }>): Promise<string> {
        try {
            const response = await this.client.post('/chat', {
                model: this.modelName,
                messages: messages
            });

            return response.data.response;
        } catch (error) {
            console.error('Chat request failed:', error);
            throw error;
        }
    }

    async explainCode(code: string): Promise<string> {
        try {
            const response = await this.client.post('/explain', {
                model: this.modelName,
                code: code
            });

            return response.data.explanation;
        } catch (error) {
            console.error('Explain request failed:', error);
            throw error;
        }
    }

    async refactorCode(code: string): Promise<string> {
        try {
            const response = await this.client.post('/refactor', {
                model: this.modelName,
                code: code
            });

            return response.data.refactored;
        } catch (error) {
            console.error('Refactor request failed:', error);
            throw error;
        }
    }

    async generateTests(code: string): Promise<string> {
        try {
            const response = await this.client.post('/generate-tests', {
                model: this.modelName,
                code: code
            });

            return response.data.tests;
        } catch (error) {
            console.error('Generate tests request failed:', error);
            throw error;
        }
    }
}

5. 上下文构建器 (contextBuilder.ts):

import * as vscode from 'vscode';

export interface CodeContext {
    before: string;
    after: string;
    language: string;
    currentLine: string;
    fileName: string;
    imports: string[];
}

export class ContextBuilder {
    private readonly maxContextLength = 2000; // 字符

    buildContext(
        document: vscode.TextDocument,
        position: vscode.Position
    ): CodeContext {
        const text = document.getText();
        const offset = document.offsetAt(position);

        // 获取前后文本
        let before = text.substring(0, offset);
        let after = text.substring(offset);

        // 限制长度
        if (before.length > this.maxContextLength) {
            before = this.truncateFromStart(before, this.maxContextLength);
        }
        if (after.length > this.maxContextLength) {
            after = after.substring(0, this.maxContextLength);
        }

        // 获取当前行
        const currentLine = document.lineAt(position.line).text;

        // 提取导入语句
        const imports = this.extractImports(document);

        return {
            before,
            after,
            language: document.languageId,
            currentLine,
            fileName: document.fileName,
            imports
        };
    }

    private truncateFromStart(text: string, maxLength: number): string {
        if (text.length <= maxLength) {
            return text;
        }

        // 从最近的换行符开始截断,保持代码结构
        const truncated = text.substring(text.length - maxLength);
        const firstNewline = truncated.indexOf('\n');

        if (firstNewline !== -1) {
            return truncated.substring(firstNewline + 1);
        }

        return truncated;
    }

    private extractImports(document: vscode.TextDocument): string[] {
        const imports: string[] = [];
        const text = document.getText();
        const lines = text.split('\n');

        for (const line of lines) {
            const trimmed = line.trim();

            // Python导入
            if (trimmed.startsWith('import ') || trimmed.startsWith('from ')) {
                imports.push(trimmed);
            }

            // TypeScript/JavaScript导入
            else if (trimmed.startsWith('import ') || trimmed.startsWith('require(')) {
                imports.push(trimmed);
            }

            // 如果已经过了导入区域(遇到非导入非空行),停止
            else if (trimmed && !trimmed.startsWith('#') && !trimmed.startswith('//')) {
                break;
            }
        }

        return imports;
    }
}

14.4.3 后端API服务

为了支持VSCode插件,我们需要一个后端API服务:

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import uvicorn

# 假设我们已经有了前面实现的代码生成器
from code_generator import CodeGenerator, IntelligentCodeCompletion

app = FastAPI(title="AI Code Assistant API")

# 初始化模型
generator = CodeGenerator()
completion_system = IntelligentCodeCompletion(generator)

class CompletionRequest(BaseModel):
    model: str
    before: str
    after: str
    language: str
    maxTokens: Optional[int] = 256
    temperature: Optional[float] = 0.2

class CompletionResponse(BaseModel):
    completions: List[dict]

class ChatMessage(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    model: str
    messages: List[ChatMessage]

class ChatResponse(BaseModel):
    response: str

class CodeRequest(BaseModel):
    model: str
    code: str

class CodeResponse(BaseModel):
    result: str

@app.post("/completions", response_model=CompletionResponse)
async def get_completions(request: CompletionRequest):
    """获取代码补全"""
    try:
        context = {
            'before': request.before,
            'after': request.after,
            'language': request.language
        }

        suggestions = completion_system.get_completion(context)

        return CompletionResponse(
            completions=[
                {
                    'text': sugg['text'],
                    'confidence': sugg['confidence']
                }
                for sugg in suggestions[:5]  # 返回前5个
            ]
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    """聊天接口"""
    try:
        # 这里简化处理,实际应该使用对话模型
        last_message = request.messages[-1].content

        # 简单响应
        response = f"收到您的消息:{last_message}。这是一个演示响应。"

        return ChatResponse(response=response)

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/explain")
async def explain_code(request: CodeRequest):
    """解释代码"""
    try:
        # 使用代码生成器生成解释
        prompt = f"Explain the following code:\n\n{request.code}\n\nExplanation:"
        explanation = generator.generate(prompt, max_length=512, temperature=0.3)[0]

        return {"explanation": explanation}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/refactor")
async def refactor_code(request: CodeRequest):
    """重构代码"""
    try:
        prompt = f"Refactor the following code to make it cleaner and more efficient:\n\n{request.code}\n\nRefactored code:\n"
        refactored = generator.generate(prompt, max_length=512, temperature=0.2)[0]

        return {"refactored": refactored}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/generate-tests")
async def generate_tests(request: CodeRequest):
    """生成测试"""
    try:
        prompt = f"Generate unit tests for the following code:\n\n{request.code}\n\nTests:\n"
        tests = generator.generate(prompt, max_length=512, temperature=0.3)[0]

        return {"tests": tests}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """健康检查"""
    return {"status": "ok", "model": "codegen"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

14.5 性能优化与最佳实践

14.5.1 缓存策略

import hashlib
import pickle
import time
from typing import Any, Optional
from collections import OrderedDict

class LRUCache:
    """LRU缓存实现"""

    def __init__(self, capacity: int = 100, ttl: int = 3600):
        """
        Args:
            capacity: 最大缓存项数
            ttl: 过期时间(秒)
        """
        self.cache = OrderedDict()
        self.capacity = capacity
        self.ttl = ttl
        self.timestamps = {}

    def get(self, key: str) -> Optional[Any]:
        """获取缓存值"""
        if key not in self.cache:
            return None

        # 检查是否过期
        if time.time() - self.timestamps[key] > self.ttl:
            self.cache.pop(key)
            self.timestamps.pop(key)
            return None

        # 移到末尾(最近使用)
        self.cache.move_to_end(key)
        return self.cache[key]

    def set(self, key: str, value: Any):
        """设置缓存值"""
        if key in self.cache:
            self.cache.move_to_end(key)
        else:
            if len(self.cache) >= self.capacity:
                # 移除最久未使用的项
                oldest = next(iter(self.cache))
                self.cache.pop(oldest)
                self.timestamps.pop(oldest)

        self.cache[key] = value
        self.timestamps[key] = time.time()

    def clear(self):
        """清空缓存"""
        self.cache.clear()
        self.timestamps.clear()

class SmartCache:
    """智能缓存系统"""

    def __init__(self, capacity: int = 1000):
        self.completion_cache = LRUCache(capacity, ttl=300)  # 5分钟
        self.embedding_cache = LRUCache(capacity, ttl=3600)  # 1小时
        self.analysis_cache = LRUCache(capacity, ttl=1800)   # 30分钟

    def _hash_key(self, data: str) -> str:
        """生成哈希键"""
        return hashlib.md5(data.encode()).hexdigest()

    def get_completion(self, context: str) -> Optional[list]:
        """获取补全缓存"""
        key = self._hash_key(context)
        return self.completion_cache.get(key)

    def set_completion(self, context: str, completions: list):
        """设置补全缓存"""
        key = self._hash_key(context)
        self.completion_cache.set(key, completions)

    def get_embedding(self, code: str) -> Optional[Any]:
        """获取嵌入缓存"""
        key = self._hash_key(code)
        return self.embedding_cache.get(key)

    def set_embedding(self, code: str, embedding: Any):
        """设置嵌入缓存"""
        key = self._hash_key(code)
        self.embedding_cache.set(key, embedding)

# 使用示例
cache = SmartCache()

# 缓存补全结果
context = "def calculate_sum(numbers):\n    "
completions = ["total = sum(numbers)", "result = 0; for n in numbers: result += n"]
cache.set_completion(context, completions)

# 获取缓存
cached = cache.get_completion(context)
if cached:
    print("使用缓存的补全结果")
else:
    print("需要重新生成")

14.5.2 批处理与并发

import asyncio
from typing import List, Callable
import concurrent.futures

class BatchProcessor:
    """批处理器"""

    def __init__(self, batch_size: int = 8, max_workers: int = 4):
        self.batch_size = batch_size
        self.max_workers = max_workers
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)

    async def process_batch(
        self,
        items: List[Any],
        process_func: Callable
    ) -> List[Any]:
        """异步批处理"""
        results = []

        # 分批处理
        for i in range(0, len(items), self.batch_size):
            batch = items[i:i + self.batch_size]

            # 并发处理批次中的项
            loop = asyncio.get_event_loop()
            batch_results = await asyncio.gather(*[
                loop.run_in_executor(self.executor, process_func, item)
                for item in batch
            ])

            results.extend(batch_results)

        return results

    def shutdown(self):
        """关闭执行器"""
        self.executor.shutdown(wait=True)

# 使用示例
async def main():
    processor = BatchProcessor(batch_size=4, max_workers=2)

    def process_code(code):
        # 模拟处理
        time.sleep(0.1)
        return f"Processed: {code[:20]}..."

    codes = [f"code_sample_{i}" for i in range(10)]
    results = await processor.process_batch(codes, process_code)

    print(f"处理了 {len(results)} 个代码样本")
    processor.shutdown()

# asyncio.run(main())

14.6 总结

本章深入介绍了AI编程助手的核心技术:

  1. 代码理解:AST解析、语义嵌入、调用图分析
  2. 代码生成:基于语言模型的生成、质量评估、智能排序
  3. VSCode插件:完整的插件架构和实现
  4. 性能优化:缓存策略、批处理、并发处理

这些技术构成了现代AI编程助手的基础。通过理解和实践这些内容,你可以构建自己的AI编程工具,或者更好地理解和使用现有的工具。

下一章我们将探讨RAG(检索增强生成)系统的设计与实现。

Prev
第13章:RLHF与对齐技术
Next
15-RAG系统设计与实现