扩展开发指南
1. 扩展开发概述
1.1 系统扩展点
AI UI生成系统采用模块化设计,提供了多个扩展点,允许开发者根据需求添加新功能:
1.2 开发规范
代码规范:
- 遵循PEP 8 Python编码规范
- 使用类型提示(Type Hints)
- 编写完整的文档字符串
- 包含单元测试
扩展规范:
- 保持向后兼容性
- 提供配置选项
- 支持错误处理
- 包含使用示例
2. 添加新UI组件
2.1 组件系统架构
基于 render/render_to_image.py 的组件系统:
# 组件基类定义
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple
from PIL import Image, ImageDraw
class UIComponent(ABC):
"""UI组件基类"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.component_type = self.__class__.__name__.lower()
@abstractmethod
def render(self, draw: ImageDraw.Draw, theme_colors: Dict[str, str],
x: int, y: int, width: int, height: int) -> int:
"""渲染组件"""
pass
def get_height(self, width: int) -> int:
"""计算组件高度"""
return 50 # 默认高度
def validate_props(self, props: Dict[str, Any]) -> bool:
"""验证组件属性"""
return True
2.2 实现自定义组件
示例:添加进度条组件
# 进度条组件实现
class ProgressBarComponent(UIComponent):
"""进度条组件"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.component_type = "progress-bar"
def render(self, draw: ImageDraw.Draw, theme_colors: Dict[str, str],
x: int, y: int, width: int, height: int) -> int:
"""渲染进度条"""
props = self.config.get("props", {})
progress = props.get("progress", 0.5) # 0-1之间
show_text = props.get("show_text", True)
# 绘制背景
bg_color = self._hex_to_rgb(theme_colors["border"])
draw.rounded_rectangle(
(x, y, x + width, y + height),
radius=height // 2,
fill=bg_color
)
# 绘制进度条
progress_width = int(width * progress)
if progress_width > 0:
progress_color = self._hex_to_rgb(theme_colors["primary"])
draw.rounded_rectangle(
(x, y, x + progress_width, y + height),
radius=height // 2,
fill=progress_color
)
# 绘制文本
if show_text:
text = f"{int(progress * 100)}%"
text_color = self._hex_to_rgb(theme_colors["text"])
text_bbox = draw.textbbox((0, 0), text, font=self.fonts["small"])
text_width = text_bbox[2] - text_bbox[0]
text_x = x + (width - text_width) // 2
text_y = y + (height - text_bbox[3] + text_bbox[1]) // 2
draw.text((text_x, text_y), text, fill=text_color, font=self.fonts["small"])
return height
def get_height(self, width: int) -> int:
"""计算进度条高度"""
return 20
def validate_props(self, props: Dict[str, Any]) -> bool:
"""验证进度条属性"""
progress = props.get("progress", 0.5)
return 0 <= progress <= 1
2.3 注册新组件
# 在UIRenderer中注册新组件
class UIRenderer:
def __init__(self, config_path: str = "config/model_config.yaml",
tokens_path: str = "config/ui_tokens.json"):
# ... 现有初始化代码 ...
# 注册自定义组件
self.component_renderers.update({
"progress-bar": self._render_progress_bar,
"chart": self._render_chart,
"timeline": self._render_timeline,
# 添加更多组件...
})
def _render_progress_bar(self, draw: ImageDraw.Draw, section: Dict,
theme_colors: Dict, y_offset: int) -> int:
"""渲染进度条组件"""
component = ProgressBarComponent(section)
return component.render(draw, theme_colors, 16, y_offset,
self.width - 32, component.get_height(self.width - 32))
2.4 Vue组件对应实现
<!-- ProgressBar.vue 组件 -->
<template>
<view class="progress-bar" :class="themeClass">
<view class="progress-track">
<view
class="progress-fill"
:style="{ width: progressPercentage + '%' }"
></view>
</view>
<text v-if="showText" class="progress-text">
{{ progressPercentage }}%
</text>
</view>
</template>
<script>
export default {
name: 'ProgressBar',
props: {
progress: {
type: Number,
default: 0.5,
validator: value => value >= 0 && value <= 1
},
showText: {
type: Boolean,
default: true
},
theme: {
type: String,
default: 'obsidian-gold'
}
},
computed: {
progressPercentage() {
return Math.round(this.progress * 100)
},
themeClass() {
return `theme-${this.theme}`
}
}
}
</script>
<style lang="scss" scoped>
.progress-bar {
display: flex;
align-items: center;
gap: 8px;
}
.progress-track {
flex: 1;
height: 20px;
background-color: var(--border-color);
border-radius: 10px;
overflow: hidden;
}
.progress-fill {
height: 100%;
background-color: var(--primary-color);
transition: width 0.3s ease;
}
.progress-text {
font-size: 12px;
color: var(--text-color);
min-width: 40px;
text-align: right;
}
</style>
3. 添加新主题
3.1 主题系统架构
基于 config/ui_tokens.json 的主题系统:
# 主题管理器
class ThemeManager:
def __init__(self, tokens_path: str = "config/ui_tokens.json"):
with open(tokens_path, 'r', encoding='utf-8') as f:
self.tokens = json.load(f)
self.themes = self.tokens.get("themes", {})
def get_theme(self, theme_name: str) -> Dict[str, Any]:
"""获取主题配置"""
return self.themes.get(theme_name, self.themes["obsidian-gold"])
def add_theme(self, theme_name: str, theme_config: Dict[str, Any]):
"""添加新主题"""
self.themes[theme_name] = theme_config
self._save_themes()
def _save_themes(self):
"""保存主题配置"""
self.tokens["themes"] = self.themes
with open(self.tokens_path, 'w', encoding='utf-8') as f:
json.dump(self.tokens, f, ensure_ascii=False, indent=2)
3.2 定义新主题
示例:添加科技蓝主题
{
"themes": {
"tech-blue": {
"name": "科技蓝主题",
"colors": {
"primary": "#007AFF",
"secondary": "#5AC8FA",
"background": "#F2F2F7",
"surface": "#FFFFFF",
"text": "#000000",
"text_secondary": "#8E8E93",
"border": "#C6C6C8",
"accent": "#FF3B30",
"success": "#34C759",
"warning": "#FF9500",
"error": "#FF3B30"
},
"typography": {
"font_family": "SF Pro Display, -apple-system, BlinkMacSystemFont, sans-serif",
"font_size": {
"xs": "12px",
"sm": "14px",
"base": "16px",
"lg": "18px",
"xl": "20px",
"2xl": "24px",
"3xl": "30px"
},
"font_weight": {
"normal": "400",
"medium": "500",
"semibold": "600",
"bold": "700"
}
},
"spacing": {
"xs": "4px",
"sm": "8px",
"md": "16px",
"lg": "24px",
"xl": "32px",
"2xl": "48px"
},
"border_radius": {
"sm": "6px",
"md": "12px",
"lg": "18px",
"xl": "24px",
"full": "50%"
},
"shadows": {
"sm": "0 2px 4px rgba(0, 122, 255, 0.1)",
"md": "0 4px 8px rgba(0, 122, 255, 0.15)",
"lg": "0 8px 16px rgba(0, 122, 255, 0.2)",
"xl": "0 16px 32px rgba(0, 122, 255, 0.25)"
}
}
}
}
3.3 主题切换机制
# 主题切换实现
class ThemeSwitcher:
def __init__(self, theme_manager: ThemeManager):
self.theme_manager = theme_manager
self.current_theme = "obsidian-gold"
def switch_theme(self, theme_name: str) -> bool:
"""切换主题"""
if theme_name in self.theme_manager.themes:
self.current_theme = theme_name
return True
return False
def get_current_theme_colors(self) -> Dict[str, str]:
"""获取当前主题颜色"""
theme = self.theme_manager.get_theme(self.current_theme)
return theme.get("colors", {})
def apply_theme_to_dsl(self, dsl: Dict[str, Any]) -> Dict[str, Any]:
"""将主题应用到DSL"""
dsl_copy = dsl.copy()
if "page" in dsl_copy:
dsl_copy["page"]["theme"] = self.current_theme
return dsl_copy
4. 添加新输出格式
4.1 渲染器接口设计
# 渲染器基类
from abc import ABC, abstractmethod
from typing import Dict, Any, Union
class Renderer(ABC):
"""渲染器基类"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.output_format = self.get_output_format()
@abstractmethod
def get_output_format(self) -> str:
"""获取输出格式名称"""
pass
@abstractmethod
def render(self, dsl: Dict[str, Any]) -> Union[str, bytes]:
"""渲染DSL为指定格式"""
pass
@abstractmethod
def validate_dsl(self, dsl: Dict[str, Any]) -> bool:
"""验证DSL格式"""
pass
def get_file_extension(self) -> str:
"""获取文件扩展名"""
return f".{self.output_format}"
4.2 实现自定义渲染器
示例:添加React Native渲染器
# React Native渲染器实现
class ReactNativeRenderer(Renderer):
"""React Native渲染器"""
def get_output_format(self) -> str:
return "jsx"
def render(self, dsl: Dict[str, Any]) -> str:
"""渲染为React Native代码"""
page = dsl.get("page", {})
theme = page.get("theme", "obsidian-gold")
sections = page.get("sections", [])
# 生成React Native组件
component_code = self._generate_component(sections, theme)
return component_code
def _generate_component(self, sections: List[Dict], theme: str) -> str:
"""生成React Native组件代码"""
imports = self._generate_imports()
component_def = self._generate_component_definition(theme)
render_method = self._generate_render_method(sections)
styles = self._generate_styles(theme)
return f"""{imports}
{component_def}
{render_method}
{styles}
"""
def _generate_imports(self) -> str:
"""生成导入语句"""
return """import React from 'react';
import { View, Text, ScrollView, TouchableOpacity, Image } from 'react-native';
import { styles } from './styles';"""
def _generate_component_definition(self, theme: str) -> str:
"""生成组件定义"""
return f"""export default function GeneratedUI() {{
return (
<ScrollView style={{styles.container}}>
{{/* 组件内容 */}}
</ScrollView>
);
}}"""
def _generate_render_method(self, sections: List[Dict]) -> str:
"""生成渲染方法"""
jsx_elements = []
for section in sections:
jsx = self._render_section(section)
jsx_elements.append(jsx)
return " " + "\n ".join(jsx_elements)
def _render_section(self, section: Dict) -> str:
"""渲染单个section"""
section_type = section.get("type")
props = section.get("props", {})
if section_type == "topbar":
return self._render_topbar_jsx(props)
elif section_type == "card-list":
return self._render_card_list_jsx(props)
# 添加更多组件类型...
return " <View></View>"
def _render_topbar_jsx(self, props: Dict) -> str:
"""渲染顶部栏JSX"""
title = props.get("title", "")
actions = props.get("actions", [])
action_buttons = ""
for action in actions:
action_buttons += f"""
<TouchableOpacity style={{styles.actionButton}}>
<Text style={{styles.actionText}}>{action}</Text>
</TouchableOpacity>"""
return f"""<View style={{styles.topbar}}>
<Text style={{styles.title}}>{title}</Text>
<View style={{styles.actions}}>
{action_buttons}
</View>
</View>"""
def _generate_styles(self, theme: str) -> str:
"""生成样式"""
return f"""const styles = StyleSheet.create({{
container: {{
flex: 1,
backgroundColor: '#F2F2F7',
}},
topbar: {{
height: 44,
flexDirection: 'row',
alignItems: 'center',
justifyContent: 'space-between',
paddingHorizontal: 16,
backgroundColor: '#FFFFFF',
}},
title: {{
fontSize: 18,
fontWeight: '600',
color: '#000000',
}},
actions: {{
flexDirection: 'row',
gap: 8,
}},
actionButton: {{
padding: 8,
}},
actionText: {{
fontSize: 16,
color: '#007AFF',
}},
}});"""
def validate_dsl(self, dsl: Dict[str, Any]) -> bool:
"""验证DSL格式"""
required_keys = ["page"]
if not all(key in dsl for key in required_keys):
return False
page = dsl.get("page", {})
if "sections" not in page:
return False
return True
4.3 注册新渲染器
# 在渲染系统中注册新渲染器
class RendererRegistry:
def __init__(self):
self.renderers = {}
def register_renderer(self, renderer: Renderer):
"""注册渲染器"""
self.renderers[renderer.get_output_format()] = renderer
def get_renderer(self, output_format: str) -> Renderer:
"""获取渲染器"""
return self.renderers.get(output_format)
def list_formats(self) -> List[str]:
"""列出所有支持的格式"""
return list(self.renderers.keys())
# 使用示例
registry = RendererRegistry()
registry.register_renderer(ReactNativeRenderer({}))
registry.register_renderer(FlutterRenderer({}))
registry.register_renderer(SwiftUIRenderer({}))
5. 集成新AI模型
5.1 模型接口设计
# 模型接口定义
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
class AIModel(ABC):
"""AI模型基类"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.model = None
self.tokenizer = None
@abstractmethod
def load_model(self, model_path: str, **kwargs):
"""加载模型"""
pass
@abstractmethod
def generate(self, prompt: str, **kwargs) -> str:
"""生成文本"""
pass
@abstractmethod
def batch_generate(self, prompts: List[str], **kwargs) -> List[str]:
"""批量生成"""
pass
@abstractmethod
def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息"""
pass
5.2 实现新模型支持
示例:集成ChatGLM模型
# ChatGLM模型实现
from transformers import AutoTokenizer, AutoModel
import torch
class ChatGLMModel(AIModel):
"""ChatGLM模型实现"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.model_name = "THUDM/chatglm-6b"
def load_model(self, model_path: str, **kwargs):
"""加载ChatGLM模型"""
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
# 加载分词器
self.tokenizer = AutoTokenizer.from_pretrained(
model_path or self.model_name,
trust_remote_code=True
)
# 加载模型
self.model = AutoModel.from_pretrained(
model_path or self.model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None
)
self.model.eval()
def generate(self, prompt: str, **kwargs) -> str:
"""生成文本"""
max_length = kwargs.get("max_length", 512)
temperature = kwargs.get("temperature", 0.7)
# 编码输入
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=max_length,
truncation=True
)
# 生成
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# 解码输出
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def batch_generate(self, prompts: List[str], **kwargs) -> List[str]:
"""批量生成"""
results = []
for prompt in prompts:
result = self.generate(prompt, **kwargs)
results.append(result)
return results
def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息"""
return {
"model_name": self.model_name,
"model_type": "chatglm",
"max_length": self.config.get("max_length", 512),
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
5.3 模型管理器
# 模型管理器
class ModelManager:
def __init__(self):
self.models = {}
self.model_configs = {}
def register_model(self, model_type: str, model_class: type, config: Dict[str, Any]):
"""注册模型类型"""
self.model_configs[model_type] = {
"class": model_class,
"config": config
}
def load_model(self, model_type: str, model_path: str, **kwargs) -> AIModel:
"""加载模型"""
if model_type not in self.model_configs:
raise ValueError(f"Unknown model type: {model_type}")
config = self.model_configs[model_type]["config"]
model_class = self.model_configs[model_type]["class"]
model = model_class(config)
model.load_model(model_path, **kwargs)
self.models[model_type] = model
return model
def get_model(self, model_type: str) -> AIModel:
"""获取已加载的模型"""
return self.models.get(model_type)
def list_models(self) -> List[str]:
"""列出所有模型类型"""
return list(self.model_configs.keys())
# 使用示例
model_manager = ModelManager()
model_manager.register_model("chatglm", ChatGLMModel, {"max_length": 512})
model_manager.register_model("llama", LLaMAModel, {"max_length": 1024})
6. API扩展
6.1 添加新端点
基于 api/main.py 的扩展:
# 新API端点示例
from fastapi import APIRouter, HTTPException, Depends
from typing import List, Optional
# 创建子路由器
extension_router = APIRouter(prefix="/api/v1/extensions", tags=["extensions"])
@extension_router.post("/custom-render")
async def custom_render(
dsl: Dict[str, Any],
output_format: str,
custom_config: Optional[Dict[str, Any]] = None
):
"""自定义渲染端点"""
try:
# 获取渲染器
renderer = renderer_registry.get_renderer(output_format)
if not renderer:
raise HTTPException(status_code=400, detail=f"Unsupported format: {output_format}")
# 验证DSL
if not renderer.validate_dsl(dsl):
raise HTTPException(status_code=400, detail="Invalid DSL format")
# 应用自定义配置
if custom_config:
dsl = apply_custom_config(dsl, custom_config)
# 执行渲染
result = renderer.render(dsl)
return {
"success": True,
"data": result,
"format": output_format
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@extension_router.get("/themes")
async def list_themes():
"""获取所有主题列表"""
theme_manager = ThemeManager()
themes = {}
for theme_name, theme_config in theme_manager.themes.items():
themes[theme_name] = {
"name": theme_config.get("name", theme_name),
"colors": theme_config.get("colors", {}),
"description": theme_config.get("description", "")
}
return {"success": True, "themes": themes}
@extension_router.post("/themes")
async def create_theme(theme_name: str, theme_config: Dict[str, Any]):
"""创建新主题"""
try:
theme_manager = ThemeManager()
theme_manager.add_theme(theme_name, theme_config)
return {
"success": True,
"message": f"Theme '{theme_name}' created successfully"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 在主应用中注册路由器
app.include_router(extension_router)
6.2 中间件开发
# 自定义中间件
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
import time
import logging
class PerformanceMiddleware(BaseHTTPMiddleware):
"""性能监控中间件"""
def __init__(self, app):
super().__init__(app)
self.logger = logging.getLogger(__name__)
async def dispatch(self, request: Request, call_next):
# 记录请求开始时间
start_time = time.time()
# 处理请求
response = await call_next(request)
# 计算处理时间
process_time = time.time() - start_time
# 记录性能指标
self.logger.info(f"Request {request.method} {request.url.path} "
f"completed in {process_time:.3f}s")
# 添加性能头
response.headers["X-Process-Time"] = str(process_time)
return response
class AuthenticationMiddleware(BaseHTTPMiddleware):
"""认证中间件"""
def __init__(self, app, api_key: str = None):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# 跳过健康检查端点
if request.url.path in ["/health", "/docs", "/redoc"]:
return await call_next(request)
# 检查API密钥
if self.api_key:
auth_header = request.headers.get("Authorization")
if not auth_header or auth_header != f"Bearer {self.api_key}":
return Response(
content='{"error": "Unauthorized"}',
status_code=401,
media_type="application/json"
)
return await call_next(request)
# 注册中间件
app.add_middleware(PerformanceMiddleware)
app.add_middleware(AuthenticationMiddleware, api_key="your-api-key")
6.3 插件机制
# 插件系统
class PluginManager:
def __init__(self):
self.plugins = {}
self.hooks = {}
def register_plugin(self, name: str, plugin_class: type):
"""注册插件"""
self.plugins[name] = plugin_class()
def register_hook(self, event: str, callback):
"""注册钩子函数"""
if event not in self.hooks:
self.hooks[event] = []
self.hooks[event].append(callback)
def trigger_hook(self, event: str, *args, **kwargs):
"""触发钩子"""
if event in self.hooks:
for callback in self.hooks[event]:
callback(*args, **kwargs)
def get_plugin(self, name: str):
"""获取插件"""
return self.plugins.get(name)
# 插件基类
class Plugin:
def __init__(self, name: str):
self.name = name
def on_load(self):
"""插件加载时调用"""
pass
def on_unload(self):
"""插件卸载时调用"""
pass
def on_request(self, request: Request):
"""请求处理前调用"""
pass
def on_response(self, response: Response):
"""响应生成后调用"""
pass
# 示例插件:日志记录插件
class LoggingPlugin(Plugin):
def __init__(self):
super().__init__("logging")
self.logger = logging.getLogger(__name__)
def on_request(self, request: Request):
self.logger.info(f"Request: {request.method} {request.url.path}")
def on_response(self, response: Response):
self.logger.info(f"Response: {response.status_code}")
# 使用插件系统
plugin_manager = PluginManager()
plugin_manager.register_plugin("logging", LoggingPlugin)
plugin_manager.register_hook("request", plugin_manager.get_plugin("logging").on_request)
7. 扩展开发最佳实践
7.1 开发流程
- 需求分析:明确扩展需求和功能范围
- 接口设计:设计清晰的接口和数据结构
- 实现开发:按照规范实现功能
- 测试验证:编写单元测试和集成测试
- 文档编写:编写使用文档和API文档
- 代码审查:进行代码审查和优化
- 部署发布:部署到测试和生产环境
7.2 测试策略
# 扩展功能测试示例
import pytest
from unittest.mock import Mock, patch
class TestProgressBarComponent:
def test_render_progress_bar(self):
"""测试进度条渲染"""
component = ProgressBarComponent({
"props": {"progress": 0.7, "show_text": True}
})
# 模拟PIL对象
mock_draw = Mock()
mock_draw.textbbox.return_value = (0, 0, 50, 20)
theme_colors = {"primary": "#007AFF", "text": "#000000"}
# 执行渲染
height = component.render(mock_draw, theme_colors, 0, 0, 100, 20)
# 验证结果
assert height == 20
assert mock_draw.rounded_rectangle.call_count >= 2 # 背景和进度条
assert mock_draw.text.called # 文本渲染
def test_validate_props(self):
"""测试属性验证"""
component = ProgressBarComponent({})
# 有效属性
assert component.validate_props({"progress": 0.5}) == True
assert component.validate_props({"progress": 0.0}) == True
assert component.validate_props({"progress": 1.0}) == True
# 无效属性
assert component.validate_props({"progress": 1.5}) == False
assert component.validate_props({"progress": -0.1}) == False
class TestThemeManager:
def test_add_theme(self):
"""测试添加主题"""
theme_manager = ThemeManager()
new_theme = {
"name": "测试主题",
"colors": {"primary": "#FF0000"}
}
theme_manager.add_theme("test-theme", new_theme)
# 验证主题已添加
theme = theme_manager.get_theme("test-theme")
assert theme["name"] == "测试主题"
assert theme["colors"]["primary"] == "#FF0000"
7.3 性能考虑
- 内存管理:及时释放不需要的资源
- 缓存策略:合理使用缓存减少重复计算
- 异步处理:使用异步处理提升并发能力
- 资源池化:使用连接池等资源池化技术
- 监控指标:添加性能监控和指标收集
7.4 安全考虑
- 输入验证:严格验证所有输入参数
- 权限控制:实现适当的权限控制机制
- 数据加密:对敏感数据进行加密处理
- 日志审计:记录关键操作和访问日志
- 错误处理:避免敏感信息泄露
通过以上扩展开发指南,开发者可以根据具体需求灵活扩展AI UI生成系统的功能,实现定制化的UI生成解决方案。