Compare commits
3 Commits
0e068c4d25
...
bf3d340aa8
| Author | SHA1 | Date | |
|---|---|---|---|
| bf3d340aa8 | |||
| 88793da902 | |||
| 43f3e0b746 |
39
.env
39
.env
@ -1,35 +1,22 @@
|
||||
# 报告生成独立服务环境配置(沿用原 eval_report 的外部服务)
|
||||
# 范文/模板解析子服务(section_reference_block)环境配置
|
||||
# 与网关 eval_report、report_generation 同库
|
||||
|
||||
# 数据库(MySQL,与原项目同库)
|
||||
# 远程 MySQL(章节内容入库目标,与主项目同库)
|
||||
DATABASE_URL=mysql+pymysql://root:Beidas0ft@192.168.4.177:3306/eval_report?charset=utf8mb4
|
||||
DB_POOL_SIZE=15
|
||||
DB_MAX_OVERFLOW=25
|
||||
DB_POOL_TIMEOUT=60
|
||||
DB_POOL_PRE_PING=true
|
||||
DB_AUTO_CREATE_TABLES=true
|
||||
|
||||
# 文档存储根目录(附图提取按 DOC_PAT/{project_uuid}/<相对路径> 定位 .docx)
|
||||
# 指向原项目的 docs 目录,保证附图能被找到
|
||||
DOC_PAT=D:/Git-Project/eval_report/docs
|
||||
# 远程文档解析服务(上传文档 → Markdown)
|
||||
FILE_PARSE_API_URL=http://192.168.4.194:8000/convert
|
||||
FILE_PARSE_FIELD_NAME=file
|
||||
FILE_PARSE_ENGINE=auto
|
||||
FILE_PARSE_HTTP_TIMEOUT_SEC=600
|
||||
|
||||
# Embedding 模型配置
|
||||
EMBEDDING_API_KEY=sk-xtbpjqekezfbttasgbrczzskqenygmhqwpsobpiagwrlacfr
|
||||
EMBEDDING_API_BASE=http://192.168.4.191:8001/v1
|
||||
EMBEDDING_BATCH_MAX_DOCS=4
|
||||
EMBEDDING_BATCH_MAX_CHARS=12000
|
||||
EMBEDDING_MAX_CHUNK_CHARS=4000
|
||||
|
||||
# LLM(OpenAI 兼容接口)
|
||||
# LLM(为每个目录生成"声明";留空则使用确定性兜底模板)
|
||||
LLM_API_BASE=http://192.168.4.197:8086/v1
|
||||
LLM_API_KEY=sk-99999999991234
|
||||
LLM_MODEL_NAME=Qwen3.6-27B
|
||||
LLM_HTTP_TIMEOUT_SEC=600
|
||||
# 报告章节单次 chat 读超时(秒),长章节建议 600+
|
||||
REPORT_LLM_HTTP_TIMEOUT_SEC=600
|
||||
DECLARATION_USE_LLM=true
|
||||
|
||||
# Milvus 向量数据库
|
||||
MILVUS_DB_URL=http://192.168.4.191:19530
|
||||
|
||||
# 服务监听(注意:8099 已被网关 eval_report 占用,本子服务改用 8101)
|
||||
# 服务监听
|
||||
HOST=0.0.0.0
|
||||
PORT=8101
|
||||
RELOAD=false
|
||||
PORT=8100
|
||||
|
||||
33
.env.example
33
.env.example
@ -1,24 +1,21 @@
|
||||
# 复制为 .env 后按实际环境填写。
|
||||
# 复制为 .env 并按实际环境修改
|
||||
|
||||
# 数据库(MySQL,与原 eval_report 共用同一库)
|
||||
DATABASE_URL=mysql+pymysql://root:123456@127.0.0.1:3306/post_eval_report?charset=utf8mb4
|
||||
# 远程 MySQL(章节内容入库目标)
|
||||
DATABASE_URL=mysql+pymysql://root:Beidas0ft@192.168.4.177:3306/eval_report?charset=utf8mb4
|
||||
DB_AUTO_CREATE_TABLES=true
|
||||
|
||||
# 文档存储根目录(附图提取按 DOC_PAT/{project_uuid}/<相对路径> 定位 .docx)
|
||||
DOC_PAT=./docpath
|
||||
# 远程文档解析服务(上传文档 → Markdown)
|
||||
FILE_PARSE_API_URL=http://192.168.4.194:8000/convert
|
||||
FILE_PARSE_FIELD_NAME=file
|
||||
FILE_PARSE_ENGINE=auto
|
||||
FILE_PARSE_HTTP_TIMEOUT_SEC=600
|
||||
|
||||
# LLM(OpenAI 兼容接口)
|
||||
LLM_API_BASE=
|
||||
LLM_API_KEY=
|
||||
LLM_MODEL_NAME=
|
||||
# 报告章节单次 chat 读超时(秒),长章节建议 600+
|
||||
REPORT_LLM_HTTP_TIMEOUT_SEC=600
|
||||
|
||||
# Embedding / Milvus(向量检索证据)
|
||||
EMBEDDING_API_BASE=
|
||||
EMBEDDING_API_KEY=
|
||||
MILVUS_DB_URL=
|
||||
# LLM(可选):为每个目录生成"声明"。留空则使用确定性兜底模板。
|
||||
LLM_API_BASE=http://192.168.4.197:8086/v1
|
||||
LLM_API_KEY=sk-99999999991234
|
||||
LLM_MODEL_NAME=Qwen3.6-27B
|
||||
DECLARATION_USE_LLM=true
|
||||
|
||||
# 服务监听
|
||||
HOST=0.0.0.0
|
||||
PORT=8099
|
||||
RELOAD=false
|
||||
PORT=8100
|
||||
|
||||
28
.gitignore
vendored
28
.gitignore
vendored
@ -1,21 +1,13 @@
|
||||
# Python-generated files
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
*.py[cod]
|
||||
.venv/
|
||||
venv/
|
||||
*.egg-info/
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
|
||||
# Environment / secrets
|
||||
# .env — tracked intentionally
|
||||
|
||||
# Local artifacts
|
||||
*.log
|
||||
.DS_Store
|
||||
comp/
|
||||
docpath/
|
||||
docs/
|
||||
# 环境与日志
|
||||
logs/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
110
README.md
110
README.md
@ -1,52 +1,98 @@
|
||||
# 报告生成服务(独立抽取版)
|
||||
# 报告模板管理模块
|
||||
|
||||
从 `eval_report` 中抽取出的「后评价报告核心生成」链路,作为独立 FastAPI 服务运行。
|
||||
保留原有的证据装配(要素表 + Milvus 向量检索)、分章 LLM 生成、表格修复、报告合并与 SSE 流式进度,
|
||||
连接与原项目相同的 MySQL / Milvus / LLM 服务。
|
||||
上传一个文档,自动完成:
|
||||
|
||||
## 范围
|
||||
1. **远程解析**:调用 `http://192.168.4.194:8000/convert`(表单字段 `file` + `engine=auto`)将文档转换为 Markdown。
|
||||
2. **抽取目录**:从 Markdown 中识别章节标题层级(目录)。
|
||||
3. **生成声明**:为每个目录(章节)生成一段"章节声明"(撰写指引),存入模板。
|
||||
4. **脱敏入库**:按标题拆分正文,对每个章节正文**脱敏**(去掉精确数字/金额/日期/百分比等),再按远程 MySQL `report_section_references` 表格式写入,得到可复用的模板化范文。
|
||||
|
||||
- 包含:异步分章生成任务、进度查询、结果获取、SSE 实时事件、章节重试、任务取消。
|
||||
- 不含:鉴权、知识库 worker、模板/范文管理、Word(docx) 导出(这些仍在原 `eval_report` 中)。
|
||||
解析、目录抽取、正文拆分逻辑参考 `eval_report/routers/template.py` 与 `routers/reference.py`。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
report_generation/
|
||||
config.py 全局配置(DB / 解析服务 / LLM)
|
||||
main.py FastAPI 入口
|
||||
config.py 配置(DB / LLM / Embedding / Milvus / DOC_PAT)
|
||||
database/ SQLAlchemy 引擎、Session、ORM 模型、建表
|
||||
schemas/ Pydantic 模型
|
||||
services/ 报告生成核心逻辑(含瘦身版 kb_service / docx_export_service / project_service)
|
||||
function/vector_store.py Milvus 向量库封装
|
||||
prompts/report_generation/ 提示词模板与章节合同
|
||||
routers/report.py 报告生成 HTTP 端点
|
||||
database/ 连接、ORM 模型、建表
|
||||
models.py report_templates / report_template_sections / report_section_references
|
||||
schemas/template.py 接口出入参
|
||||
services/
|
||||
file_parse_client.py 调用远程 /convert → Markdown
|
||||
section_extractor.py 目录抽取 + 正文按标题拆分(共用同一遍历)
|
||||
desensitize_service.py 章节正文脱敏(去精确数字等)
|
||||
declaration_service.py 为每个目录生成"声明"(LLM 可选 + 兜底模板)
|
||||
llm_client.py OpenAI 兼容 Chat 客户端(可选)
|
||||
routers/template.py 上传/列表/详情/删除
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
## 配置
|
||||
|
||||
复制 `.env.example` 为 `.env` 并修改:
|
||||
|
||||
- `DATABASE_URL`:远程 MySQL(章节内容入库目标)。
|
||||
- `FILE_PARSE_API_URL`:远程文档解析服务(默认 `http://192.168.4.194:8000/convert`,文件字段 `FILE_PARSE_FIELD_NAME=file`,引擎 `FILE_PARSE_ENGINE=auto`)。
|
||||
- `LLM_*`:可选。配置后用 LLM 生成更贴合的章节声明;留空则使用确定性兜底模板。
|
||||
|
||||
启动时会按需在远程库中创建本模块用到的三张表(`DB_AUTO_CREATE_TABLES=true`,已存在则跳过)。
|
||||
|
||||
## 运行
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
cp .env.example .env # 按需填写 DATABASE_URL / LLM_* / EMBEDDING_* / MILVUS_DB_URL
|
||||
uvicorn main:app --reload
|
||||
python main.py
|
||||
# 或
|
||||
uvicorn main:app --host 0.0.0.0 --port 8100
|
||||
```
|
||||
|
||||
启动后访问 `http://127.0.0.1:8099/docs` 查看接口文档,`/health` 做健康检查。
|
||||
打开 `http://localhost:8100/docs` 查看接口文档。
|
||||
|
||||
## 主要接口(前缀 `/api/v1/write`)
|
||||
## 主要接口
|
||||
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| GET | `/projects/{project_id}/generate-sections` | 预览模板章节提示词清单 |
|
||||
| POST | `/projects/{project_id}/generate-report-job` | 创建分章异步报告生成任务 |
|
||||
| GET | `/projects/{project_id}/generate-report-job/{job_id}` | 查询任务进度 |
|
||||
| GET | `/projects/{project_id}/generate-report-job/{job_id}/result` | 获取任务结果 |
|
||||
| GET | `/projects/{project_id}/generate-report-job/{job_id}/events` | 订阅实时事件(SSE) |
|
||||
| POST | `/projects/{project_id}/generate-report-job/{job_id}/retry-chapter` | 重试指定章节 |
|
||||
| POST | `/projects/{project_id}/generate-report-job/{job_id}/cancel` | 取消任务 |
|
||||
| --- | --- | --- |
|
||||
| POST | `/templates/upload` | 上传文档,解析为模板(目录+声明)并将章节内容入库 |
|
||||
| GET | `/templates` | 模板列表 |
|
||||
| GET | `/templates/{id}` | 模板详情(含目录与各章节声明) |
|
||||
| DELETE | `/templates/{id}` | 删除模板 |
|
||||
| GET | `/health` | 健康检查 |
|
||||
|
||||
## 依赖的外部数据
|
||||
### 上传示例
|
||||
|
||||
报告生成依赖原库中已有的项目数据:`projects`、`element_tables` / `element_cells`(要素表)、
|
||||
`report_templates` / `report_template_sections`(模板章节)、可选的 `report_section_references`(参考范文),
|
||||
以及 Milvus 中按项目 UUID 写入的文档向量。请确保新服务连接到已包含这些数据的 MySQL 与 Milvus。
|
||||
```bash
|
||||
curl -X POST "http://localhost:8100/templates/upload" \
|
||||
-F "file=@/path/to/报告.docx"
|
||||
```
|
||||
|
||||
返回包含:模板信息(每个目录的 `sectionDeclaration` 即声明)、入库章节数与各章节摘要。
|
||||
|
||||
## 日志
|
||||
|
||||
启动即初始化日志系统(`log/logger.py`),输出到控制台(强制 UTF-8,避免 Windows 中文乱码)并写入 `logs/`:
|
||||
|
||||
| 文件 | 内容 |
|
||||
| --- | --- |
|
||||
| `logs/app.log` | 全量日志(按大小轮转) |
|
||||
| `logs/error.log` | WARNING 及以上 |
|
||||
| `logs/upload.log` | 上传/解析/入库链路(`routers.template`、`services.*`) |
|
||||
|
||||
- 每个 HTTP 请求会记录方法、路径、状态码、耗时,并在响应头返回 `X-Request-ID`。
|
||||
- uvicorn 的 access/error 日志也统一汇入上述文件。
|
||||
- 可在 `.env` 调整:`LOG_LEVEL`、`LOG_DIR`、`LOG_TO_CONSOLE`、`LOG_MAX_BYTES`、`LOG_BACKUP_COUNT`、`LOG_HTTP_ACCESS`。
|
||||
|
||||
## 数据落点
|
||||
|
||||
- `report_templates`:一条模板记录。
|
||||
- `report_template_sections`:每个目录一条,`section_prompt` 字段存放该目录的**声明**。
|
||||
- `report_section_references`:每个章节一条,存放该章节**脱敏后的正文内容**(与远程库现有格式一致)。
|
||||
|
||||
### 脱敏规则
|
||||
|
||||
`services/desensitize_service.py`:
|
||||
|
||||
- 阿拉伯数字串(含小数/千分位/全角)→ 占位符(默认 `X`):`总投资10.5亿元` → `总投资X亿元`、`85.3%` → `X%`、`2020年3月` → `X年X月`。
|
||||
- 标题行(`#` 开头)整行保留,不动章节编号与标题。
|
||||
- 行首枚举序号(`1)`、`(2)` 等)保留,仅脱敏正文数字。
|
||||
- 表格分隔行保留;数据格数字默认脱敏(`DESENSITIZE_MASK_TABLE_NUMBERS`)。
|
||||
- 中文数字(一二三…)默认保留(多为序数/层级)。
|
||||
- 可在 `.env` 调整:`DESENSITIZE_ENABLED`、`DESENSITIZE_PLACEHOLDER`、`DESENSITIZE_MASK_TABLE_NUMBERS`。
|
||||
|
||||
94
config.py
94
config.py
@ -1,64 +1,86 @@
|
||||
"""
|
||||
config.py
|
||||
全局配置项。可通过 .env 文件或环境变量覆盖。
|
||||
|
||||
本项目为「报告生成」独立服务,仅保留报告生成链路所需配置:
|
||||
数据库(MySQL) / LLM / Embedding / Milvus / 文档存储路径。
|
||||
报告模板管理模块的全局配置。可通过 .env 或环境变量覆盖。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# 应用基本信息
|
||||
APP_TITLE: str = "智能报告生成服务 API"
|
||||
APP_TITLE: str = "报告模板管理模块 API"
|
||||
APP_VERSION: str = "0.1.0"
|
||||
APP_DESCRIPTION: str = "后评价报告分章异步生成后端服务(独立抽取版)"
|
||||
APP_DESCRIPTION: str = "上传文档 → 远程解析为 Markdown → 拆解目录/章节 → 入库远程 MySQL"
|
||||
|
||||
# 服务监听
|
||||
HOST: str = "0.0.0.0"
|
||||
PORT: int = 8099
|
||||
PORT: int = 8100
|
||||
RELOAD: bool = False
|
||||
|
||||
# CORS 允许的前端源(开发阶段放开,生产环境改为具体域名)
|
||||
CORS_ORIGINS: list[str] = ["*"]
|
||||
|
||||
# 数据库(MySQL)
|
||||
DATABASE_URL: str = "mysql+pymysql://root:123456@127.0.0.1:3306/post_eval_report?charset=utf8mb4"
|
||||
DB_POOL_SIZE: int = 15
|
||||
DB_MAX_OVERFLOW: int = 25
|
||||
# 日志
|
||||
LOG_LEVEL: str = "INFO" # DEBUG / INFO / WARNING / ERROR
|
||||
LOG_DIR: str = "logs" # 日志目录(相对启动目录或绝对路径)
|
||||
LOG_TO_CONSOLE: bool = True # 是否同时输出到控制台
|
||||
LOG_MAX_BYTES: int = 10 * 1024 * 1024 # 单文件最大字节数(轮转)
|
||||
LOG_BACKUP_COUNT: int = 7 # 轮转保留份数
|
||||
LOG_HTTP_ACCESS: bool = True # 是否记录每个 HTTP 请求
|
||||
|
||||
# 远程 MySQL:mysql+pymysql://用户:密码@主机:端口/库名?charset=utf8mb4
|
||||
DATABASE_URL: str = (
|
||||
"mysql+pymysql://root:Beidas0ft@192.168.4.177:3306/eval_report?charset=utf8mb4"
|
||||
)
|
||||
DB_POOL_SIZE: int = 10
|
||||
DB_MAX_OVERFLOW: int = 20
|
||||
DB_POOL_TIMEOUT: int = 60
|
||||
DB_POOL_PRE_PING: bool = True
|
||||
# 启动时自动建表(仅创建本模块用到的表,已存在则跳过)
|
||||
DB_AUTO_CREATE_TABLES: bool = True
|
||||
|
||||
# 文档存储根目录(附图提取时按 DOC_PAT/{project_uuid}/<相对路径> 定位 .docx)
|
||||
DOC_PAT: str = "./docpath"
|
||||
# 远程文档解析服务:上传文件 → Markdown
|
||||
FILE_PARSE_API_URL: str = "http://192.168.4.194:8000/convert"
|
||||
FILE_PARSE_FIELD_NAME: str = "file"
|
||||
# 解析引擎(随 multipart 一起提交的表单字段 engine)
|
||||
FILE_PARSE_ENGINE: str = "auto"
|
||||
FILE_PARSE_HTTP_TIMEOUT_SEC: int = 600
|
||||
FILE_PARSE_RETRY_COUNT: int = 3
|
||||
FILE_PARSE_RETRY_BACKOFF_SEC: float = 15.0
|
||||
|
||||
# LLM(OpenAI 兼容接口)
|
||||
# 章节正文:是否包含其下级小节内容(章/节聚合整棵子树正文,避免父章节正文为空)
|
||||
SECTION_CONTENT_INCLUDE_SUBSECTIONS: bool = True
|
||||
# 单章节正文入库字节上限(MySQL TEXT 列上限 65535 字节,留余量防止截断到半个字符)
|
||||
SECTION_CONTENT_MAX_BYTES: int = 60000
|
||||
|
||||
# 章节内容脱敏:入库前过滤精确数据(数字/金额/日期/百分比等)
|
||||
DESENSITIZE_ENABLED: bool = True
|
||||
DESENSITIZE_PLACEHOLDER: str = "X" # 数字脱敏后的占位符
|
||||
# 是否把表格中的数字也脱敏(表格通常是精确数据,默认开启)
|
||||
DESENSITIZE_MASK_TABLE_NUMBERS: bool = True
|
||||
|
||||
# LLM(可选):为每个目录生成"声明"。未配置时使用确定性兜底模板。
|
||||
LLM_API_BASE: str = ""
|
||||
LLM_API_KEY: str = ""
|
||||
LLM_MODEL_NAME: str = ""
|
||||
LLM_HTTP_TIMEOUT_SEC: int = 120
|
||||
LLM_CONNECT_TIMEOUT_SEC: int = 30
|
||||
LLM_RETRY_COUNT: int = 3
|
||||
LLM_RETRY_BACKOFF_SEC: float = 1.0
|
||||
LLM_RETRY_BACKOFF_MAX_SEC: float = 12.0
|
||||
# 报告章节单次 chat 读超时(秒)。0 表示沿用 LLM_HTTP_TIMEOUT_SEC;长章节建议 600+
|
||||
REPORT_LLM_HTTP_TIMEOUT_SEC: int = 600
|
||||
# 某章 LLM 仍失败时写入占位正文并继续后续章节,避免整份任务失败
|
||||
REPORT_LLM_CONTINUE_ON_TIMEOUT: bool = True
|
||||
# 表格抽取延迟补抽(首轮失败后进入队列,按轮次延迟重试)
|
||||
LLM_TABLE_DELAY_RETRY_ROUNDS: int = 2
|
||||
LLM_TABLE_DELAY_RETRY_SEC: float = 8.0
|
||||
LLM_TABLE_DELAY_RETRY_BACKOFF: float = 2.0
|
||||
LLM_TABLE_DELAY_RETRY_MAX_SEC: float = 60.0
|
||||
|
||||
# Embedding / Milvus(向量检索证据 L2/L3)
|
||||
EMBEDDING_API_KEY: str = ""
|
||||
EMBEDDING_API_BASE: str = ""
|
||||
EMBEDDING_BATCH_MAX_DOCS: int = 4
|
||||
EMBEDDING_BATCH_MAX_CHARS: int = 12000
|
||||
EMBEDDING_MAX_CHUNK_CHARS: int = 4000
|
||||
MILVUS_DB_URL: str = ""
|
||||
# 关闭思考模型的思维链输出(vLLM/Qwen3 等:chat_template_kwargs.enable_thinking=false)。
|
||||
# 既避免"思考过程"混入正文,又减少 token、降低截断与耗时。
|
||||
LLM_DISABLE_THINKING: bool = True
|
||||
# 是否调用 LLM 生成章节声明(关闭则始终使用兜底模板)
|
||||
DECLARATION_USE_LLM: bool = True
|
||||
# 上传模版时:用 LLM 匹配默认提示词 / 为无匹配章节生成提示词(复刻 eval_report)
|
||||
TEMPLATE_UPLOAD_LLM_PROMPT_MAPPING: bool = True
|
||||
# LLM 提示词匹配并发:把未匹配章节分批并行调用,缩短整体耗时。
|
||||
# 多卡 A100 + 连续批处理(vLLM/TGI,TP 或多副本)下,提高并发在飞请求数即可打满 GPU:
|
||||
# - 调小 BATCH_SIZE:请求更多更短,确保批次数 ≥ 线程数,单请求尾延迟更低
|
||||
# - 调大 MAX_WORKERS:同时在飞的序列更多,填满推理服务的批,decode 吞吐接近峰值
|
||||
# - 调小 MAX_TOKENS:每序列 KV 缓存预留更少,调度器可纳入更多并发序列
|
||||
# 2×A100:并发目标约 16(较单卡的 8 翻倍);BATCH_SIZE=2 保证常见规模也能跑满 16 路。
|
||||
TEMPLATE_UPLOAD_LLM_BATCH_SIZE: int = 2 # 每批未匹配章节数量
|
||||
TEMPLATE_UPLOAD_LLM_MAX_WORKERS: int = 16 # 并行线程数上限(在飞请求数)
|
||||
TEMPLATE_UPLOAD_LLM_MAX_TOKENS: int = 2048 # 单批最大输出 token
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
|
||||
@ -1,27 +1,7 @@
|
||||
"""
|
||||
database
|
||||
数据库连接与 Session 管理。
|
||||
"""database package:连接、模型与依赖注入。"""
|
||||
|
||||
使用方式:
|
||||
from database import get_db, SessionLocal, init_database
|
||||
|
||||
# 依赖注入(FastAPI 路由)
|
||||
@router.get("/items")
|
||||
def list_items(db: Session = Depends(get_db)):
|
||||
...
|
||||
|
||||
# 上下文管理器(脚本、worker)
|
||||
with SessionLocal() as db:
|
||||
...
|
||||
"""
|
||||
|
||||
from database.core import engine, SessionLocal
|
||||
from database.core import SessionLocal, engine
|
||||
from database.dependencies import get_db
|
||||
from database.init_db import init_database
|
||||
|
||||
__all__ = [
|
||||
"engine",
|
||||
"SessionLocal",
|
||||
"get_db",
|
||||
"init_database",
|
||||
]
|
||||
__all__ = ["engine", "SessionLocal", "get_db", "init_database"]
|
||||
|
||||
@ -1,42 +1,33 @@
|
||||
"""
|
||||
database/core.py
|
||||
SQLAlchemy 引擎与 Session 工厂。
|
||||
|
||||
- 同步引擎,默认连接池(QueuePool)
|
||||
- 后续可替换为 create_async_engine 实现异步
|
||||
SQLAlchemy 引擎与 Session 工厂(同步引擎,连接远程 MySQL)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from config import settings
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 引擎配置
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DB_POOL_SIZE,
|
||||
max_overflow=settings.DB_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DB_POOL_TIMEOUT,
|
||||
pool_pre_ping=settings.DB_POOL_PRE_PING,
|
||||
pool_recycle=3600, # 1 小时回收空闲连接,避免 MySQL wait_timeout
|
||||
pool_recycle=3600,
|
||||
connect_args={
|
||||
"charset": "utf8mb4",
|
||||
"use_unicode": True,
|
||||
"init_command": "SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci",
|
||||
},
|
||||
echo=False, # 开发时可设为 True 打印 SQL
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Session 工厂
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
SessionLocal = sessionmaker(
|
||||
bind=engine,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
expire_on_commit=False, # 提交后对象仍可访问属性,便于返回响应
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
"""
|
||||
database/dependencies.py
|
||||
FastAPI 依赖注入:获取数据库 Session。
|
||||
|
||||
每个请求创建新 Session,请求结束后自动关闭。
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -13,14 +13,6 @@ from database.core import SessionLocal
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""
|
||||
获取数据库 Session,用于 FastAPI Depends()。
|
||||
|
||||
用法:
|
||||
@router.get("/items")
|
||||
def list_items(db: Session = Depends(get_db)):
|
||||
...
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
|
||||
@ -1,764 +1,52 @@
|
||||
"""
|
||||
database/init_db.py
|
||||
应用启动时初始化数据库表结构。
|
||||
|
||||
执行 init.sql 中的 DDL,使用 IF NOT EXISTS 保证幂等。
|
||||
按需建表:仅创建本模块用到的三张表,已存在则跳过(checkfirst=True)。
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
|
||||
from sqlalchemy import inspect, text
|
||||
|
||||
from database.core import engine
|
||||
from database.models import Base
|
||||
|
||||
# DDL 与 init_db.py 同目录:database/init.sql
|
||||
INIT_SQL_PATH = Path(__file__).resolve().parent / "init.sql"
|
||||
|
||||
INIT_TABLES = [
|
||||
"projects",
|
||||
"kb_directories",
|
||||
"kb_documents",
|
||||
"write_documents",
|
||||
"doc_versions",
|
||||
"element_tables",
|
||||
"element_cells",
|
||||
"extraction_results",
|
||||
"element_extraction_results",
|
||||
"element_conflicts",
|
||||
"document_markdowns",
|
||||
"document_chunks",
|
||||
"report_templates",
|
||||
"report_template_sections",
|
||||
"report_generation_jobs",
|
||||
"report_generation_chapters",
|
||||
"departments",
|
||||
"users",
|
||||
"roles",
|
||||
"permissions",
|
||||
"role_permissions",
|
||||
"user_roles",
|
||||
"project_members",
|
||||
"project_departments",
|
||||
"fill_records",
|
||||
"report_section_references",
|
||||
]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_TARGET_TABLE_COLLATION = "utf8mb4_unicode_ci"
|
||||
def _ensure_reference_template_id_column() -> None:
|
||||
"""为已存在的 report_section_references 表补充 template_id 字段(幂等)。
|
||||
|
||||
|
||||
def _existing_tables(conn) -> set[str]:
|
||||
return {
|
||||
row[0]
|
||||
for row in conn.execute(
|
||||
text(
|
||||
"SELECT TABLE_NAME FROM information_schema.TABLES "
|
||||
"WHERE TABLE_SCHEMA = DATABASE()"
|
||||
)
|
||||
).fetchall()
|
||||
}
|
||||
|
||||
|
||||
def _table_collation(conn, table_name: str) -> str | None:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"SELECT TABLE_COLLATION FROM information_schema.TABLES "
|
||||
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = :table_name"
|
||||
),
|
||||
{"table_name": table_name},
|
||||
).first()
|
||||
return str(row[0]).strip() if row and row[0] else None
|
||||
|
||||
|
||||
def _column_collation(conn, table_name: str, column_name: str) -> str | None:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"SELECT COLLATION_NAME FROM information_schema.COLUMNS "
|
||||
"WHERE TABLE_SCHEMA = DATABASE() "
|
||||
"AND TABLE_NAME = :table_name AND COLUMN_NAME = :column_name"
|
||||
),
|
||||
{"table_name": table_name, "column_name": column_name},
|
||||
).first()
|
||||
return str(row[0]).strip() if row and row[0] else None
|
||||
|
||||
|
||||
def _normalize_projects_table(conn) -> None:
|
||||
create_all(checkfirst=True) 只建缺失的表,不会给已存在的表加列,
|
||||
因此这里对历史表做一次轻量级 ALTER(仅在缺列时执行)。
|
||||
"""
|
||||
将历史库表/列统一为 utf8mb4_unicode_ci(仅在实际不一致时执行 ALTER)。
|
||||
|
||||
切勿在每次启动时对已迁移库重复 CONVERT:会长时间持有 metadata lock,
|
||||
阻塞所有对 projects 等表的读写,并导致连接池耗尽。
|
||||
"""
|
||||
existing = _existing_tables(conn)
|
||||
tables_to_convert = [
|
||||
name
|
||||
for name in INIT_TABLES
|
||||
if name in existing and _table_collation(conn, name) != _TARGET_TABLE_COLLATION
|
||||
]
|
||||
projects_uuid_needs_fix = (
|
||||
"projects" in existing
|
||||
and _column_collation(conn, "projects", "uuid") != _TARGET_TABLE_COLLATION
|
||||
)
|
||||
if not tables_to_convert and not projects_uuid_needs_fix:
|
||||
insp = inspect(engine)
|
||||
if "report_section_references" not in insp.get_table_names():
|
||||
return
|
||||
|
||||
conn.execute(text("SET FOREIGN_KEY_CHECKS=0"))
|
||||
try:
|
||||
for table_name in tables_to_convert:
|
||||
columns = {c["name"] for c in insp.get_columns("report_section_references")}
|
||||
if "template_id" in columns:
|
||||
return
|
||||
|
||||
with engine.begin() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
f"ALTER TABLE `{table_name}` "
|
||||
"CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
|
||||
"ALTER TABLE report_section_references "
|
||||
"ADD COLUMN template_id VARCHAR(64) NULL"
|
||||
)
|
||||
)
|
||||
if projects_uuid_needs_fix:
|
||||
conn.execute(
|
||||
text(
|
||||
"ALTER TABLE projects "
|
||||
"MODIFY uuid VARCHAR(32) CHARACTER SET utf8mb4 "
|
||||
"COLLATE utf8mb4_unicode_ci NOT NULL"
|
||||
"ALTER TABLE report_section_references "
|
||||
"ADD INDEX ix_report_section_references_template_id (template_id)"
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.execute(text("SET FOREIGN_KEY_CHECKS=1"))
|
||||
|
||||
|
||||
def _split_sql_statements(content: str) -> list[str]:
|
||||
"""
|
||||
按分号拆分 SQL 语句,忽略注释和空行。
|
||||
简单实现:不处理字符串内的分号。
|
||||
"""
|
||||
# 移除单行注释
|
||||
content = re.sub(r"--[^\n]*", "", content)
|
||||
# 移除多行注释
|
||||
content = re.sub(r"/\*.*?\*/", "", content, flags=re.DOTALL)
|
||||
statements = [
|
||||
s.strip()
|
||||
for s in content.split(";")
|
||||
if s.strip() and not s.strip().startswith("--")
|
||||
]
|
||||
return statements
|
||||
logger.info("init_database: report_section_references.template_id 字段已补充")
|
||||
|
||||
|
||||
def init_database() -> None:
|
||||
"""
|
||||
执行 init.sql,创建表结构,并按需执行缺失字段迁移。
|
||||
|
||||
注意:init.sql 里使用了 `CREATE TABLE IF NOT EXISTS`,因此对“已存在但缺列”的旧库,
|
||||
需要额外执行对应迁移脚本(例如补齐 `kb_documents.factor`)。
|
||||
"""
|
||||
if not INIT_SQL_PATH.exists():
|
||||
return
|
||||
|
||||
sql_text = INIT_SQL_PATH.read_text(encoding="utf-8")
|
||||
statements = _split_sql_statements(sql_text)
|
||||
|
||||
with engine.connect() as conn:
|
||||
for stmt in statements:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
# 表/索引已存在时忽略(Duplicate key name、already exists)
|
||||
err_msg = str(e).lower()
|
||||
# 历史库可能缺列,导致 CREATE INDEX 报 "Key column ... doesn't exist in table"。
|
||||
# 这里先跳过,后续 migrate_extraction_results.sql 会补齐列并建索引。
|
||||
if (
|
||||
"already exists" in err_msg
|
||||
or "duplicate" in err_msg
|
||||
or ("key column" in err_msg and "doesn't exist in table" in err_msg)
|
||||
or "error 1072" in err_msg
|
||||
):
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
# 仅在字符集未达标时执行 ALTER(勿在每次 CREATE TABLE projects 后重复调用)
|
||||
_normalize_projects_table(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Missing-column migrations (idempotent via "duplicate column" ignore)
|
||||
# ------------------------------------------------------------------
|
||||
factor_migrate_path = Path(__file__).resolve().parent / "migrate_kb_documents_factor.sql"
|
||||
if factor_migrate_path.exists():
|
||||
factor_sql_text = factor_migrate_path.read_text(encoding="utf-8")
|
||||
factor_statements = _split_sql_statements(factor_sql_text)
|
||||
|
||||
with engine.connect() as conn:
|
||||
for stmt in factor_statements:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
# MySQL: Error 1060 "Duplicate column name 'factor'"
|
||||
err_msg = str(e).lower()
|
||||
if "duplicate column" in err_msg or "error 1060" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Missing tables/columns migrations (kb_directories + directory_id)
|
||||
# ------------------------------------------------------------------
|
||||
kb_dirs_migrate_path = Path(__file__).resolve().parent / "migrate_kb_directories.sql"
|
||||
if kb_dirs_migrate_path.exists():
|
||||
kb_dirs_sql_text = kb_dirs_migrate_path.read_text(encoding="utf-8")
|
||||
kb_dirs_statements = _split_sql_statements(kb_dirs_sql_text)
|
||||
|
||||
with engine.connect() as conn:
|
||||
for stmt in kb_dirs_statements:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
# MySQL 常见“已存在/重复”错误:忽略以保证幂等
|
||||
if (
|
||||
"duplicate column" in err_msg
|
||||
or "error 1060" in err_msg
|
||||
or "already exists" in err_msg
|
||||
or "duplicate" in err_msg
|
||||
):
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Missing tables/columns migrations (extraction_results legacy schema)
|
||||
# ------------------------------------------------------------------
|
||||
extraction_migrate_path = Path(__file__).resolve().parent / "migrate_extraction_results.sql"
|
||||
if extraction_migrate_path.exists():
|
||||
extraction_sql_text = extraction_migrate_path.read_text(encoding="utf-8")
|
||||
extraction_statements = _split_sql_statements(extraction_sql_text)
|
||||
|
||||
with engine.connect() as conn:
|
||||
for stmt in extraction_statements:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if (
|
||||
"duplicate column" in err_msg
|
||||
or "error 1060" in err_msg
|
||||
or "already exists" in err_msg
|
||||
or "duplicate" in err_msg
|
||||
or "check that column/key exists" in err_msg
|
||||
or "error 1072" in err_msg
|
||||
or "doesn't exist" in err_msg
|
||||
):
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# extraction_results:移除历史 run_id 外键/字段(收敛到 batch_id)
|
||||
# ------------------------------------------------------------------
|
||||
extraction_drop_run_id_path = (
|
||||
Path(__file__).resolve().parent / "migrate_extraction_results_drop_run_id.sql"
|
||||
)
|
||||
if extraction_drop_run_id_path.exists():
|
||||
extraction_drop_run_id_sql = extraction_drop_run_id_path.read_text(encoding="utf-8")
|
||||
extraction_drop_run_id_statements = _split_sql_statements(extraction_drop_run_id_sql)
|
||||
|
||||
with engine.connect() as conn:
|
||||
for stmt in extraction_drop_run_id_statements:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if (
|
||||
"already exists" in err_msg
|
||||
or "duplicate" in err_msg
|
||||
or "doesn't exist" in err_msg
|
||||
or "check that column/key exists" in err_msg
|
||||
or "error 1091" in err_msg
|
||||
):
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# element_conflicts:补齐 table_id / cell_id(旧库缺列导致 ORM 查询 500)
|
||||
# ------------------------------------------------------------------
|
||||
ec_migrate_path = Path(__file__).resolve().parent / "migrate_element_conflicts.sql"
|
||||
if ec_migrate_path.exists():
|
||||
ec_sql_text = ec_migrate_path.read_text(encoding="utf-8")
|
||||
ec_statements = _split_sql_statements(ec_sql_text)
|
||||
|
||||
with engine.connect() as conn:
|
||||
for stmt in ec_statements:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if (
|
||||
"duplicate column" in err_msg
|
||||
or "error 1060" in err_msg
|
||||
or "already exists" in err_msg
|
||||
or "errno 1060" in err_msg
|
||||
):
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# element_conflicts:兼容历史 project_element_id NOT NULL(改为 NULL)
|
||||
# ------------------------------------------------------------------
|
||||
ec_project_element_id_path = (
|
||||
Path(__file__).resolve().parent / "migrate_element_conflicts_project_element_id_nullable.sql"
|
||||
)
|
||||
if ec_project_element_id_path.exists():
|
||||
ec_peid_sql = ec_project_element_id_path.read_text(encoding="utf-8")
|
||||
ec_peid_stmts = _split_sql_statements(ec_peid_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in ec_peid_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# element_conflicts:兼容历史 extraction_result_id NOT NULL(改为 NULL)
|
||||
# ------------------------------------------------------------------
|
||||
ec_extraction_result_id_path = (
|
||||
Path(__file__).resolve().parent / "migrate_element_conflicts_extraction_result_id_nullable.sql"
|
||||
)
|
||||
if ec_extraction_result_id_path.exists():
|
||||
ec_erid_sql = ec_extraction_result_id_path.read_text(encoding="utf-8")
|
||||
ec_erid_stmts = _split_sql_statements(ec_erid_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in ec_erid_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# extraction_results:extracted_at / source_line_end
|
||||
# ------------------------------------------------------------------
|
||||
ext_time_path = Path(__file__).resolve().parent / "migrate_extraction_results_extracted_at.sql"
|
||||
if ext_time_path.exists():
|
||||
ext_sql = ext_time_path.read_text(encoding="utf-8")
|
||||
ext_stmts = _split_sql_statements(ext_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in ext_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if (
|
||||
"duplicate column" in err_msg
|
||||
or "error 1060" in err_msg
|
||||
or "already exists" in err_msg
|
||||
):
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# element_extraction_results:要素抽取结果明细表(若旧库缺表则补齐)
|
||||
# ------------------------------------------------------------------
|
||||
el_ext_path = Path(__file__).resolve().parent / "migrate_element_extraction_results.sql"
|
||||
if el_ext_path.exists():
|
||||
el_ext_sql = el_ext_path.read_text(encoding="utf-8")
|
||||
el_ext_stmts = _split_sql_statements(el_ext_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in el_ext_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if (
|
||||
"already exists" in err_msg
|
||||
or "duplicate" in err_msg
|
||||
):
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# project_departments:项目可见部门
|
||||
# ------------------------------------------------------------------
|
||||
proj_dept_path = Path(__file__).resolve().parent / "migrate_project_departments.sql"
|
||||
if proj_dept_path.exists():
|
||||
proj_dept_sql = proj_dept_path.read_text(encoding="utf-8")
|
||||
proj_dept_stmts = _split_sql_statements(proj_dept_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in proj_dept_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "already exists" in err_msg or "duplicate" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# report_template_sections:章节输出合同(section_output_contract)
|
||||
# ------------------------------------------------------------------
|
||||
template_section_contract_path = (
|
||||
Path(__file__).resolve().parent / "migrations" / "add_section_output_contract.sql"
|
||||
)
|
||||
if template_section_contract_path.exists():
|
||||
template_section_contract_sql = template_section_contract_path.read_text(encoding="utf-8")
|
||||
template_section_contract_stmts = _split_sql_statements(template_section_contract_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in template_section_contract_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# report_section_references:补齐 template_id(按模板过滤参考范文)
|
||||
# ------------------------------------------------------------------
|
||||
ref_template_id_path = (
|
||||
Path(__file__).resolve().parent / "migrations" / "add_ref_template_id.sql"
|
||||
)
|
||||
if ref_template_id_path.exists():
|
||||
ref_template_id_sql = ref_template_id_path.read_text(encoding="utf-8")
|
||||
ref_template_id_stmts = _split_sql_statements(ref_template_id_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in ref_template_id_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if (
|
||||
"duplicate column" in err_msg
|
||||
or "error 1060" in err_msg
|
||||
or "duplicate key name" in err_msg
|
||||
or "error 1061" in err_msg
|
||||
or "already exists" in err_msg
|
||||
):
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# users:补齐 password_hash(登录注册)
|
||||
# ------------------------------------------------------------------
|
||||
users_pwd_path = Path(__file__).resolve().parent / "migrate_users_password_hash.sql"
|
||||
if users_pwd_path.exists():
|
||||
users_pwd_sql = users_pwd_path.read_text(encoding="utf-8")
|
||||
users_pwd_stmts = _split_sql_statements(users_pwd_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in users_pwd_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# departments:补齐 description(部门描述)
|
||||
# ------------------------------------------------------------------
|
||||
dept_desc_path = Path(__file__).resolve().parent / "migrate_departments_description.sql"
|
||||
if dept_desc_path.exists():
|
||||
dept_desc_sql = dept_desc_path.read_text(encoding="utf-8")
|
||||
dept_desc_stmts = _split_sql_statements(dept_desc_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in dept_desc_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# projects:用户删除的标准模版表不再被「同步模版」回补
|
||||
# ------------------------------------------------------------------
|
||||
proj_sup_path = Path(__file__).resolve().parent / "migrate_projects_sync_suppressed_tables.sql"
|
||||
if proj_sup_path.exists():
|
||||
proj_sup_sql = proj_sup_path.read_text(encoding="utf-8")
|
||||
proj_sup_stmts = _split_sql_statements(proj_sup_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in proj_sup_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# element_tables:用户删行后不再被「同步模版」回补
|
||||
# ------------------------------------------------------------------
|
||||
et_sup_path = Path(__file__).resolve().parent / "migrate_element_tables_sync_suppressed.sql"
|
||||
if et_sup_path.exists():
|
||||
et_sup_sql = et_sup_path.read_text(encoding="utf-8")
|
||||
et_sup_stmts = _split_sql_statements(et_sup_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in et_sup_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# element_tables:自定义行顺序(加行插在选中行下;刷新后仍保持)
|
||||
# ------------------------------------------------------------------
|
||||
et_row_order_path = Path(__file__).resolve().parent / "migrate_element_tables_custom_row_order.sql"
|
||||
if et_row_order_path.exists():
|
||||
et_row_order_sql = et_row_order_path.read_text(encoding="utf-8")
|
||||
et_row_order_stmts = _split_sql_statements(et_row_order_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in et_row_order_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# kb_documents:status 语义 v2(0/2/3/4),仅旧库(仍有 status=1 且无 status=4)时执行
|
||||
# ------------------------------------------------------------------
|
||||
status_v2_path = Path(__file__).resolve().parent / "migrate_kb_doc_status_v2.sql"
|
||||
if status_v2_path.exists():
|
||||
with engine.connect() as conn:
|
||||
try:
|
||||
probe = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT
|
||||
SUM(CASE WHEN status = 1 THEN 1 ELSE 0 END) AS s1,
|
||||
SUM(CASE WHEN status = 4 THEN 1 ELSE 0 END) AS s4
|
||||
FROM kb_documents
|
||||
"""
|
||||
)
|
||||
).fetchone()
|
||||
s1 = int((probe[0] if probe else 0) or 0)
|
||||
s4 = int((probe[1] if probe else 0) or 0)
|
||||
if s1 > 0 and s4 == 0:
|
||||
status_v2_sql = status_v2_path.read_text(encoding="utf-8")
|
||||
for stmt in _split_sql_statements(status_v2_sql):
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "doesn't exist" in err_msg and "kb_documents" in err_msg:
|
||||
conn.rollback()
|
||||
else:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# kb_documents:storage_rel_path + error_message
|
||||
# ------------------------------------------------------------------
|
||||
kb_storage_path = Path(__file__).resolve().parent / "migrate_kb_doc_storage_path.sql"
|
||||
if kb_storage_path.exists():
|
||||
kb_storage_sql = kb_storage_path.read_text(encoding="utf-8")
|
||||
kb_storage_stmts = _split_sql_statements(kb_storage_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in kb_storage_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# kb_documents:category (文件分类)
|
||||
# ------------------------------------------------------------------
|
||||
category_migrate_path = Path(__file__).resolve().parent / "migrate_kb_documents_category.sql"
|
||||
if category_migrate_path.exists():
|
||||
category_sql_text = category_migrate_path.read_text(encoding="utf-8")
|
||||
category_statements = _split_sql_statements(category_sql_text)
|
||||
with engine.connect() as conn:
|
||||
for stmt in category_statements:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# kb_documents:旧版分类 → 资料清单六大类
|
||||
# ------------------------------------------------------------------
|
||||
category_checklist_path = Path(__file__).resolve().parent / "migrate_kb_category_checklist.sql"
|
||||
if category_checklist_path.exists():
|
||||
checklist_sql_text = category_checklist_path.read_text(encoding="utf-8")
|
||||
checklist_statements = _split_sql_statements(checklist_sql_text)
|
||||
with engine.connect() as conn:
|
||||
for stmt in checklist_statements:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# kb_documents:upload_filename(上传/解压原始文件名)
|
||||
# ------------------------------------------------------------------
|
||||
upload_fn_path = Path(__file__).resolve().parent / "migrate_kb_documents_upload_filename.sql"
|
||||
if upload_fn_path.exists():
|
||||
upload_fn_sql = upload_fn_path.read_text(encoding="utf-8")
|
||||
upload_fn_stmts = _split_sql_statements(upload_fn_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in upload_fn_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# element_cells:source_type(文档抽取 / 手工输入)
|
||||
# ------------------------------------------------------------------
|
||||
ec_source_type_path = Path(__file__).resolve().parent / "migrate_element_cells_source_type.sql"
|
||||
if ec_source_type_path.exists():
|
||||
ec_source_sql = ec_source_type_path.read_text(encoding="utf-8")
|
||||
ec_source_stmts = _split_sql_statements(ec_source_sql)
|
||||
with engine.connect() as conn:
|
||||
for stmt in ec_source_stmts:
|
||||
stmt = stmt.strip()
|
||||
if not stmt:
|
||||
continue
|
||||
try:
|
||||
conn.execute(text(stmt))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
err_msg = str(e).lower()
|
||||
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
|
||||
conn.rollback()
|
||||
continue
|
||||
conn.rollback()
|
||||
raise
|
||||
"""在远程 MySQL 中创建本模块所需表(若不存在)。"""
|
||||
Base.metadata.create_all(bind=engine, checkfirst=True)
|
||||
_ensure_reference_template_id_column()
|
||||
logger.info("init_database: report_templates / report_template_sections / report_section_references 已就绪")
|
||||
|
||||
@ -1,305 +1,24 @@
|
||||
"""
|
||||
database/models.py
|
||||
SQLAlchemy ORM 模型,与 db.md / init.sql 对应。
|
||||
ORM 模型,与远程 MySQL(eval_report 库)现有表结构一致:
|
||||
- report_templates 模板
|
||||
- report_template_sections 模板章节(目录 + 声明)
|
||||
- report_section_references 章节参考范文(章节内容入库目标)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, JSON, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class Project(Base):
|
||||
"""项目表(统一:知识库 + 撰写)"""
|
||||
|
||||
__tablename__ = "projects"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
uuid: Mapped[str] = mapped_column(
|
||||
String(32),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
doc_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
eval_reports_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_size: Mapped[str] = mapped_column(String(32), default="0 B")
|
||||
tags: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(16), default="active")
|
||||
color: Mapped[str] = mapped_column(String(16), default="#3b82f6")
|
||||
sync_suppressed_table_names: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
kb_documents: Mapped[list["KbDocument"]] = relationship(
|
||||
"KbDocument", back_populates="project", cascade="all, delete-orphan"
|
||||
)
|
||||
kb_directories: Mapped[list["KbDirectory"]] = relationship(
|
||||
"KbDirectory", back_populates="project", cascade="all, delete-orphan"
|
||||
)
|
||||
write_documents: Mapped[list["WriteDocumentModel"]] = relationship(
|
||||
"WriteDocumentModel", back_populates="project", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class WriteDocumentModel(Base):
|
||||
"""撰写文档表(后评价报告)。project_id 关联 projects.uuid"""
|
||||
|
||||
__tablename__ = "write_documents"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
content: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
word_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(16), default="draft")
|
||||
sort_order: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
project: Mapped["Project"] = relationship("Project", back_populates="write_documents")
|
||||
doc_versions: Mapped[list["DocumentVersion"]] = relationship(
|
||||
"DocumentVersion", back_populates="document", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class DocumentVersion(Base):
|
||||
"""撰写文档版本表(对应 doc_versions)。"""
|
||||
|
||||
__tablename__ = "doc_versions"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
document_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("write_documents.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
version: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
citation_payload: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
saved_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
author: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
note: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
document: Mapped["WriteDocumentModel"] = relationship("WriteDocumentModel", back_populates="doc_versions")
|
||||
|
||||
|
||||
class KbDocument(Base):
|
||||
"""知识库文档表。project_id 关联 projects.uuid。status: 0=失败 2=排队中 3=处理中 4=可用"""
|
||||
|
||||
__tablename__ = "kb_documents"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
directory_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("kb_directories.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
upload_filename: Mapped[Optional[str]] = mapped_column(
|
||||
String(255), nullable=True
|
||||
) # 上传/解压时的原始文件名(含扩展名),与智能展示名 name 区分
|
||||
size: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
file_path: Mapped[Optional[str]] = mapped_column(String(512), nullable=True) # 仅目录路径,不含文件名
|
||||
storage_rel_path: Mapped[Optional[str]] = mapped_column(
|
||||
String(512), nullable=True
|
||||
) # 项目内完整相对路径(含文件名),用于精确定位磁盘文件
|
||||
word_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
uploaded_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
status: Mapped[int] = mapped_column(Integer, default=2) # 0=失败 2=排队中 3=处理中 4=可用
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
category: Mapped[Optional[str]] = mapped_column(String(32), nullable=True, default=None)
|
||||
|
||||
project: Mapped["Project"] = relationship("Project", back_populates="kb_documents")
|
||||
directory: Mapped[Optional["KbDirectory"]] = relationship("KbDirectory", back_populates="documents")
|
||||
|
||||
|
||||
class KbDirectory(Base):
|
||||
"""知识库目录表。project_id 关联 projects.uuid;parent_id 形成目录树。"""
|
||||
|
||||
__tablename__ = "kb_directories"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
parent_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("kb_directories.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
full_path: Mapped[str] = mapped_column(String(1024), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
project: Mapped["Project"] = relationship("Project", back_populates="kb_directories")
|
||||
documents: Mapped[list["KbDocument"]] = relationship("KbDocument", back_populates="directory")
|
||||
|
||||
|
||||
class Task(Base):
|
||||
"""独立后台任务表:pdf2md 转换和 element-agent 要素抽取。"""
|
||||
|
||||
__tablename__ = "tasks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
task_type: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
file_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
file_path: Mapped[Optional[str]] = mapped_column(String(1024), nullable=True)
|
||||
status: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
payload_json: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
result_path: Mapped[Optional[str]] = mapped_column(String(1024), nullable=True)
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
add_time: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
finish_time: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class ElementTable(Base):
|
||||
__tablename__ = "element_tables"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
table_type: Mapped[str] = mapped_column(String(32), nullable=False) # global/time
|
||||
table_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
year: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
is_time_dimension: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
sort_order: Mapped[int] = mapped_column(Integer, default=0)
|
||||
# JSON 数组字符串,row_key 列表;sync 模版时跳过为这些行补格子,避免用户删行后一同步又出现
|
||||
sync_suppressed_row_keys: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
# JSON 数组:界面行键展示顺序(含用户加行)
|
||||
custom_row_order: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class ElementCell(Base):
|
||||
__tablename__ = "element_cells"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
table_id: Mapped[str] = mapped_column(ForeignKey("element_tables.id", ondelete="CASCADE"), nullable=False)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
row_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
col_key: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
year: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
source_document_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("kb_documents.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
source_line_no: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
source_line_end: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
source_quote: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
extraction_batch_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
extraction_model: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
source_type: Mapped[Optional[str]] = mapped_column(String(16), nullable=True) # extract | manual
|
||||
conflict_status: Mapped[str] = mapped_column(String(16), default="none")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class ExtractionResult(Base):
|
||||
__tablename__ = "extraction_results"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(ForeignKey("kb_documents.id", ondelete="CASCADE"), nullable=False)
|
||||
batch_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
result_type: Mapped[str] = mapped_column(String(16), nullable=False) # table/element
|
||||
table_type: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
table_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
year: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
item_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
item_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
source_line_no: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
source_line_end: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
raw_payload: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
extracted_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 抽取业务时间(旧库迁移前可为空)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class ElementExtractionResult(Base):
|
||||
"""
|
||||
要素抽取结果明细表(面向“细则章节/小节提示词 -> 项目材料”抽取)。
|
||||
|
||||
字段对齐(用户侧语义):
|
||||
- 表类型 -> table_type
|
||||
- 年份 -> year
|
||||
- 表名称 -> table_name
|
||||
- 时间 -> extracted_at
|
||||
- 键 -> item_key
|
||||
- 值 -> item_value
|
||||
- 来源文档ID -> source_document_id
|
||||
- 来源行数 -> source_line_no / source_line_end
|
||||
"""
|
||||
|
||||
__tablename__ = "element_extraction_results"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
table_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
year: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
table_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
extracted_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
item_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
item_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
source_document_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("kb_documents.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
source_line_no: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
source_line_end: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class ElementConflict(Base):
|
||||
__tablename__ = "element_conflicts"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
table_id: Mapped[Optional[str]] = mapped_column(ForeignKey("element_tables.id", ondelete="SET NULL"), nullable=True)
|
||||
cell_id: Mapped[Optional[str]] = mapped_column(ForeignKey("element_cells.id", ondelete="SET NULL"), nullable=True)
|
||||
item_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
old_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
new_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
selected_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
source_document_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("kb_documents.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
source_line_no: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(16), default="pending")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class DocumentMarkdown(Base):
|
||||
__tablename__ = "document_markdowns"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(ForeignKey("kb_documents.id", ondelete="CASCADE"), nullable=False)
|
||||
extracted_filename: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
markdown_content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
content_hash: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(ForeignKey("kb_documents.id", ondelete="CASCADE"), nullable=False)
|
||||
markdown_id: Mapped[Optional[str]] = mapped_column(ForeignKey("document_markdowns.id", ondelete="CASCADE"), nullable=True)
|
||||
heading: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
|
||||
chunk_text: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
chunk_index: Mapped[int] = mapped_column(Integer, default=0)
|
||||
source_line_start: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
source_line_end: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
vector_id: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class ReportTemplate(Base):
|
||||
__tablename__ = "report_templates"
|
||||
|
||||
@ -316,9 +35,12 @@ class ReportTemplateSection(Base):
|
||||
__tablename__ = "report_template_sections"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
template_id: Mapped[str] = mapped_column(ForeignKey("report_templates.id", ondelete="CASCADE"), nullable=False)
|
||||
template_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("report_templates.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
section_key: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
section_title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# 本模块语义:section_prompt 即为该目录生成的"声明"
|
||||
section_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
section_output_contract: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
section_order: Mapped[int] = mapped_column(Integer, default=0)
|
||||
@ -327,54 +49,15 @@ class ReportTemplateSection(Base):
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class ReportGenerationJob(Base):
|
||||
__tablename__ = "report_generation_jobs"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
template_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("report_templates.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(16), default="pending") # pending/running/completed/failed
|
||||
progress: Mapped[int] = mapped_column(Integer, default=0)
|
||||
current_section_key: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
requested_by: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
options: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
snapshot: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
|
||||
class ReportGenerationChapter(Base):
|
||||
__tablename__ = "report_generation_chapters"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
job_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("report_generation_jobs.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
section_key: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
section_title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
section_order: Mapped[int] = mapped_column(Integer, default=0)
|
||||
status: Mapped[str] = mapped_column(String(16), default="pending") # pending/running/completed/failed
|
||||
content: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
prompt_text: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
evidence_payload: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
validation_payload: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
class ReportSectionReference(Base):
|
||||
"""章节参考范文(独立于模板配置,用于报告生成时拼入 prompt)"""
|
||||
"""章节参考范文(章节内容入库目标,格式与远程 MySQL 现有表一致)。"""
|
||||
|
||||
__tablename__ = "report_section_references"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
# 关联模板(与 report_template_sections.template_id 一致);历史数据可能为空
|
||||
template_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("report_templates.id", ondelete="CASCADE"), nullable=True
|
||||
ForeignKey("report_templates.id", ondelete="CASCADE"), nullable=True, index=True
|
||||
)
|
||||
source_file: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
section_key: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
@ -383,121 +66,3 @@ class ReportSectionReference(Base):
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class Department(Base):
|
||||
|
||||
__tablename__ = "department"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
parent_id: Mapped[Optional[str]] = mapped_column(ForeignKey("departments.id", ondelete="SET NULL"), nullable=True)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
username: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
|
||||
password_hash: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
department_id: Mapped[Optional[str]] = mapped_column(ForeignKey("departments.id", ondelete="SET NULL"), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class Role(Base):
|
||||
__tablename__ = "roles"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class Permission(Base):
|
||||
__tablename__ = "permissions"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
perm_key: Mapped[str] = mapped_column(String(128), nullable=False, unique=True)
|
||||
perm_type: Mapped[str] = mapped_column(String(32), nullable=False) # menu/project
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class RolePermission(Base):
|
||||
__tablename__ = "role_permissions"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
role_id: Mapped[str] = mapped_column(ForeignKey("roles.id", ondelete="CASCADE"), nullable=False)
|
||||
permission_id: Mapped[str] = mapped_column(ForeignKey("permissions.id", ondelete="CASCADE"), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class UserRole(Base):
|
||||
__tablename__ = "user_roles"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
role_id: Mapped[str] = mapped_column(ForeignKey("roles.id", ondelete="CASCADE"), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class ProjectMember(Base):
|
||||
__tablename__ = "project_members"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(32), default="editor")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class ProjectDepartment(Base):
|
||||
"""项目可见部门:绑定后,仅这些部门下的用户可访问(另有管理员与 project_members 例外)。"""
|
||||
|
||||
__tablename__ = "project_departments"
|
||||
__table_args__ = (UniqueConstraint("project_id", "department_id", name="uq_project_department"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
department_id: Mapped[str] = mapped_column(ForeignKey("departments.id", ondelete="CASCADE"), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class FillRecord(Base):
|
||||
"""回填记录:每次要素回填均留痕,支持证据追溯。"""
|
||||
|
||||
__tablename__ = "fill_records"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
|
||||
cell_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("element_cells.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
table_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("element_tables.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
row_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
col_key: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
year: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
filled_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
previous_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
source_document_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("kb_documents.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
source_document_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
source_line_no: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
source_line_end: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
source_quote: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
extraction_batch_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
extraction_model: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
|
||||
fill_type: Mapped[str] = mapped_column(String(16), nullable=False, default="auto")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from .logger import configure_logging, get_logger
|
||||
"""日志包:统一日志配置。"""
|
||||
|
||||
from log.logger import configure_logging, get_logger
|
||||
|
||||
__all__ = ["configure_logging", "get_logger"]
|
||||
|
||||
231
log/logger.py
231
log/logger.py
@ -1,38 +1,40 @@
|
||||
"""
|
||||
log/logger.py
|
||||
统一日志配置:
|
||||
|
||||
- 控制台输出(强制 UTF-8,修复 Windows 控制台中文乱码)
|
||||
- logs/app.log 全量日志(按大小轮转)
|
||||
- logs/error.log 仅 WARNING 及以上
|
||||
- logs/upload.log 上传/解析/入库链路(routers.template、services.*)
|
||||
- 接管 uvicorn 的 access/error 日志,统一落盘
|
||||
|
||||
幂等:重复调用只配置一次。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
_CONFIGURED = False
|
||||
_FILE_PROCESSING_PREFIXES = (
|
||||
"worker.document_processing",
|
||||
"services.kb_service",
|
||||
"services.es_docs",
|
||||
"services.element_llm_extract_service",
|
||||
"routers.extract",
|
||||
"function.documents",
|
||||
"function.vector_store",
|
||||
"repo.kb_documents",
|
||||
"routers.reference",
|
||||
"services.doc_convert_service",
|
||||
"services.reference_service",
|
||||
)
|
||||
_DOCUMENT_GENERATION_PREFIXES = (
|
||||
"services.write_service",
|
||||
"services.report_generation_service",
|
||||
"services.markdown_stream_service",
|
||||
|
||||
_FORMAT = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
|
||||
_DATEFMT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
# 上传/解析/入库链路相关的 logger 前缀(额外汇总到 upload.log)
|
||||
_UPLOAD_PREFIXES = (
|
||||
"routers.template",
|
||||
"services.file_parse_client",
|
||||
"services.section_extractor",
|
||||
"services.declaration_service",
|
||||
"services.llm_client",
|
||||
"services.llm_runner",
|
||||
"services.report_prompt_service",
|
||||
"services.report_runtime_store",
|
||||
)
|
||||
# 生成全过程追踪:完整记录输入 prompt / 调用模型 / 模型输出
|
||||
_GENERATION_TRACE_PREFIXES = (
|
||||
"generation.trace",
|
||||
)
|
||||
|
||||
# 交由 root 统一处理的第三方/框架 logger
|
||||
_DELEGATED_LOGGERS = ("uvicorn", "uvicorn.error", "uvicorn.access")
|
||||
|
||||
|
||||
class _PrefixFilter(logging.Filter):
|
||||
def __init__(self, prefixes: tuple[str, ...]) -> None:
|
||||
@ -41,144 +43,99 @@ class _PrefixFilter(logging.Filter):
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
name = str(record.name or "")
|
||||
return any(name == prefix or name.startswith(prefix + ".") for prefix in self.prefixes)
|
||||
return any(name == p or name.startswith(p + ".") for p in self.prefixes)
|
||||
|
||||
|
||||
class _OtherFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
name = str(record.name or "")
|
||||
if any(name == prefix or name.startswith(prefix + ".") for prefix in _FILE_PROCESSING_PREFIXES):
|
||||
return False
|
||||
if any(name == prefix or name.startswith(prefix + ".") for prefix in _DOCUMENT_GENERATION_PREFIXES):
|
||||
return False
|
||||
if any(name == prefix or name.startswith(prefix + ".") for prefix in _GENERATION_TRACE_PREFIXES):
|
||||
return False
|
||||
return True
|
||||
def _force_utf8_stream(stream):
|
||||
"""让控制台以 UTF-8 输出,避免 Windows GBK 控制台中文乱码。"""
|
||||
reconfigure = getattr(stream, "reconfigure", None)
|
||||
if callable(reconfigure):
|
||||
try:
|
||||
reconfigure(encoding="utf-8", errors="replace")
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
return stream
|
||||
|
||||
|
||||
def configure_logging(
|
||||
*,
|
||||
log_dir: str | Path = "logs",
|
||||
level: int = logging.INFO,
|
||||
log_dir: str | Path | None = None,
|
||||
level: str | int | None = None,
|
||||
to_console: bool | None = None,
|
||||
max_bytes: int | None = None,
|
||||
backup_count: int | None = None,
|
||||
) -> Path:
|
||||
"""配置全局日志。返回 app.log 路径。"""
|
||||
global _CONFIGURED
|
||||
|
||||
# 延迟导入,避免与 config 形成循环依赖问题
|
||||
from config import settings
|
||||
|
||||
log_dir = log_dir if log_dir is not None else settings.LOG_DIR
|
||||
level = level if level is not None else settings.LOG_LEVEL
|
||||
to_console = to_console if to_console is not None else settings.LOG_TO_CONSOLE
|
||||
max_bytes = max_bytes if max_bytes is not None else settings.LOG_MAX_BYTES
|
||||
backup_count = backup_count if backup_count is not None else settings.LOG_BACKUP_COUNT
|
||||
|
||||
if isinstance(level, str):
|
||||
level = getattr(logging, level.strip().upper(), logging.INFO)
|
||||
|
||||
target_dir = Path(log_dir).resolve()
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
other_log_path = target_dir / "other.log"
|
||||
app_log_path = target_dir / "app.log"
|
||||
|
||||
if _CONFIGURED:
|
||||
return other_log_path
|
||||
return app_log_path
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s | %(levelname)s | %(name)s | %(message)s"
|
||||
)
|
||||
formatter = logging.Formatter(_FORMAT, datefmt=_DATEFMT)
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
|
||||
file_processing_handler = RotatingFileHandler(
|
||||
target_dir / "file_processing.log",
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=5,
|
||||
def _rotating(name: str, *, backups: int | None = None) -> RotatingFileHandler:
|
||||
h = RotatingFileHandler(
|
||||
target_dir / name,
|
||||
maxBytes=max_bytes,
|
||||
backupCount=backups if backups is not None else backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_processing_handler.setLevel(level)
|
||||
file_processing_handler.setFormatter(formatter)
|
||||
file_processing_handler.addFilter(_PrefixFilter(_FILE_PROCESSING_PREFIXES))
|
||||
h.setFormatter(formatter)
|
||||
return h
|
||||
|
||||
document_generation_handler = RotatingFileHandler(
|
||||
target_dir / "document_generation.log",
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=5,
|
||||
encoding="utf-8",
|
||||
)
|
||||
document_generation_handler.setLevel(level)
|
||||
document_generation_handler.setFormatter(formatter)
|
||||
document_generation_handler.addFilter(_PrefixFilter(_DOCUMENT_GENERATION_PREFIXES))
|
||||
# 全量日志
|
||||
app_handler = _rotating("app.log")
|
||||
app_handler.setLevel(level)
|
||||
|
||||
other_handler = RotatingFileHandler(
|
||||
other_log_path,
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=5,
|
||||
encoding="utf-8",
|
||||
)
|
||||
other_handler.setLevel(level)
|
||||
other_handler.setFormatter(formatter)
|
||||
other_handler.addFilter(_OtherFilter())
|
||||
# 错误日志(WARNING+)
|
||||
error_handler = _rotating("error.log")
|
||||
error_handler.setLevel(logging.WARNING)
|
||||
|
||||
# ── 要素抽取独立日志 ─────────────────────────────────────────────
|
||||
element_extract_handler = RotatingFileHandler(
|
||||
target_dir / "element_extract.log",
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=10,
|
||||
encoding="utf-8",
|
||||
)
|
||||
element_extract_handler.setLevel(level)
|
||||
element_extract_handler.setFormatter(formatter)
|
||||
element_extract_handler.addFilter(_PrefixFilter(("services.element_llm_extract_service", "routers.extract")))
|
||||
# 上传/解析链路日志
|
||||
upload_handler = _rotating("upload.log", backups=max(backup_count, 10))
|
||||
upload_handler.setLevel(level)
|
||||
upload_handler.addFilter(_PrefixFilter(_UPLOAD_PREFIXES))
|
||||
|
||||
# ── 文件上传/解析独立日志 ─────────────────────────────────────────
|
||||
file_upload_handler = RotatingFileHandler(
|
||||
target_dir / "file_upload.log",
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=10,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_upload_handler.setLevel(level)
|
||||
file_upload_handler.setFormatter(formatter)
|
||||
file_upload_handler.addFilter(_PrefixFilter(("routers.reference", "routers.template", "services.doc_convert_service", "services.reference_service", "services.kb_service", "routers.kb")))
|
||||
handlers: list[logging.Handler] = [app_handler, error_handler, upload_handler]
|
||||
|
||||
# ── 报告生成独立日志 ──────────────────────────────────────────────
|
||||
report_generation_handler = RotatingFileHandler(
|
||||
target_dir / "report_generation.log",
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=10,
|
||||
encoding="utf-8",
|
||||
)
|
||||
report_generation_handler.setLevel(level)
|
||||
report_generation_handler.setFormatter(formatter)
|
||||
report_generation_handler.addFilter(_PrefixFilter(("services.report_generation_service", "services.report_prompt_service", "services.report_runtime_store", "services.markdown_stream_service")))
|
||||
if to_console:
|
||||
console_handler = logging.StreamHandler(_force_utf8_stream(sys.stdout))
|
||||
console_handler.setLevel(level)
|
||||
console_handler.setFormatter(formatter)
|
||||
handlers.append(console_handler)
|
||||
|
||||
# ── LLM 调用独立日志 ──────────────────────────────────────────────
|
||||
llm_handler = RotatingFileHandler(
|
||||
target_dir / "llm.log",
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=10,
|
||||
encoding="utf-8",
|
||||
)
|
||||
llm_handler.setLevel(level)
|
||||
llm_handler.setFormatter(formatter)
|
||||
llm_handler.addFilter(_PrefixFilter(("services.llm_client", "services.llm_runner")))
|
||||
root = logging.getLogger()
|
||||
root.setLevel(level)
|
||||
root.handlers.clear()
|
||||
for h in handlers:
|
||||
root.addHandler(h)
|
||||
|
||||
# ── 生成全过程追踪日志(输入 prompt / 模型 / 输出,单条可能较大)────────
|
||||
generation_trace_handler = RotatingFileHandler(
|
||||
target_dir / "generation_trace.log",
|
||||
maxBytes=50 * 1024 * 1024,
|
||||
backupCount=10,
|
||||
encoding="utf-8",
|
||||
)
|
||||
generation_trace_handler.setLevel(level)
|
||||
generation_trace_handler.setFormatter(formatter)
|
||||
generation_trace_handler.addFilter(_PrefixFilter(_GENERATION_TRACE_PREFIXES))
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setLevel(level)
|
||||
stream_handler.setFormatter(formatter)
|
||||
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(file_processing_handler)
|
||||
root_logger.addHandler(document_generation_handler)
|
||||
root_logger.addHandler(other_handler)
|
||||
root_logger.addHandler(element_extract_handler)
|
||||
root_logger.addHandler(file_upload_handler)
|
||||
root_logger.addHandler(report_generation_handler)
|
||||
root_logger.addHandler(llm_handler)
|
||||
root_logger.addHandler(generation_trace_handler)
|
||||
root_logger.addHandler(stream_handler)
|
||||
# 让 uvicorn 的日志走 root 统一落盘
|
||||
for name in _DELEGATED_LOGGERS:
|
||||
lg = logging.getLogger(name)
|
||||
lg.handlers.clear()
|
||||
lg.propagate = True
|
||||
lg.setLevel(level)
|
||||
|
||||
_CONFIGURED = True
|
||||
return other_log_path
|
||||
logging.getLogger(__name__).info("日志系统已初始化 | dir=%s | level=%s", target_dir, logging.getLevelName(level))
|
||||
return app_log_path
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
|
||||
77
main.py
77
main.py
@ -1,43 +1,50 @@
|
||||
"""
|
||||
main.py
|
||||
报告生成独立服务 FastAPI 入口。
|
||||
报告模板管理模块 FastAPI 应用入口。
|
||||
|
||||
启动方式:
|
||||
uvicorn main:app --reload
|
||||
或:python main.py
|
||||
启动:
|
||||
uvicorn main:app --host 0.0.0.0 --port 8100
|
||||
或:
|
||||
python main.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from config import settings
|
||||
from database import engine, init_database
|
||||
from log import configure_logging
|
||||
from routers import report
|
||||
from database import init_database
|
||||
from log import configure_logging, get_logger
|
||||
from routers import template
|
||||
|
||||
# 在创建应用前完成日志配置
|
||||
configure_logging()
|
||||
_log = logging.getLogger(__name__)
|
||||
logger = get_logger("app")
|
||||
access_logger = get_logger("app.access")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用启动与关闭时执行。"""
|
||||
async def lifespan(_app: FastAPI):
|
||||
logger.info("应用启动 | %s v%s", settings.APP_TITLE, settings.APP_VERSION)
|
||||
if settings.DB_AUTO_CREATE_TABLES:
|
||||
try:
|
||||
init_database()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("启动建表失败(不影响已存在表的使用): %s", e)
|
||||
yield
|
||||
engine.dispose()
|
||||
logger.info("应用关闭")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title=settings.APP_TITLE,
|
||||
version=settings.APP_VERSION,
|
||||
description=settings.APP_DESCRIPTION,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
@ -48,19 +55,47 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(report.router, prefix="/api/v1")
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
if not settings.LOG_HTTP_ACCESS:
|
||||
return await call_next(request)
|
||||
|
||||
req_id = uuid.uuid4().hex[:8]
|
||||
start = time.perf_counter()
|
||||
client = request.client.host if request.client else "-"
|
||||
access_logger.info("→ [%s] %s %s | client=%s", req_id, request.method, request.url.path, client)
|
||||
try:
|
||||
response = await call_next(request)
|
||||
except Exception:
|
||||
cost = (time.perf_counter() - start) * 1000
|
||||
access_logger.exception("✗ [%s] %s %s | %.1fms | 未处理异常", req_id, request.method, request.url.path, cost)
|
||||
raise
|
||||
cost = (time.perf_counter() - start) * 1000
|
||||
access_logger.info(
|
||||
"← [%s] %s %s | %s | %.1fms",
|
||||
req_id, request.method, request.url.path, response.status_code, cost,
|
||||
)
|
||||
response.headers["X-Request-ID"] = req_id
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/health", tags=["系统"], summary="健康检查")
|
||||
def health_check():
|
||||
"""确认服务存活,返回版本信息。"""
|
||||
app.include_router(template.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["健康检查"])
|
||||
def health() -> dict:
|
||||
return {"status": "ok", "version": settings.APP_VERSION}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
# log_config=None:沿用本模块 configure_logging() 的配置,避免被 uvicorn 覆盖
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host=settings.HOST,
|
||||
port=settings.PORT,
|
||||
reload=settings.RELOAD,
|
||||
log_config=None,
|
||||
)
|
||||
|
||||
150
prompts/report_generation/template_prompt_rules.py
Normal file
150
prompts/report_generation/template_prompt_rules.py
Normal file
@ -0,0 +1,150 @@
|
||||
"""Default section prompt and example variables used by report templates."""
|
||||
|
||||
DEFAULT_SECTION_PROMPT = "按《炼油化工建设项目后评价报告编制细则(修订)》对应章节要求撰写,缺失信息标注“待补充”,禁止编造。"
|
||||
|
||||
SECTION_PROMPT_RULES: list[tuple[str, str]] = [
|
||||
(
|
||||
"1.1 项目基本情况",
|
||||
"严格按以下字段顺序输出:项目名称、建设单位、建设地点、建设类型、起止时间、建设内容、建设投资、占地面积。"
|
||||
"所有事实均须来自证据材料,缺失写“待补充”,禁止编造,禁止复刻示例原文。",
|
||||
),
|
||||
(
|
||||
"1.2 项目决策要点",
|
||||
"按“项目背景(1)2)3)+ 预期目标(规模/质量/效益)”撰写。"
|
||||
"证据依据用于内部校验,不在报告正文显示“【证据依据:...】”标记。"
|
||||
"背景每条先写 2~4 句书面语再视需要附表;表格须用 Markdown,禁止粘贴未对齐的原始文本表。"
|
||||
"预期目标须从证据归纳:证据中已有产能、产量(万吨/年)、辛烷值、国VI、收入、利润等时不得三条目标全写待补充。"
|
||||
"改扩建项目应补充原装置问题与改造动因。",
|
||||
),
|
||||
(
|
||||
"1.3 项目实施情况",
|
||||
"说明实际建成内容及与批复方案的差异,给出关键里程碑时间线(立项、批复、开工、中交、投产等)。",
|
||||
),
|
||||
(
|
||||
"1.4 项目运行情况",
|
||||
"说明投产以来装置运行负荷、产量及主要财务运行情况(营业收入、利润等),并与预期目标对照。",
|
||||
),
|
||||
(
|
||||
"2.1.1 资源与原料评价",
|
||||
"必须与《模版.doc》中本节结构一致:先简述可研原料来源与实际一致性;"
|
||||
"再给出「原料数量及组成对比表」(列含:序号、原料名称、规格、可研/初设/实际各自的「数量(万吨)」「占比(%)」、备注,须有合计行);"
|
||||
"再给出运行负荷与加工量等对比叙述;"
|
||||
"再给出「原料性质对比表(醚后碳四)」(列:序号、名称、可研报告、初步设计、实际生产、备注;行至少含密度、硫含量、氮含量等,可按项目增删);"
|
||||
"最后给出组成/性质变化分析及后评价判断。"
|
||||
"表格一律使用 Markdown 表头+分隔行,禁止粘贴未对齐的纯文本表。"
|
||||
"本节第一张主表表题须固定为「表1 原料数量及组成对比表」(与章节输出合同及要素管理默认表题一致,与正文节内表序一致);"
|
||||
"禁止用安评/专篇中「表2.6-1 原料选择加氢工艺技术对比」等同号异题表替代上述两张模版主表;"
|
||||
"本节只保留模版主表,不输出附录与“非模版主表”字样。"
|
||||
"数据来自证据包,缺失填待补充,禁止编造。",
|
||||
),
|
||||
(
|
||||
"2.1.2.1 产品方案评价",
|
||||
"按“事实依据—评价判断—问题与建议”组织内容,但不要在正文中显示"
|
||||
"“【事实依据】”“【评价判断】”“【问题与建议】”这三个标题标签;"
|
||||
"以自然段或编号表达即可。缺失信息写“待补充”,禁止编造。",
|
||||
),
|
||||
("2", "对照前期工作细则,评价可研、前评估、初设及决策程序的合规性、完整性与合理性。"),
|
||||
("3", "对照建设实施细则,评价建设管理、招投标、设计、采购、施工、监理、质量、HSE与竣工验收。"),
|
||||
("4", "对照生产运行细则,评价生产准备、联合试运、达标情况、工艺技术、设备与辅助系统运行效果。"),
|
||||
("5", "对照投资与经济效益细则,评价投资控制、资金到位、经营效益、后评价测算及不确定性分析。"),
|
||||
("6", "对照影响与持续性细则,评价环境、安全、科技、社会影响及资源、产品、技术经济竞争力持续性。"),
|
||||
("7", "给出综合评价结论、成功度评价、主要经验、问题与建议,结论须与前文证据保持一致。"),
|
||||
]
|
||||
|
||||
|
||||
SECTION_EXAMPLE_RULES: list[tuple[str, str]] = [
|
||||
(
|
||||
"1",
|
||||
"项目名称:宁夏石化分公司16万吨/年烷基化装置建设项目独立后评价。"
|
||||
"建设内容包括16万吨/年烷基化单元、1.5万吨/年废酸再生单元及配套公用工程。"
|
||||
"批复可研估算22812万元,批复初设概算25079万元,竣工决算32486万元。"
|
||||
"烷基化单元2018年11月投产,废酸再生单元2020年11月投运。\n"
|
||||
"---EXAMPLE---\n"
|
||||
"项目运行情况:投产后烷基化油收率保持在86%以上,高于设计值81.28%。"
|
||||
"受原油加工负荷影响,阶段性加工负荷与可研预期存在偏差;"
|
||||
"通过全厂优化运行后,烷基化装置加工负荷保持在90%以上,"
|
||||
"满足国VI汽油质量升级需要。",
|
||||
),
|
||||
(
|
||||
"2",
|
||||
"前期决策评价:项目可研由具备甲级资质单位编制,前评估提出57条意见并完成落实。"
|
||||
"原料来源为MTBE装置醚后碳四,来源与实际生产一致;"
|
||||
"产品全部调入汽油系统,由中石油西北销售公司统一销售,市场风险可控。"
|
||||
"工艺路线采用中石油自主硫酸法烷基化技术并配套P&P湿法废酸再生技术。\n"
|
||||
"---EXAMPLE---\n"
|
||||
"前期工作结论:项目在可研、前评估、初设及决策程序方面总体规范,"
|
||||
"对国VI质量升级任务支撑明确;"
|
||||
"后续应针对专利技术工程化经验不足、概算约束偏紧等问题,"
|
||||
"在施工图阶段加强工程量校核与投资风险预控。",
|
||||
),
|
||||
(
|
||||
"3",
|
||||
"建设实施评价:项目采用“业主+监理+E+PC”管理模式。"
|
||||
"单位工程质量合格率100%,HSE总体受控。"
|
||||
"但施工图设计进度与采购协同不足,废酸再生单元受工艺包与设备整改影响,"
|
||||
"中交与投产明显滞后。\n"
|
||||
"---EXAMPLE---\n"
|
||||
"招采与设计变更情况:共发生设计变更118份,变更费用约5033万元,"
|
||||
"对进度与投资控制产生不利影响。"
|
||||
"经验上应优先采用以设计为龙头的EPC协同模式,"
|
||||
"并提前锁定关键设备选材和高温腐蚀工况边界条件。",
|
||||
),
|
||||
(
|
||||
"4",
|
||||
"生产运行评价:烷基化单元2018年11月一次投产成功,"
|
||||
"废酸再生单元经四次整改后于2020年11月投运。"
|
||||
"运行期主要问题集中在原料杂质波动、局部腐蚀及部分指标偏离设计值,"
|
||||
"通过优化烷烯比、补酸策略及上游分离精度逐步改进。\n"
|
||||
"---EXAMPLE---\n"
|
||||
"达标评价结论:烷基化单元在标定中出现辛烷值、硫含量、酸耗偏差,"
|
||||
"废酸再生单元处理能力达标但能耗偏高。"
|
||||
"总体上项目工艺可用、装置可稳态运行,"
|
||||
"需持续优化操作和设备防腐管理以提升长周期绩效。",
|
||||
),
|
||||
(
|
||||
"5",
|
||||
"主要经济指标对比示例(宁夏石化项目):\n\n"
|
||||
"表5-1 主要经济指标对比表\n"
|
||||
"| 指标 | 单位 | 可研值 | 后评价值 | 差值 |\n"
|
||||
"| --- | --- | --- | --- | --- |\n"
|
||||
"| 报批总投资 | 万元 | 22812 | 32486 | +9674 |\n"
|
||||
"| 年均税后利润 | 万元 | 20652 | 13283 | -7369 |\n"
|
||||
"| 税后内部收益率 | % | 85.12 | 35.46 | -49.66 |\n"
|
||||
"| 静态投资回收期 | 年 | 2.29 | 4.38 | +2.09 |\n"
|
||||
"\n"
|
||||
"示例结论:项目收益率虽较可研明显下降,但仍高于基准收益率,"
|
||||
"基本实现效益目标;投资控制偏弱是主要短板。\n"
|
||||
"---EXAMPLE---\n"
|
||||
"5.2.2 投资水平分析正文参考(勿输出为表5-2;表5-2 仅用于 5.2.1 投资变动情况表):\n\n"
|
||||
"同类烷基化装置单位工程费对标(撰写段落时参考,非模版表号):\n"
|
||||
"| 项目 | 规模(万吨/年) | 工程费(万元) | 单位造价(元/吨) |\n"
|
||||
"| --- | --- | --- | --- |\n"
|
||||
"| 宁夏石化 | 16 | 15159 | 947 |\n"
|
||||
"| 乌石化 | 20 | 21286 | 1064 |\n"
|
||||
"| 锦州石化 | 25 | 23401 | 936 |\n"
|
||||
"| 兰州石化 | 20 | 14377 | 719 |",
|
||||
),
|
||||
(
|
||||
"6",
|
||||
"影响评价示例:项目落实环评及安全“三同时”,废气、废水、噪声监测达标,"
|
||||
"投产以来未发生重大安全事故,环境与安全风险总体可控。\n"
|
||||
"---EXAMPLE---\n"
|
||||
"持续性评价示例(宁夏石化项目):\n\n"
|
||||
"表6-1 装置技术经济指标对比表\n"
|
||||
"| 项目名称 | 技术来源 | 规模(万吨/年) | 物耗(Wt)% | 能耗(kgEo/t) | 产品质量 | 产品收率(Wt)% | 排名 |\n"
|
||||
"| --- | --- | --- | --- | --- | --- | --- | --- |\n"
|
||||
"| 宁夏石化烷基化 | 自主硫酸法烷基化 | 16 | 待补充 | 待补充 | 国VI调和组分 | 86以上 | 待补充 |\n"
|
||||
"| 同类装置A | … | … | … | … | … | … | … |",
|
||||
),
|
||||
(
|
||||
"7",
|
||||
"综合结论示例:宁夏石化16万吨/年烷基化项目完成了国VI汽油质量升级目标,"
|
||||
"生产运行总体平稳,效益指标虽低于可研预测但仍高于基准收益率。"
|
||||
"项目综合评分8.62分,评级为“良”。\n"
|
||||
"---EXAMPLE---\n"
|
||||
"建议示例:"
|
||||
"1)持续优化自主技术工程化能力,重点治理腐蚀与高温环节选材问题;"
|
||||
"2)加强设计-采购-施工一体化协同,减少大额变更;"
|
||||
"3)围绕原料品质与上游协同运行,进一步提升装置长期经济性。",
|
||||
),
|
||||
]
|
||||
@ -1,23 +1,8 @@
|
||||
# Web 框架
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic
|
||||
pydantic-settings
|
||||
|
||||
# 数据库(MySQL)
|
||||
sqlalchemy
|
||||
pymysql
|
||||
cryptography
|
||||
|
||||
# HTTP(LLM / Embedding 调用)
|
||||
requests
|
||||
|
||||
# 附图提取(解析项目 .docx 内嵌图片)
|
||||
python-docx
|
||||
|
||||
# 向量检索(Milvus + Embeddings + BM25)
|
||||
langchain-core
|
||||
langchain-text-splitters
|
||||
langchain-openai
|
||||
langchain-milvus
|
||||
pymilvus
|
||||
fastapi>=0.115.6
|
||||
uvicorn[standard]>=0.34.0
|
||||
python-multipart>=0.0.20
|
||||
pydantic>=2.11
|
||||
pydantic-settings>=2.7.1
|
||||
SQLAlchemy>=2.0.36
|
||||
PyMySQL>=1.1.1
|
||||
requests>=2.32.3
|
||||
|
||||
346
routers/template.py
Normal file
346
routers/template.py
Normal file
@ -0,0 +1,346 @@
|
||||
"""
|
||||
routers/template.py
|
||||
报告模板管理:上传文档 → 远程解析为 Markdown → 抽取目录并为每个目录生成声明
|
||||
→ 创建模板(目录 + 声明)→ 按章节拆分正文并入库远程 MySQL。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from database.models import (
|
||||
ReportSectionReference,
|
||||
ReportTemplate,
|
||||
ReportTemplateSection,
|
||||
)
|
||||
from schemas.template import (
|
||||
SectionReferenceItem,
|
||||
TemplateItem,
|
||||
TemplateSectionItem,
|
||||
UploadTemplateResult,
|
||||
)
|
||||
from services.template_prompt_mapper import resolve_uploaded_template_prompts
|
||||
from services.desensitize_service import count_masked_numbers, desensitize_content
|
||||
from services.file_parse_client import parse_file_to_markdown
|
||||
from services.section_extractor import (
|
||||
clamp_text_bytes,
|
||||
extract_sections_from_text,
|
||||
normalize_section_key,
|
||||
parse_section_order,
|
||||
split_markdown_into_sections,
|
||||
)
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/templates", tags=["报告模板管理"])
|
||||
|
||||
ALLOWED_SUFFIXES = {".doc", ".docx", ".pdf", ".txt", ".md", ".html", ".htm", ".rtf"}
|
||||
|
||||
|
||||
def _clamp_title(value: str | None) -> str:
|
||||
return str(value or "").strip()[:255]
|
||||
|
||||
|
||||
def _build_description(source_file: str | None) -> str:
|
||||
sf = str(source_file or "").strip()
|
||||
base = "通过文件上传导入"
|
||||
return f"{base}\n来源文件:{sf}" if sf else base
|
||||
|
||||
|
||||
def _find_duplicates(db: Session, filename: str) -> dict:
|
||||
"""按来源文件名查找已导入的模板与章节范文,用于重复检查。"""
|
||||
template_ids = [
|
||||
t.id
|
||||
for t in db.query(ReportTemplate.id)
|
||||
.filter(ReportTemplate.description.like(f"%来源文件:{filename}%"))
|
||||
.all()
|
||||
]
|
||||
ref_count = (
|
||||
db.query(ReportSectionReference)
|
||||
.filter(ReportSectionReference.source_file == filename)
|
||||
.count()
|
||||
)
|
||||
return {"template_ids": template_ids, "ref_count": ref_count}
|
||||
|
||||
|
||||
def _delete_by_source(db: Session, filename: str, template_ids: list[str]) -> None:
|
||||
"""删除指定来源文件已导入的模板(含章节)与章节范文。"""
|
||||
for tid in template_ids:
|
||||
db.query(ReportTemplateSection).filter(
|
||||
ReportTemplateSection.template_id == tid
|
||||
).delete(synchronize_session=False)
|
||||
db.query(ReportTemplate).filter(ReportTemplate.id == tid).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.query(ReportSectionReference).filter(
|
||||
ReportSectionReference.source_file == filename
|
||||
).delete(synchronize_session=False)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _extract_source_file(description: str | None) -> str | None:
|
||||
"""从模板描述中解析来源文件名(上传时写入 '来源文件:xxx')。"""
|
||||
m = re.search(r"来源文件\s*[::]\s*(.+)$", str(description or ""))
|
||||
if not m:
|
||||
return None
|
||||
return (m.group(1) or "").strip() or None
|
||||
|
||||
|
||||
def _serialize_template(db: Session, template_id: str) -> TemplateItem:
|
||||
t = db.query(ReportTemplate).filter(ReportTemplate.id == template_id).first()
|
||||
if not t:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
sections = (
|
||||
db.query(ReportTemplateSection)
|
||||
.filter(ReportTemplateSection.template_id == t.id)
|
||||
.order_by(ReportTemplateSection.section_order.asc())
|
||||
.all()
|
||||
)
|
||||
src = _extract_source_file(t.description)
|
||||
return TemplateItem(
|
||||
id=t.id,
|
||||
name=t.name,
|
||||
description=t.description,
|
||||
sourceFile=src,
|
||||
createdAt=t.created_at.strftime("%Y-%m-%d %H:%M:%S") if t.created_at else None,
|
||||
updatedAt=t.updated_at.strftime("%Y-%m-%d %H:%M:%S") if t.updated_at else None,
|
||||
isDefault=t.is_default,
|
||||
isActive=t.is_active,
|
||||
sections=[
|
||||
TemplateSectionItem(
|
||||
id=s.id,
|
||||
sectionKey=s.section_key,
|
||||
sectionTitle=s.section_title,
|
||||
sectionPrompt=s.section_prompt,
|
||||
sectionOutputContract=s.section_output_contract,
|
||||
sectionOrder=s.section_order,
|
||||
examples=s.examples,
|
||||
)
|
||||
for s in sections
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload", response_model=UploadTemplateResult, summary="上传文档并解析为模板(目录+声明)与章节内容")
|
||||
async def upload_template_route(
|
||||
file: UploadFile = File(...),
|
||||
force: bool = Query(False, description="为 true 时覆盖同名来源文件的已有模板与章节后重新导入"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
filename = (file.filename or "").strip()
|
||||
suffix = os.path.splitext(filename)[1].lower()
|
||||
if suffix not in ALLOWED_SUFFIXES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件格式({suffix or '未知'});支持:{', '.join(sorted(ALLOWED_SUFFIXES))}",
|
||||
)
|
||||
|
||||
# 0) 重复检查:同名来源文件已导入则拒绝(force=true 则先删除旧数据再重导)
|
||||
dup = _find_duplicates(db, filename)
|
||||
if dup["template_ids"] or dup["ref_count"]:
|
||||
if not force:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=(
|
||||
f"文件「{filename}」已导入(模板 {len(dup['template_ids'])} 个、"
|
||||
f"章节 {dup['ref_count']} 条),不可重复入库;"
|
||||
f"如需覆盖请使用 force=true 重新上传。"
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
"重复导入覆盖 | file=%s | 删除旧模板=%s | 旧章节=%s",
|
||||
filename, len(dup["template_ids"]), dup["ref_count"],
|
||||
)
|
||||
_delete_by_source(db, filename, dup["template_ids"])
|
||||
|
||||
# 1) 落盘临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
tmp_path = tmp.name
|
||||
tmp.write(await file.read())
|
||||
|
||||
try:
|
||||
# 2) 远程解析为 Markdown
|
||||
try:
|
||||
markdown = parse_file_to_markdown(tmp_path)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise HTTPException(status_code=400, detail=f"文档解析失败:{e}")
|
||||
|
||||
if not markdown or not markdown.strip():
|
||||
raise HTTPException(status_code=400, detail="解析结果为空")
|
||||
|
||||
# 3) 抽取目录
|
||||
sections = extract_sections_from_text(markdown)
|
||||
if not sections:
|
||||
raise HTTPException(status_code=400, detail="未从文档中识别到章节标题(目录)")
|
||||
|
||||
# 4) 按标题拆分正文,对每段正文脱敏(去精确数字等);跳过无正文的章节
|
||||
raw_sections = split_markdown_into_sections(markdown)
|
||||
max_bytes = int(getattr(settings, "SECTION_CONTENT_MAX_BYTES", 60000) or 60000)
|
||||
ref_sections: list[dict] = []
|
||||
masked_total = 0
|
||||
skipped_empty = 0
|
||||
truncated = 0
|
||||
for s in raw_sections:
|
||||
filtered = desensitize_content(s["content"])
|
||||
if not filtered.strip():
|
||||
skipped_empty += 1
|
||||
continue
|
||||
masked_total += max(count_masked_numbers(s["content"], filtered), 0)
|
||||
clamped = clamp_text_bytes(filtered, max_bytes)
|
||||
if clamped is not filtered and len(clamped) != len(filtered):
|
||||
truncated += 1
|
||||
ref_sections.append({**s, "content": clamped})
|
||||
logger.info(
|
||||
"解析结果 | md_len=%s | toc=%s | 入库章节=%s | 跳过空正文=%s | 截断超长=%s | 脱敏数字串=%s",
|
||||
len(markdown), len(sections), len(ref_sections), skipped_empty, truncated, masked_total,
|
||||
)
|
||||
|
||||
# 5) 复刻 eval_report:将上传目录匹配默认模板,得到每章节 提示词/输出合同/示例
|
||||
# 放到工作线程执行:内部含并行 LLM 调用,避免阻塞事件循环(上传期间仍可并发处理其它请求)
|
||||
resolved = await asyncio.to_thread(resolve_uploaded_template_prompts, sections)
|
||||
logger.info(
|
||||
"提示词匹配完成 | 章节=%s | 命中提示词=%s",
|
||||
len(resolved), sum(1 for r in resolved if (r.get("sectionPrompt") or "").strip()),
|
||||
)
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
# 6) 创建模板(目录 + 提示词/输出合同/示例)
|
||||
template = ReportTemplate(
|
||||
id=uuid.uuid4().hex,
|
||||
name=os.path.splitext(filename)[0] or "上传模板",
|
||||
description=_build_description(filename),
|
||||
is_default=False,
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(template)
|
||||
db.flush()
|
||||
|
||||
for i, sec in enumerate(sections):
|
||||
r = resolved[i] if i < len(resolved) else {}
|
||||
db.add(
|
||||
ReportTemplateSection(
|
||||
id=uuid.uuid4().hex,
|
||||
template_id=template.id,
|
||||
section_key=normalize_section_key(sec["sectionKey"], sec["sectionTitle"]),
|
||||
section_title=_clamp_title(sec["sectionTitle"]),
|
||||
section_prompt=(r.get("sectionPrompt") or None),
|
||||
section_output_contract=(r.get("sectionOutputContract") or None),
|
||||
section_order=i,
|
||||
examples="",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
# 7) 章节内容入库(report_section_references 格式)
|
||||
saved_refs: list[SectionReferenceItem] = []
|
||||
for sec in ref_sections:
|
||||
ref = ReportSectionReference(
|
||||
id=uuid.uuid4().hex,
|
||||
template_id=template.id,
|
||||
source_file=filename,
|
||||
section_key=sec["section_key"],
|
||||
section_title=_clamp_title(sec["section_title"]),
|
||||
section_order=parse_section_order(sec["section_key"]),
|
||||
content=sec["content"],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(ref)
|
||||
saved_refs.append(
|
||||
SectionReferenceItem(
|
||||
id=ref.id,
|
||||
templateId=ref.template_id,
|
||||
sourceFile=ref.source_file,
|
||||
sectionKey=ref.section_key,
|
||||
sectionTitle=ref.section_title,
|
||||
sectionOrder=ref.section_order,
|
||||
contentLength=len(ref.content or ""),
|
||||
content=ref.content or "",
|
||||
)
|
||||
)
|
||||
|
||||
db.commit()
|
||||
logger.info(
|
||||
"模板上传完成 | file=%s | toc=%s | refs=%s",
|
||||
filename, len(sections), len(saved_refs),
|
||||
)
|
||||
|
||||
return UploadTemplateResult(
|
||||
template=_serialize_template(db, template.id),
|
||||
sourceFile=filename,
|
||||
markdownLength=len(markdown),
|
||||
totalSections=len(sections),
|
||||
totalReferences=len(saved_refs),
|
||||
references=saved_refs,
|
||||
parseWarnings=[],
|
||||
)
|
||||
except HTTPException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e: # noqa: BLE001
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail=f"模板创建失败:{e}")
|
||||
finally:
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@router.get("", response_model=list[TemplateItem], summary="获取模板列表")
|
||||
def list_templates_route(db: Session = Depends(get_db)):
|
||||
rows = (
|
||||
db.query(ReportTemplate)
|
||||
.order_by(ReportTemplate.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return [_serialize_template(db, r.id) for r in rows]
|
||||
|
||||
|
||||
@router.get("/{template_id}", response_model=TemplateItem, summary="获取模板详情")
|
||||
def get_template_route(template_id: str, db: Session = Depends(get_db)):
|
||||
return _serialize_template(db, template_id)
|
||||
|
||||
|
||||
@router.delete("/{template_id}", status_code=204, summary="删除模板(含其来源文件的章节范文)")
|
||||
def delete_template_route(template_id: str, db: Session = Depends(get_db)):
|
||||
t = db.query(ReportTemplate).filter(ReportTemplate.id == template_id).first()
|
||||
if not t:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
source_file = _extract_source_file(t.description)
|
||||
db.query(ReportTemplateSection).filter(
|
||||
ReportTemplateSection.template_id == t.id
|
||||
).delete(synchronize_session=False)
|
||||
db.delete(t)
|
||||
# 同步删除该模板在 report_section_references 中的章节内容:
|
||||
# 优先按 template_id 精确删除;同时按 source_file 兜底清理历史数据(template_id 为空的旧记录)
|
||||
conditions = [ReportSectionReference.template_id == t.id]
|
||||
if source_file:
|
||||
conditions.append(ReportSectionReference.source_file == source_file)
|
||||
ref_deleted = (
|
||||
db.query(ReportSectionReference)
|
||||
.filter(or_(*conditions))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
db.commit()
|
||||
logger.info(
|
||||
"模板删除完成 | template_id=%s | source_file=%s | 删除章节范文=%s",
|
||||
template_id, source_file, ref_deleted,
|
||||
)
|
||||
51
schemas/template.py
Normal file
51
schemas/template.py
Normal file
@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TemplateSectionItem(BaseModel):
|
||||
id: str
|
||||
sectionKey: str
|
||||
sectionTitle: str
|
||||
# 复刻 eval_report:章节提示词 / 输出合同 / 示例
|
||||
sectionPrompt: Optional[str] = None
|
||||
sectionOutputContract: Optional[str] = None
|
||||
sectionOrder: int = 0
|
||||
examples: Optional[str] = None
|
||||
|
||||
|
||||
class TemplateItem(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
sourceFile: Optional[str] = None
|
||||
createdAt: Optional[str] = None
|
||||
updatedAt: Optional[str] = None
|
||||
isDefault: bool = False
|
||||
isActive: bool = True
|
||||
sections: List[TemplateSectionItem] = []
|
||||
|
||||
|
||||
class SectionReferenceItem(BaseModel):
|
||||
id: str
|
||||
templateId: Optional[str] = None
|
||||
sourceFile: str
|
||||
sectionKey: str
|
||||
sectionTitle: str
|
||||
sectionOrder: int = 0
|
||||
contentLength: int = 0
|
||||
content: str = ""
|
||||
|
||||
|
||||
class UploadTemplateResult(BaseModel):
|
||||
"""上传解析结果:模板(目录 + 声明)+ 入库的章节内容。"""
|
||||
|
||||
template: TemplateItem
|
||||
sourceFile: str
|
||||
markdownLength: int
|
||||
totalSections: int
|
||||
totalReferences: int
|
||||
references: List[SectionReferenceItem] = []
|
||||
parseWarnings: List[str] = []
|
||||
126
services/declaration_service.py
Normal file
126
services/declaration_service.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""
|
||||
services/declaration_service.py
|
||||
为每个目录(章节)生成一个"声明"。
|
||||
|
||||
声明:一段说明该章节应写什么、结构与约束的撰写指引,存入
|
||||
report_template_sections.section_prompt。
|
||||
|
||||
优先用 LLM(结合章节标题 + 该章节正文)生成;未配置或失败时
|
||||
回退到确定性模板,保证流程稳定可用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from config import settings
|
||||
from services.llm_client import chat_completions_json, llm_configured
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"你是报告模板专家。任务:阅读给定章节的范文(参考正文),总结这一章应该怎么写,"
|
||||
"作为后续报告撰写该章节的写作指引。需提炼:①内容要点(写哪些事项);"
|
||||
"②组织结构(应有的小节/条目顺序);③数据与口径要求(需引用的对比/指标/表格等);"
|
||||
"④写作约束(先事实后评价、缺失写「待补充」、不得编造)。"
|
||||
"严格要求:不要输出任何思考过程或解释;只输出 JSON 对象 {\"guide\": \"...\"};"
|
||||
"guide 为 300 字以内的写作指引纯文本(不含 markdown 标题);"
|
||||
"范文缺失或过短时,按章节标题给出通用写作指引。"
|
||||
)
|
||||
|
||||
|
||||
def _strip_number_prefix(title: str) -> str:
|
||||
t = str(title or "").strip()
|
||||
t = re.sub(r"^(?:\d+(?:\.\d+)*|[一二三四五六七八九十]+[、..])\s*", "", t).strip()
|
||||
return t
|
||||
|
||||
|
||||
def _fallback_declaration(section_title: str) -> str:
|
||||
label = _strip_number_prefix(section_title) or "本章节"
|
||||
return (
|
||||
f"本章节为「{label}」。撰写时应紧扣标题主题,先陈述事实与数据,再给出分析与评价;"
|
||||
f"结构需与标题保持一致,条理清晰、用语规范;"
|
||||
f"所有结论须有依据,缺失信息写「待补充」,禁止编造。"
|
||||
)
|
||||
|
||||
|
||||
def _build_user_prompt(section_title: str, content: str) -> str:
|
||||
body = (content or "").strip()
|
||||
if len(body) > 2500:
|
||||
body = body[:2500]
|
||||
body_block = f"\n\n该章节范文(参考正文,节选):\n```\n{body}\n```" if body else ""
|
||||
return (
|
||||
f"章节标题:{section_title}{body_block}\n\n"
|
||||
f"请根据上述范文,总结该章节应该怎么写,并只返回 JSON:{{\"guide\": \"300字以内的写作指引\"}}。"
|
||||
)
|
||||
|
||||
|
||||
def generate_declaration(section_title: str, content: str = "") -> str:
|
||||
"""根据范文为单个章节生成"怎么写"的写作指引(JSON 取 guide,自动剔除思考过程)。"""
|
||||
use_llm = bool(getattr(settings, "DECLARATION_USE_LLM", True)) and llm_configured()
|
||||
if not use_llm:
|
||||
return _fallback_declaration(section_title)
|
||||
try:
|
||||
data = chat_completions_json(
|
||||
system_prompt=_SYSTEM_PROMPT,
|
||||
user_prompt=_build_user_prompt(section_title, content),
|
||||
temperature=0.2,
|
||||
max_tokens=2048,
|
||||
)
|
||||
guide = str((data or {}).get("guide") or "").strip()
|
||||
if guide:
|
||||
return guide
|
||||
except Exception as e: # noqa: BLE001 - 兜底,保证主流程不被 LLM 影响
|
||||
logger.warning("生成章节声明失败,使用兜底模板 | title=%s | err=%s", section_title, e)
|
||||
return _fallback_declaration(section_title)
|
||||
|
||||
|
||||
def _content_for_section(s: dict, content_by_key: dict[str, str]) -> str:
|
||||
"""目录键可能是 canonical 形式,优先用标题中的编号前缀去匹配正文。"""
|
||||
title = str(s.get("sectionTitle") or "")
|
||||
m = re.match(r"^(\d+(?:\.\d+)*)", title.strip())
|
||||
num = m.group(1) if m else ""
|
||||
return content_by_key.get(num, "") or content_by_key.get(str(s.get("sectionKey") or ""), "")
|
||||
|
||||
|
||||
def generate_declarations(sections: list[dict], content_by_key: dict[str, str] | None = None) -> list[str]:
|
||||
"""
|
||||
为目录中每个章节并发生成"怎么写"的写作指引(基于范文)。
|
||||
sections: [{sectionKey, sectionTitle}, ...]
|
||||
content_by_key: 章节编号/键 -> 范文正文,用于为指引提供上下文(可选)。
|
||||
|
||||
每章一次 LLM 调用,多线程并发以打满 GPU(LLM 为网络 I/O,线程下真正并行)。
|
||||
"""
|
||||
content_by_key = content_by_key or {}
|
||||
tasks = [(str(s.get("sectionTitle") or ""), _content_for_section(s, content_by_key)) for s in sections]
|
||||
if not tasks:
|
||||
return []
|
||||
|
||||
use_llm = bool(getattr(settings, "DECLARATION_USE_LLM", True)) and llm_configured()
|
||||
if not use_llm:
|
||||
return [_fallback_declaration(title) for title, _ in tasks]
|
||||
|
||||
max_workers = max(int(getattr(settings, "TEMPLATE_UPLOAD_LLM_MAX_WORKERS", 8) or 8), 1)
|
||||
results: list[str] = [""] * len(tasks)
|
||||
|
||||
if len(tasks) == 1:
|
||||
results[0] = generate_declaration(*tasks[0])
|
||||
return results
|
||||
|
||||
workers = min(max_workers, len(tasks))
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_idx = {
|
||||
executor.submit(generate_declaration, title, content): i
|
||||
for i, (title, content) in enumerate(tasks)
|
||||
}
|
||||
for fut in as_completed(future_to_idx):
|
||||
idx = future_to_idx[fut]
|
||||
try:
|
||||
results[idx] = fut.result()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("章节声明并发生成失败,使用兜底 | idx=%s | err=%s", idx, e)
|
||||
results[idx] = _fallback_declaration(tasks[idx][0])
|
||||
logger.info("章节声明生成 | 章节=%s | 线程=%s", len(tasks), workers)
|
||||
return results
|
||||
80
services/desensitize_service.py
Normal file
80
services/desensitize_service.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""
|
||||
services/desensitize_service.py
|
||||
章节内容脱敏:把范文正文中的"精确数据"过滤掉,得到可复用的模板化内容。
|
||||
|
||||
规则(默认):
|
||||
- 阿拉伯数字串(含小数、千分位、全角数字)→ 占位符(默认 "X"),
|
||||
如 "总投资10.5亿元" → "总投资X亿元"、"2020年3月" → "X年X月"、"85.3%" → "X%"。
|
||||
- 标题行(以 # 开头)整行保留,不动章节编号/标题。
|
||||
- 行首的列表/枚举序号(如 "1)" "1." "(2)")保留,仅脱敏正文中的数字。
|
||||
- 单位与符号(万元/亿元/%/吨/年 等)保留,仅去掉其中的精确数值。
|
||||
|
||||
可通过 config 调整占位符、是否脱敏表格数字、是否启用。
|
||||
中文数字(一二三…)通常用于序数/层级,默认保留。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 阿拉伯数字(含全角)串,允许小数点/千分位分隔
|
||||
_NUMBER_RE = re.compile(r"[0-90-9]+(?:[..,,][0-90-9]+)*")
|
||||
|
||||
# 行首枚举序号:1) / 1. / (2) / 2、 等(这些是结构标记,保留)
|
||||
_LEADING_ENUM_RE = re.compile(r"^(\s*(?:[((]\s*[0-90-9]+\s*[))]|[0-90-9]+\s*[)).、.]))")
|
||||
|
||||
_HEADING_RE = re.compile(r"^\s*#{1,6}\s")
|
||||
_TABLE_ROW_RE = re.compile(r"^\s*\|.*\|\s*$")
|
||||
_TABLE_SEP_RE = re.compile(r"^\s*\|?[\s:\-|]+\|?\s*$")
|
||||
|
||||
|
||||
def _mask_numbers(segment: str, placeholder: str) -> str:
|
||||
return _NUMBER_RE.sub(placeholder, segment)
|
||||
|
||||
|
||||
def _desensitize_line(line: str, placeholder: str, mask_table_numbers: bool) -> str:
|
||||
# 标题行整行保留(不动章节编号/标题)
|
||||
if _HEADING_RE.match(line):
|
||||
return line
|
||||
|
||||
# 表格行
|
||||
if _TABLE_ROW_RE.match(line):
|
||||
if _TABLE_SEP_RE.match(line): # 分隔行 |---|---|
|
||||
return line
|
||||
if not mask_table_numbers:
|
||||
return line
|
||||
return _mask_numbers(line, placeholder)
|
||||
|
||||
# 普通正文:保留行首枚举序号,仅脱敏其余部分
|
||||
m = _LEADING_ENUM_RE.match(line)
|
||||
if m:
|
||||
prefix = m.group(1)
|
||||
rest = line[len(prefix):]
|
||||
return prefix + _mask_numbers(rest, placeholder)
|
||||
|
||||
return _mask_numbers(line, placeholder)
|
||||
|
||||
|
||||
def desensitize_content(text: str) -> str:
|
||||
"""对单个章节正文脱敏。未启用时原样返回。"""
|
||||
if not text:
|
||||
return text
|
||||
if not bool(getattr(settings, "DESENSITIZE_ENABLED", True)):
|
||||
return text
|
||||
|
||||
placeholder = str(getattr(settings, "DESENSITIZE_PLACEHOLDER", "X") or "X")
|
||||
mask_table = bool(getattr(settings, "DESENSITIZE_MASK_TABLE_NUMBERS", True))
|
||||
|
||||
lines = text.splitlines()
|
||||
out = [_desensitize_line(ln, placeholder, mask_table) for ln in lines]
|
||||
return "\n".join(out)
|
||||
|
||||
|
||||
def count_masked_numbers(original: str, filtered: str) -> int:
|
||||
"""粗略统计脱敏掉的数字串数量(用于日志)。"""
|
||||
return len(_NUMBER_RE.findall(original or "")) - len(_NUMBER_RE.findall(filtered or ""))
|
||||
194
services/file_parse_client.py
Normal file
194
services/file_parse_client.py
Normal file
@ -0,0 +1,194 @@
|
||||
"""
|
||||
services/file_parse_client.py
|
||||
调用远程解析服务(默认 http://192.168.4.194:8000/convert):
|
||||
上传文件(multipart:文件字段默认 "file",并附带 engine=auto 表单字段)→ 返回 Markdown。
|
||||
|
||||
响应解析:JSON 中按 results / md_content / mdcontent / markdown / content
|
||||
逐层提取;若响应非 JSON 则整体作为 Markdown 返回。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib import error as urlerror
|
||||
from urllib import request as urlrequest
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MD_CONTENT_KEYS = ("md_content", "mdcontent", "markdown", "content")
|
||||
|
||||
|
||||
class FileParseApiError(RuntimeError):
|
||||
def __init__(self, message: str, *, status_code: int | None = None, api_url: str = "") -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.api_url = api_url
|
||||
|
||||
|
||||
def _build_multipart_body(
|
||||
file_path: Path,
|
||||
field_name: str,
|
||||
extra_fields: dict[str, str] | None = None,
|
||||
) -> tuple[bytes, str]:
|
||||
boundary = uuid.uuid4().hex
|
||||
mime_type = mimetypes.guess_type(file_path.name)[0] or "application/octet-stream"
|
||||
file_bytes = file_path.read_bytes()
|
||||
|
||||
parts: list[bytes] = []
|
||||
# 普通表单字段(如 engine=auto)
|
||||
for key, value in (extra_fields or {}).items():
|
||||
if value is None or str(value).strip() == "":
|
||||
continue
|
||||
parts.append(f"--{boundary}\r\n".encode("utf-8"))
|
||||
parts.append(
|
||||
f'Content-Disposition: form-data; name="{key}"\r\n\r\n'.encode("utf-8")
|
||||
)
|
||||
parts.append(f"{value}\r\n".encode("utf-8"))
|
||||
|
||||
# 文件字段
|
||||
parts.append(f"--{boundary}\r\n".encode("utf-8"))
|
||||
parts.append(
|
||||
(
|
||||
f'Content-Disposition: form-data; name="{field_name}"; filename="{file_path.name}"\r\n'
|
||||
f"Content-Type: {mime_type}\r\n\r\n"
|
||||
).encode("utf-8")
|
||||
)
|
||||
parts.append(file_bytes)
|
||||
parts.append(f"\r\n--{boundary}--\r\n".encode("utf-8"))
|
||||
return b"".join(parts), boundary
|
||||
|
||||
|
||||
def _extract_md_contents(payload: Any) -> list[str]:
|
||||
if isinstance(payload, str):
|
||||
return [payload]
|
||||
if isinstance(payload, list):
|
||||
out: list[str] = []
|
||||
for item in payload:
|
||||
out.extend(_extract_md_contents(item))
|
||||
return out
|
||||
if not isinstance(payload, dict):
|
||||
return []
|
||||
|
||||
for key in MD_CONTENT_KEYS:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
|
||||
results = payload.get("results")
|
||||
if results is not None:
|
||||
return _extract_md_contents(results)
|
||||
|
||||
out = []
|
||||
for value in payload.values():
|
||||
out.extend(_extract_md_contents(value))
|
||||
return out
|
||||
|
||||
|
||||
def _response_to_markdown(text: str) -> str:
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
# 非 JSON 直接当作 Markdown 返回
|
||||
return text
|
||||
contents = _extract_md_contents(payload)
|
||||
if not contents:
|
||||
raise ValueError("解析服务响应中未找到 md_content/markdown/content 字段")
|
||||
return "\n\n".join(c.strip() for c in contents if c and c.strip())
|
||||
|
||||
|
||||
def _request_once(
|
||||
api_url: str,
|
||||
file_path: Path,
|
||||
field_name: str,
|
||||
*,
|
||||
timeout_sec: int,
|
||||
extra_fields: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
body, boundary = _build_multipart_body(file_path, field_name, extra_fields)
|
||||
req = urlrequest.Request(
|
||||
api_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"content-type": f"multipart/form-data; boundary={boundary}"},
|
||||
)
|
||||
try:
|
||||
with urlrequest.urlopen(req, timeout=timeout_sec) as resp:
|
||||
raw = resp.read()
|
||||
encoding = resp.headers.get_content_charset() or "utf-8"
|
||||
return raw.decode(encoding, errors="replace")
|
||||
except urlerror.HTTPError as exc:
|
||||
body_text = ""
|
||||
try:
|
||||
body_text = (exc.read() or b"").decode("utf-8", errors="replace")[:1000]
|
||||
except Exception:
|
||||
pass
|
||||
raise FileParseApiError(
|
||||
f"解析服务 HTTP {exc.code}({api_url}):{body_text or exc.reason}",
|
||||
status_code=int(exc.code or 0),
|
||||
api_url=api_url,
|
||||
) from exc
|
||||
except urlerror.URLError as exc:
|
||||
raise FileParseApiError(
|
||||
f"无法连接解析服务({api_url}):{exc.reason}",
|
||||
status_code=0,
|
||||
api_url=api_url,
|
||||
) from exc
|
||||
|
||||
|
||||
def parse_file_to_markdown(file_path: str | Path) -> str:
|
||||
"""
|
||||
将上传文件通过远程 file_parse 服务转换为 Markdown。
|
||||
失败时对 5xx 做有限重试。
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"文件不存在: {path}")
|
||||
|
||||
api_url = str(settings.FILE_PARSE_API_URL or "").strip()
|
||||
if not api_url:
|
||||
raise ValueError("FILE_PARSE_API_URL 未配置")
|
||||
|
||||
field_name = str(settings.FILE_PARSE_FIELD_NAME or "file").strip() or "file"
|
||||
timeout_sec = max(int(settings.FILE_PARSE_HTTP_TIMEOUT_SEC or 600), 30)
|
||||
retry_count = max(int(settings.FILE_PARSE_RETRY_COUNT or 1), 1)
|
||||
backoff_sec = max(float(settings.FILE_PARSE_RETRY_BACKOFF_SEC or 1.0), 1.0)
|
||||
retryable_status = {500, 502, 503, 504}
|
||||
|
||||
extra_fields: dict[str, str] = {}
|
||||
engine = str(getattr(settings, "FILE_PARSE_ENGINE", "") or "").strip()
|
||||
if engine:
|
||||
extra_fields["engine"] = engine
|
||||
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(1, retry_count + 1):
|
||||
try:
|
||||
raw = _request_once(
|
||||
api_url, path, field_name, timeout_sec=timeout_sec, extra_fields=extra_fields
|
||||
)
|
||||
markdown = _response_to_markdown(raw)
|
||||
if not markdown.strip():
|
||||
raise ValueError("解析服务返回的 Markdown 为空")
|
||||
return markdown
|
||||
except FileParseApiError as exc:
|
||||
last_error = exc
|
||||
status = int(exc.status_code or 0)
|
||||
if attempt >= retry_count or status not in retryable_status:
|
||||
raise
|
||||
wait = backoff_sec * attempt
|
||||
logger.warning(
|
||||
"file_parse 重试 %s/%s status=%s wait=%ss file=%s",
|
||||
attempt, retry_count, status, wait, path.name,
|
||||
)
|
||||
time.sleep(wait)
|
||||
|
||||
if last_error:
|
||||
raise last_error
|
||||
raise RuntimeError("file_parse 请求失败")
|
||||
@ -1,724 +1,118 @@
|
||||
"""
|
||||
services/llm_client.py
|
||||
极简 OpenAI 兼容 Chat Completions 客户端(仅用于生成章节声明,可选)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from requests import RequestException
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from config import settings
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
# 生成全过程追踪:完整记录输入 prompt / 调用模型 / 模型输出,写入 logs/generation_trace.log
|
||||
_trace_logger = logging.getLogger("generation.trace")
|
||||
|
||||
_LLM_MAX_CONCURRENCY = 5
|
||||
_llm_slots = threading.BoundedSemaphore(_LLM_MAX_CONCURRENCY)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _RetryableLLMError(RuntimeError):
|
||||
"""用于标记可安全重试的 LLM 调用异常。"""
|
||||
|
||||
|
||||
class _ContentFieldStreamExtractor:
|
||||
"""从流式 JSON 文本中增量提取 content 字段的已解码正文。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._raw = ""
|
||||
self._content_started = False
|
||||
self._content_done = False
|
||||
self._value_start = -1
|
||||
self._consumed_pos = 0
|
||||
|
||||
def feed(self, chunk: str) -> tuple[str, bool]:
|
||||
if not chunk:
|
||||
return "", False
|
||||
self._raw += chunk
|
||||
emitted = ""
|
||||
done_now = False
|
||||
|
||||
if not self._content_started:
|
||||
marker = '"content"'
|
||||
idx = self._raw.find(marker)
|
||||
if idx == -1:
|
||||
return "", False
|
||||
colon = self._raw.find(":", idx + len(marker))
|
||||
if colon == -1:
|
||||
return "", False
|
||||
quote = self._raw.find('"', colon + 1)
|
||||
if quote == -1:
|
||||
return "", False
|
||||
self._content_started = True
|
||||
self._value_start = quote + 1
|
||||
self._consumed_pos = self._value_start
|
||||
|
||||
if self._content_started and not self._content_done:
|
||||
emitted, consumed_pos, done_now = self._decode_partial_json_string(
|
||||
self._raw,
|
||||
self._consumed_pos,
|
||||
def llm_configured() -> bool:
|
||||
return bool(
|
||||
str(settings.LLM_API_BASE or "").strip()
|
||||
and str(settings.LLM_API_KEY or "").strip()
|
||||
and str(settings.LLM_MODEL_NAME or "").strip()
|
||||
)
|
||||
self._consumed_pos = consumed_pos
|
||||
if done_now:
|
||||
self._content_done = True
|
||||
return emitted, done_now
|
||||
|
||||
@staticmethod
|
||||
def _decode_partial_json_string(src: str, start: int) -> tuple[str, int, bool]:
|
||||
out: list[str] = []
|
||||
i = start
|
||||
n = len(src)
|
||||
while i < n:
|
||||
ch = src[i]
|
||||
if ch == '"':
|
||||
if i == 0 or src[i - 1] != "\\" or _ContentFieldStreamExtractor._is_escaped(src, i):
|
||||
return "".join(out), i, True
|
||||
if ch != "\\":
|
||||
out.append(ch)
|
||||
i += 1
|
||||
continue
|
||||
if i + 1 >= n:
|
||||
break
|
||||
esc = src[i + 1]
|
||||
mapping = {
|
||||
'"': '"',
|
||||
"\\": "\\",
|
||||
"/": "/",
|
||||
"b": "\b",
|
||||
"f": "\f",
|
||||
"n": "\n",
|
||||
"r": "\r",
|
||||
"t": "\t",
|
||||
}
|
||||
if esc == "u":
|
||||
if i + 6 > n:
|
||||
break
|
||||
hex_part = src[i + 2 : i + 6]
|
||||
try:
|
||||
out.append(chr(int(hex_part, 16)))
|
||||
except Exception:
|
||||
pass
|
||||
i += 6
|
||||
continue
|
||||
if esc in mapping:
|
||||
out.append(mapping[esc])
|
||||
i += 2
|
||||
continue
|
||||
out.append(esc)
|
||||
i += 2
|
||||
return "".join(out), i, False
|
||||
|
||||
@staticmethod
|
||||
def _is_escaped(src: str, quote_index: int) -> bool:
|
||||
backslashes = 0
|
||||
i = quote_index - 1
|
||||
while i >= 0 and src[i] == "\\":
|
||||
backslashes += 1
|
||||
i -= 1
|
||||
return backslashes % 2 == 0
|
||||
|
||||
|
||||
def _format_exc_raw(e: Exception) -> str:
|
||||
"""统一输出最直接的异常原文(类型 + repr)。"""
|
||||
return f"{type(e).__name__}: {e!r}"
|
||||
_THINK_BLOCK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)
|
||||
|
||||
|
||||
def _chat_completions_stream_text(
|
||||
def _strip_reasoning(text: str) -> str:
|
||||
"""去掉思考模型的思维链:成对 <think>…</think>,以及截断/前导的残留标签。"""
|
||||
s = text or ""
|
||||
s = _THINK_BLOCK_RE.sub("", s)
|
||||
# 仅剩结束标签时,说明前面是未配对的思考段,取最后一个 </think> 之后的正文
|
||||
if "</think>" in s:
|
||||
s = s.rsplit("</think>", 1)[-1]
|
||||
s = re.sub(r"</?think>", "", s, flags=re.IGNORECASE)
|
||||
return s.strip()
|
||||
|
||||
|
||||
def chat_completion_text(
|
||||
*,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
extra_payload: dict[str, Any],
|
||||
connect_timeout_sec: int,
|
||||
read_timeout_sec: int = 300,
|
||||
on_content_delta: Optional[callable] = None,
|
||||
temperature: float = 0.2,
|
||||
max_tokens: int = 512,
|
||||
timeout_sec: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
以 OpenAI-compat SSE 流式读取模型输出文本。
|
||||
- connect timeout 保留,避免连接阶段长时间卡死
|
||||
- read timeout 防止流式读取无限挂起(默认 300s)
|
||||
"""
|
||||
_logger.info(
|
||||
"LLM 流式调用开始 | model=%s | temperature=%s | max_tokens=%s | timeout_connect=%s timeout_read=%s",
|
||||
model_name, temperature, max_tokens, connect_timeout_sec, read_timeout_sec,
|
||||
)
|
||||
resp = requests.post(
|
||||
f"{api_base}/chat/completions",
|
||||
"""调用 LLM 返回纯文本。失败抛出异常,由调用方决定是否兜底。"""
|
||||
base = str(settings.LLM_API_BASE or "").strip().rstrip("/")
|
||||
url = f"{base}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Authorization": f"Bearer {settings.LLM_API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": model_name,
|
||||
}
|
||||
payload = {
|
||||
"model": settings.LLM_MODEL_NAME,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"response_format": {"type": "json_object"},
|
||||
"stream": True,
|
||||
**extra_payload,
|
||||
},
|
||||
stream=True,
|
||||
timeout=(connect_timeout_sec, max(60, read_timeout_sec)),
|
||||
}
|
||||
# 关闭思考模型的思维链(vLLM/Qwen3 等支持该扩展字段;不支持的服务会忽略)
|
||||
if bool(getattr(settings, "LLM_DISABLE_THINKING", False)):
|
||||
payload["chat_template_kwargs"] = {"enable_thinking": False}
|
||||
resp = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
|
||||
timeout=timeout_sec or int(settings.LLM_HTTP_TIMEOUT_SEC or 120),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return _strip_reasoning((data["choices"][0]["message"]["content"] or "").strip())
|
||||
|
||||
if resp.status_code in (408, 429, 500, 502, 503, 504):
|
||||
raise _RetryableLLMError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}")
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
|
||||
|
||||
resp.encoding = "utf-8"
|
||||
chunks: list[str] = []
|
||||
extractor = _ContentFieldStreamExtractor()
|
||||
def _extract_json(text: str) -> dict:
|
||||
"""从模型输出中解析 JSON object(容忍 ```json``` 代码块包裹)。"""
|
||||
s = (text or "").strip()
|
||||
if s.startswith("```"):
|
||||
s = re.sub(r"^```[a-zA-Z]*\s*", "", s)
|
||||
s = re.sub(r"\s*```$", "", s).strip()
|
||||
try:
|
||||
for line in resp.iter_lines(decode_unicode=True):
|
||||
if not line:
|
||||
continue
|
||||
s = line.strip()
|
||||
if not s.startswith("data:"):
|
||||
continue
|
||||
payload = s[5:].strip()
|
||||
if not payload or payload == "[DONE]":
|
||||
break
|
||||
obj = json.loads(s)
|
||||
except json.JSONDecodeError:
|
||||
m = re.search(r"\{.*\}", s, flags=re.DOTALL)
|
||||
if not m:
|
||||
return {}
|
||||
try:
|
||||
obj = json.loads(payload)
|
||||
except Exception:
|
||||
continue
|
||||
choices = obj.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
continue
|
||||
first = choices[0] if isinstance(choices[0], dict) else {}
|
||||
delta = first.get("delta") if isinstance(first.get("delta"), dict) else {}
|
||||
content = delta.get("content")
|
||||
if isinstance(content, str) and content:
|
||||
chunks.append(content)
|
||||
# print(content, end="", flush=True)
|
||||
if on_content_delta:
|
||||
delta_text, done_now = extractor.feed(content)
|
||||
if delta_text:
|
||||
try:
|
||||
on_content_delta("delta", delta_text)
|
||||
except Exception:
|
||||
pass
|
||||
if done_now:
|
||||
try:
|
||||
on_content_delta("finalizing", "")
|
||||
except Exception:
|
||||
pass
|
||||
# 兼容部分实现把最终结果放在 message.content
|
||||
message = first.get("message") if isinstance(first.get("message"), dict) else {}
|
||||
msg_content = message.get("content")
|
||||
if isinstance(msg_content, str) and msg_content:
|
||||
chunks.append(msg_content)
|
||||
# print(msg_content, end="", flush=True)
|
||||
if on_content_delta:
|
||||
delta_text, done_now = extractor.feed(msg_content)
|
||||
if delta_text:
|
||||
try:
|
||||
on_content_delta("delta", delta_text)
|
||||
except Exception:
|
||||
pass
|
||||
if done_now:
|
||||
try:
|
||||
on_content_delta("finalizing", "")
|
||||
except Exception:
|
||||
pass
|
||||
except ChunkedEncodingError as e:
|
||||
partial_text = "".join(chunks).strip()
|
||||
# 若流提前结束但已收到完整 JSON,直接使用,避免无谓重试失败。
|
||||
if partial_text:
|
||||
try:
|
||||
parse_json_object_from_text(partial_text)
|
||||
return partial_text
|
||||
except Exception:
|
||||
pass
|
||||
raise _RetryableLLMError(f"LLM 流中断: {_format_exc_raw(e)}") from e
|
||||
|
||||
text = "".join(chunks).strip()
|
||||
if not text:
|
||||
raise _RetryableLLMError("LLM 返回空内容")
|
||||
print()
|
||||
return text
|
||||
obj = json.loads(m.group(0))
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return obj if isinstance(obj, dict) else {}
|
||||
|
||||
|
||||
def chat_completions_json(
|
||||
*,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float = 0.2,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
timeout_sec: int = 180,
|
||||
on_content_delta: Optional[callable] = None,
|
||||
log_context: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
统一的 OpenAI-compat chat/completions 调用,强制返回 JSON object。
|
||||
复用项目现有配置:LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME。
|
||||
|
||||
log_context: 调用来源标签(如章节编号),用于在 generation_trace.log 中区分各次生成调用。
|
||||
"""
|
||||
api_base = (settings.LLM_API_BASE or "").rstrip("/")
|
||||
api_key = settings.LLM_API_KEY or ""
|
||||
model_name = settings.LLM_MODEL_NAME or ""
|
||||
|
||||
if not api_base or not api_key or not model_name:
|
||||
raise RuntimeError("LLM 未配置:请设置 LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME")
|
||||
|
||||
ctx = log_context or "-"
|
||||
_trace_logger.info(
|
||||
"[输入] context=%s | model=%s | temperature=%s | max_tokens=%s\n"
|
||||
"----- SYSTEM PROMPT -----\n%s\n"
|
||||
"----- USER PROMPT -----\n%s\n"
|
||||
"----- END INPUT -----",
|
||||
ctx, model_name, temperature, max_tokens, system_prompt, user_prompt,
|
||||
)
|
||||
|
||||
extra_payload: dict[str, Any] = {}
|
||||
# SiliconFlow 的部分 Qwen 模型默认把输出写到 reasoning_content,导致 content 为空;
|
||||
# 显式关闭 thinking,确保最终输出进入 content,避免下游解析失败。
|
||||
if "siliconflow" in api_base.lower() and "qwen" in model_name.lower():
|
||||
extra_payload["enable_thinking"] = False
|
||||
|
||||
final_timeout_sec = int(timeout_sec or 0)
|
||||
if final_timeout_sec <= 0:
|
||||
final_timeout_sec = int(getattr(settings, "LLM_HTTP_TIMEOUT_SEC", 90) or 90)
|
||||
retry_count = int(getattr(settings, "LLM_RETRY_COUNT", 2) or 2)
|
||||
if retry_count < 1:
|
||||
retry_count = 1
|
||||
retry_backoff = float(getattr(settings, "LLM_RETRY_BACKOFF_SEC", 1.0) or 1.0)
|
||||
retry_backoff_max = float(getattr(settings, "LLM_RETRY_BACKOFF_MAX_SEC", 12.0) or 12.0)
|
||||
connect_timeout_sec = int(getattr(settings, "LLM_CONNECT_TIMEOUT_SEC", 20) or 20)
|
||||
if connect_timeout_sec <= 0:
|
||||
connect_timeout_sec = 20
|
||||
use_stream = True
|
||||
|
||||
_logger.info(
|
||||
"chat_completions_json 调用 | model=%s | temperature=%s | max_tokens=%s | timeout=%s | retry=%s",
|
||||
model_name, temperature, max_tokens, final_timeout_sec, retry_count,
|
||||
)
|
||||
|
||||
with _llm_slots:
|
||||
last_err: Optional[Exception] = None
|
||||
for attempt in range(retry_count):
|
||||
timeout_sec: int | None = None,
|
||||
) -> dict:
|
||||
"""调用 LLM 并将返回解析为 JSON object(dict)。失败返回 {}。"""
|
||||
try:
|
||||
if use_stream:
|
||||
content = _chat_completions_stream_text(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
text = chat_completion_text(
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
extra_payload=extra_payload,
|
||||
connect_timeout_sec=connect_timeout_sec,
|
||||
read_timeout_sec=final_timeout_sec,
|
||||
on_content_delta=on_content_delta,
|
||||
timeout_sec=timeout_sec,
|
||||
)
|
||||
else:
|
||||
# 分离连接超时与读超时:长生成阶段只应占用「读」时间,避免与连接握手混在一个上限里过早超时
|
||||
read_timeout = max(int(connect_timeout_sec) + 5, int(final_timeout_sec))
|
||||
resp = requests.post(
|
||||
f"{api_base}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"response_format": {"type": "json_object"},
|
||||
**extra_payload,
|
||||
},
|
||||
timeout=(connect_timeout_sec, read_timeout),
|
||||
)
|
||||
if resp.status_code in (408, 429, 500, 502, 503, 504):
|
||||
raise _RetryableLLMError(
|
||||
f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}"
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
|
||||
|
||||
data = resp.json()
|
||||
content = (
|
||||
(data.get("choices") or [{}])[0]
|
||||
.get("message", {})
|
||||
.get("content", "")
|
||||
)
|
||||
if not isinstance(content, str) or not content.strip():
|
||||
raise _RetryableLLMError("LLM 返回空内容")
|
||||
try:
|
||||
obj = parse_json_object_from_text(content)
|
||||
_logger.info(
|
||||
"chat_completions_json 成功 | model=%s | attempt=%d/%d | content_len=%d | keys=%s",
|
||||
model_name, attempt + 1, retry_count, len(content), list(obj.keys())[:8],
|
||||
)
|
||||
_trace_logger.info(
|
||||
"[输出] context=%s | model=%s | attempt=%d/%d | output_len=%d\n"
|
||||
"----- MODEL OUTPUT -----\n%s\n"
|
||||
"----- END OUTPUT -----",
|
||||
ctx, model_name, attempt + 1, retry_count, len(content), content,
|
||||
)
|
||||
return obj
|
||||
except ValueError as e:
|
||||
raise _RetryableLLMError(f"LLM JSON 解析失败: {e}") from e
|
||||
except (
|
||||
requests.ReadTimeout,
|
||||
requests.ConnectTimeout,
|
||||
requests.ConnectionError,
|
||||
ChunkedEncodingError,
|
||||
) as e:
|
||||
last_err = e
|
||||
if attempt >= retry_count - 1:
|
||||
raise RuntimeError(
|
||||
"LLM 请求超时/连接失败"
|
||||
f"(已重试{retry_count}次,timeout={final_timeout_sec}s)"
|
||||
f",endpoint={api_base}/chat/completions"
|
||||
f",model={model_name}"
|
||||
f",raw={_format_exc_raw(e)}"
|
||||
) from e
|
||||
sleep_sec = min(retry_backoff * (2 ** attempt), retry_backoff_max)
|
||||
sleep_sec += random.uniform(0, min(0.5, sleep_sec * 0.2))
|
||||
time.sleep(sleep_sec)
|
||||
except _RetryableLLMError as e:
|
||||
last_err = e
|
||||
if attempt >= retry_count - 1:
|
||||
raise RuntimeError(
|
||||
f"{e}(已重试{retry_count}次,timeout={final_timeout_sec}s)"
|
||||
f",endpoint={api_base}/chat/completions"
|
||||
f",model={model_name}"
|
||||
f",raw={_format_exc_raw(e)}"
|
||||
) from e
|
||||
sleep_sec = min(retry_backoff * (2 ** attempt), retry_backoff_max)
|
||||
sleep_sec += random.uniform(0, min(0.5, sleep_sec * 0.2))
|
||||
time.sleep(sleep_sec)
|
||||
except RequestException as e:
|
||||
resp = getattr(e, "response", None)
|
||||
status = getattr(resp, "status_code", None)
|
||||
body = ""
|
||||
if resp is not None:
|
||||
try:
|
||||
body = (resp.text or "")[:800]
|
||||
except Exception:
|
||||
body = ""
|
||||
raise RuntimeError(
|
||||
"LLM 请求失败"
|
||||
f",endpoint={api_base}/chat/completions"
|
||||
f",model={model_name}"
|
||||
f",status={status}"
|
||||
+ (f",body={body}" if body else "")
|
||||
+ f",raw={_format_exc_raw(e)}"
|
||||
) from e
|
||||
else:
|
||||
raw = _format_exc_raw(last_err) if isinstance(last_err, Exception) else str(last_err)
|
||||
raise RuntimeError(
|
||||
"LLM 请求失败"
|
||||
f",endpoint={api_base}/chat/completions"
|
||||
f",model={model_name}"
|
||||
f",raw={raw}"
|
||||
)
|
||||
|
||||
|
||||
def _repair_loose_json_object(s: str) -> str:
|
||||
"""常见模型输出问题:尾随逗号(, 后紧跟 } 或 ])。"""
|
||||
return re.sub(r",(\s*[}\]])", r"\1", s)
|
||||
|
||||
|
||||
def _extract_balanced_json_prefix(s: str) -> str:
|
||||
"""
|
||||
提取以 `{` 开始的最长“可能完整”的 JSON 对象前缀。
|
||||
会忽略字符串内的花括号,避免误判。
|
||||
"""
|
||||
start = s.find("{")
|
||||
if start == -1:
|
||||
return s
|
||||
in_string = False
|
||||
escaped = False
|
||||
depth = 0
|
||||
end_idx = -1
|
||||
for i, ch in enumerate(s[start:], start=start):
|
||||
if in_string:
|
||||
if escaped:
|
||||
escaped = False
|
||||
elif ch == "\\":
|
||||
escaped = True
|
||||
elif ch == '"':
|
||||
in_string = False
|
||||
continue
|
||||
if ch == '"':
|
||||
in_string = True
|
||||
elif ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
end_idx = i
|
||||
break
|
||||
if end_idx != -1:
|
||||
return s[start : end_idx + 1]
|
||||
return s[start:]
|
||||
|
||||
|
||||
def _close_truncated_json_object(s: str) -> str:
|
||||
"""
|
||||
处理模型截断导致的 JSON 残缺:
|
||||
- 若字符串未闭合,补一个 `"`
|
||||
- 按栈补齐缺失的 `}` / `]`
|
||||
"""
|
||||
out: list[str] = []
|
||||
stack: list[str] = []
|
||||
in_string = False
|
||||
escaped = False
|
||||
|
||||
for ch in s:
|
||||
out.append(ch)
|
||||
if in_string:
|
||||
if escaped:
|
||||
escaped = False
|
||||
elif ch == "\\":
|
||||
escaped = True
|
||||
elif ch == '"':
|
||||
in_string = False
|
||||
continue
|
||||
if ch == '"':
|
||||
in_string = True
|
||||
continue
|
||||
if ch == "{":
|
||||
stack.append("}")
|
||||
elif ch == "[":
|
||||
stack.append("]")
|
||||
elif ch in ("}", "]"):
|
||||
if stack and stack[-1] == ch:
|
||||
stack.pop()
|
||||
|
||||
if in_string:
|
||||
out.append('"')
|
||||
while stack:
|
||||
out.append(stack.pop())
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def parse_json_object_from_text(text: str) -> dict[str, Any]:
|
||||
"""从模型输出里提取并解析 { ... } JSON 对象。"""
|
||||
s = (text or "").strip()
|
||||
s = re.sub(r"```(?:json)?", "", s, flags=re.IGNORECASE).replace("```", "").strip()
|
||||
start = s.find("{")
|
||||
if start == -1:
|
||||
raise ValueError("未找到 JSON 对象")
|
||||
chunk = s[start:]
|
||||
balanced_chunk = _extract_balanced_json_prefix(chunk)
|
||||
decoder = json.JSONDecoder()
|
||||
last_err: Optional[Exception] = None
|
||||
for candidate in (
|
||||
balanced_chunk,
|
||||
_repair_loose_json_object(balanced_chunk),
|
||||
_close_truncated_json_object(_repair_loose_json_object(balanced_chunk)),
|
||||
_close_truncated_json_object(_repair_loose_json_object(chunk)),
|
||||
):
|
||||
try:
|
||||
obj, _ = decoder.raw_decode(candidate)
|
||||
if not isinstance(obj, dict):
|
||||
raise ValueError("JSON 根节点不是对象(dict)")
|
||||
return obj
|
||||
except json.JSONDecodeError as e:
|
||||
last_err = e
|
||||
raise ValueError(f"JSON 解析失败:{last_err}") from last_err
|
||||
|
||||
|
||||
def safe_get_str(v: Any) -> Optional[str]:
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
return s if s else None
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Agent 多轮对话 + Tool Calling(流式生成器)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _iter_chat_stream_events(
|
||||
*,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 4096,
|
||||
extra_payload: dict[str, Any] | None = None,
|
||||
connect_timeout_sec: int = 20,
|
||||
read_timeout_sec: int = 300,
|
||||
):
|
||||
"""
|
||||
流式调用 OpenAI-compat /chat/completions,逐步 yield 事件:
|
||||
("delta", str) — 文本增量
|
||||
("tool_calls", list) — 完整 tool_calls 列表 [{id, function:{name, arguments}}]
|
||||
("done", dict) — 最终 usage 等元信息
|
||||
"""
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**(extra_payload or {}),
|
||||
}
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = "auto"
|
||||
|
||||
resp = requests.post(
|
||||
f"{api_base}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=(connect_timeout_sec, max(60, read_timeout_sec)),
|
||||
)
|
||||
|
||||
if resp.status_code in (408, 429, 500, 502, 503, 504):
|
||||
raise _RetryableLLMError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}")
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
|
||||
|
||||
resp.encoding = "utf-8"
|
||||
content_parts: list[str] = []
|
||||
tool_calls_map: dict[int, dict] = {}
|
||||
|
||||
for line in resp.iter_lines(decode_unicode=True):
|
||||
if not line:
|
||||
continue
|
||||
s = line.strip()
|
||||
if not s.startswith("data:"):
|
||||
continue
|
||||
data_str = s[5:].strip()
|
||||
if not data_str or data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
obj = json.loads(data_str)
|
||||
except Exception:
|
||||
continue
|
||||
choices = obj.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
continue
|
||||
first = choices[0] if isinstance(choices[0], dict) else {}
|
||||
delta = first.get("delta") if isinstance(first.get("delta"), dict) else {}
|
||||
|
||||
# 仅输出正文 content;忽略 reasoning_content,避免思考过程展示给用户
|
||||
_ = delta.get("reasoning_content")
|
||||
|
||||
content = delta.get("content")
|
||||
if isinstance(content, str) and content:
|
||||
content_parts.append(content)
|
||||
yield ("delta", content)
|
||||
|
||||
# tool_calls delta (streamed incrementally)
|
||||
tc_deltas = delta.get("tool_calls")
|
||||
if isinstance(tc_deltas, list):
|
||||
for tc in tc_deltas:
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
idx = tc.get("index", 0)
|
||||
if idx not in tool_calls_map:
|
||||
tool_calls_map[idx] = {
|
||||
"id": tc.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
entry = tool_calls_map[idx]
|
||||
if tc.get("id"):
|
||||
entry["id"] = tc["id"]
|
||||
fn = tc.get("function") if isinstance(tc.get("function"), dict) else {}
|
||||
if fn.get("name"):
|
||||
entry["function"]["name"] += fn["name"]
|
||||
if fn.get("arguments"):
|
||||
entry["function"]["arguments"] += fn["arguments"]
|
||||
|
||||
# finish_reason
|
||||
finish = first.get("finish_reason")
|
||||
if finish == "tool_calls" and tool_calls_map:
|
||||
ordered = [tool_calls_map[k] for k in sorted(tool_calls_map.keys())]
|
||||
yield ("tool_calls", ordered)
|
||||
tool_calls_map = {}
|
||||
|
||||
if tool_calls_map:
|
||||
ordered = [tool_calls_map[k] for k in sorted(tool_calls_map.keys())]
|
||||
yield ("tool_calls", ordered)
|
||||
|
||||
yield ("done", {"content": "".join(content_parts)})
|
||||
|
||||
|
||||
def _default_disable_thinking_payload(model_name: str) -> dict[str, Any]:
|
||||
"""Qwen 等推理模型:关闭 thinking,仅将最终答案写入 content。"""
|
||||
if not model_name or "qwen" not in str(model_name).lower():
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("chat_completions_json 调用失败: %s", e)
|
||||
return {}
|
||||
return {
|
||||
"enable_thinking": False,
|
||||
# vLLM / 部分 OpenAI 兼容网关使用 chat_template_kwargs
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
}
|
||||
|
||||
|
||||
def chat_completions_with_tools(
|
||||
*,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 4096,
|
||||
timeout_sec: int = 180,
|
||||
extra_payload: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Agent 用多轮对话 + tool calling。返回生成器,yield 事件元组。
|
||||
调用方负责工具循环编排。
|
||||
"""
|
||||
api_base = (settings.LLM_API_BASE or "").rstrip("/")
|
||||
api_key = settings.LLM_API_KEY or ""
|
||||
model_name = settings.LLM_MODEL_NAME or ""
|
||||
|
||||
if not api_base or not api_key or not model_name:
|
||||
raise RuntimeError("LLM 未配置:请设置 LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME")
|
||||
|
||||
merged_extra: dict[str, Any] = dict(_default_disable_thinking_payload(model_name))
|
||||
if extra_payload:
|
||||
merged_extra.update(extra_payload)
|
||||
|
||||
connect_timeout_sec = int(getattr(settings, "LLM_CONNECT_TIMEOUT_SEC", 20) or 20)
|
||||
if connect_timeout_sec <= 0:
|
||||
connect_timeout_sec = 20
|
||||
|
||||
final_timeout_sec = int(timeout_sec or 0)
|
||||
if final_timeout_sec <= 0:
|
||||
final_timeout_sec = int(getattr(settings, "LLM_HTTP_TIMEOUT_SEC", 90) or 90)
|
||||
|
||||
with _llm_slots:
|
||||
yield from _iter_chat_stream_events(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
extra_payload=merged_extra or None,
|
||||
connect_timeout_sec=connect_timeout_sec,
|
||||
read_timeout_sec=final_timeout_sec,
|
||||
)
|
||||
return _extract_json(text)
|
||||
|
||||
406
services/section_extractor.py
Normal file
406
services/section_extractor.py
Normal file
@ -0,0 +1,406 @@
|
||||
"""
|
||||
services/section_extractor.py
|
||||
从 Markdown 中:
|
||||
1) 抽取目录(章节标题层级)-> 用于生成模板章节(目录)
|
||||
2) 按标题拆分正文 -> 每个章节的内容(用于入库 report_section_references)
|
||||
|
||||
抽取/过滤逻辑参考 eval_report/routers/template.py 与 routers/reference.py。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
_MAX_SECTION_TITLE_LEN = 200
|
||||
|
||||
|
||||
# ────────────────────────────── 通用过滤/清洗 ──────────────────────────────
|
||||
|
||||
|
||||
def _segment_looks_like_year(segment: str) -> bool:
|
||||
if not segment.isdigit() or len(segment) != 4:
|
||||
return False
|
||||
year = int(segment)
|
||||
return 1900 <= year <= 2099
|
||||
|
||||
|
||||
def _is_valid_section_number(num: str) -> bool:
|
||||
"""章节编号形如 1 / 1.1 / 2.3.4,排除正文年份(2017、2019 等)。"""
|
||||
parts = [p for p in str(num or "").strip().split(".") if p]
|
||||
if not parts or not all(p.isdigit() for p in parts):
|
||||
return False
|
||||
if any(_segment_looks_like_year(p) for p in parts):
|
||||
return False
|
||||
if len(parts) == 1:
|
||||
return 1 <= int(parts[0]) <= 20
|
||||
return all(1 <= int(p) <= 99 for p in parts)
|
||||
|
||||
|
||||
def _heading_title_core(rest: str) -> str:
|
||||
return re.sub(r"^\d+(?:\.\d+)*\s*", "", str(rest or "").strip()).strip()
|
||||
|
||||
|
||||
def _rest_looks_like_body_text(rest: str) -> bool:
|
||||
"""过滤日期句、长段落、数据说明句等被误识别为标题的正文。"""
|
||||
t = _heading_title_core(rest) or str(rest or "").strip()
|
||||
if not t:
|
||||
return True
|
||||
if re.match(r"^[月日]", t):
|
||||
return True
|
||||
if re.search(r"月\d", t):
|
||||
return True
|
||||
if re.match(r"^\d{4}\s*年", t) or re.match(r"^\d{4}[、,]", t):
|
||||
return True
|
||||
if re.search(r"\d{4}\s*[-~~—至]\s*\d{4}", t):
|
||||
return True
|
||||
if t.count("。") >= 2 or t.count(";") >= 2:
|
||||
return True
|
||||
if len(t) > 80 and re.search(r"[,。;:]", t):
|
||||
return True
|
||||
if len(t) > 45 and any(
|
||||
k in t
|
||||
for k in (
|
||||
"运营数据", "预测数据", "实际运营", "根据公司",
|
||||
"发展规划", "工况下", "万吨", "有项目", "无项目",
|
||||
)
|
||||
):
|
||||
return True
|
||||
if len(t) > 45 and not re.search(
|
||||
r"(评价|分析|结论|概况|说明|措施|建议|对比|控制|实现|状况|情况|程序|模式|评价结论)$",
|
||||
t.rstrip("。;,"),
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _looks_like_real_heading_title(title: str) -> bool:
|
||||
if not str(title or "").strip():
|
||||
return False
|
||||
return not _rest_looks_like_body_text(title)
|
||||
|
||||
|
||||
def _clean_heading_title(s: str) -> str:
|
||||
t = str(s or "").strip()
|
||||
t = re.sub(r"\s+", " ", t)
|
||||
t = re.sub(r"\s+\d+$", "", t).strip() # 去掉目录行尾页码
|
||||
m_note = re.search(r"[((]([^))]{20,})[))]", t)
|
||||
if m_note and re.search(r"[,。;:]", m_note.group(1)):
|
||||
t = re.sub(r"\s*[((][^))]{20,}[))]\s*$", "", t).strip()
|
||||
return t
|
||||
|
||||
|
||||
def _section_dict(section_key: str, section_title: str) -> dict:
|
||||
return {"sectionKey": section_key, "sectionTitle": section_title}
|
||||
|
||||
|
||||
def _canonical_to_section_key(canonical: str, order: int) -> str:
|
||||
return (
|
||||
re.sub(r"[^a-z0-9\u4e00-\u9fa5]+", "-", canonical).strip("-")
|
||||
or f"section-{order}"
|
||||
)
|
||||
|
||||
|
||||
def normalize_section_key(raw_key: str | None, title: str | None) -> str:
|
||||
"""生成稳定且可入库的 section_key(<=64),超长追加短哈希。"""
|
||||
base = (raw_key or "").strip().lower()
|
||||
if not base:
|
||||
base = (title or "").strip().lower()
|
||||
base = re.sub(r"[^a-z0-9\u4e00-\u9fa5]+", "-", base).strip("-")
|
||||
if not base:
|
||||
base = "section"
|
||||
if len(base) <= 64:
|
||||
return base
|
||||
digest = hashlib.md5(base.encode("utf-8")).hexdigest()[:10]
|
||||
prefix = base[:53].rstrip("-")
|
||||
return f"{prefix}-{digest}"
|
||||
|
||||
|
||||
# ────────────────────────────── 目录(TOC)抽取 ──────────────────────────────
|
||||
|
||||
|
||||
def _walk_markdown_heading_sections(text: str) -> list[dict]:
|
||||
"""
|
||||
单次遍历 Markdown,按标题(# ~ ######)切分章节并捕获正文(不含本节标题行)。
|
||||
标题层级自动编号(## 项目概况 -> 1.1 项目概况),无显式编号也可处理。
|
||||
被判定为"非真实标题"的 # 行视为正文内容,不另起章节。
|
||||
|
||||
正文范围:
|
||||
- 默认(SECTION_CONTENT_INCLUDE_SUBSECTIONS=True):聚合整棵子树,
|
||||
即本节标题之后、直到下一个"层级 <= 本节"的标题之前的全部内容
|
||||
(含下级小节标题与正文),保证父章节正文非空。
|
||||
- 关闭时:仅取到下一个任意标题之前(本节自身正文)。
|
||||
|
||||
返回每节:{number, title, full_title, canonical, section_key(canonical), level, content}
|
||||
目录抽取与正文拆分共用此函数,确保目录与内容一一对应。
|
||||
"""
|
||||
from config import settings
|
||||
|
||||
include_sub = bool(getattr(settings, "SECTION_CONTENT_INCLUDE_SUBSECTIONS", True))
|
||||
|
||||
lines = str(text or "").splitlines()
|
||||
counters: list[int] = []
|
||||
accepted: list[dict] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for idx, raw in enumerate(lines):
|
||||
m = re.match(r"^(#{1,6})\s+(.+)$", str(raw or "").strip())
|
||||
if not m:
|
||||
continue
|
||||
level = len(m.group(1))
|
||||
title = _clean_heading_title(m.group(2).strip())
|
||||
is_valid = (
|
||||
bool(title)
|
||||
and len(title) <= _MAX_SECTION_TITLE_LEN
|
||||
and _looks_like_real_heading_title(title)
|
||||
)
|
||||
if not is_valid:
|
||||
continue
|
||||
if len(counters) < level:
|
||||
counters.extend([0] * (level - len(counters)))
|
||||
else:
|
||||
counters = counters[:level]
|
||||
counters[level - 1] += 1
|
||||
for i in range(level, len(counters)):
|
||||
counters[i] = 0
|
||||
num = ".".join(str(counters[i]) for i in range(level))
|
||||
full_title = f"{num} {title}"
|
||||
canonical = f"{num}|{title}".lower()
|
||||
if canonical in seen:
|
||||
continue
|
||||
seen.add(canonical)
|
||||
accepted.append(
|
||||
{
|
||||
"number": num,
|
||||
"title": title,
|
||||
"full_title": full_title,
|
||||
"canonical": canonical,
|
||||
"section_key": _canonical_to_section_key(canonical, len(accepted) + 1),
|
||||
"level": level,
|
||||
"start_idx": idx,
|
||||
}
|
||||
)
|
||||
|
||||
total = len(lines)
|
||||
for i, sec in enumerate(accepted):
|
||||
body_start = sec["start_idx"] + 1 # 排除本节标题行
|
||||
end = total
|
||||
for j in range(i + 1, len(accepted)):
|
||||
nxt = accepted[j]
|
||||
if include_sub:
|
||||
if nxt["level"] <= sec["level"]:
|
||||
end = nxt["start_idx"]
|
||||
break
|
||||
else:
|
||||
end = nxt["start_idx"]
|
||||
break
|
||||
sec["content"] = "\n".join(lines[body_start:end]).strip()
|
||||
sec.pop("start_idx", None)
|
||||
|
||||
return accepted
|
||||
|
||||
|
||||
def _extract_sections_from_markdown_headings(text: str) -> list[dict]:
|
||||
"""
|
||||
从 Markdown 标题(# / ## / ###)构建模板章节目录。
|
||||
复刻 eval_report 报告模板管理模块 services/template_service.py 的同名逻辑:
|
||||
标题层级自动编号(## 项目概况 -> 1.1 项目概况),并过滤非真实标题行。
|
||||
"""
|
||||
lines = str(text or "").splitlines()
|
||||
counters: list[int] = []
|
||||
out: list[dict] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for raw in lines:
|
||||
m = re.match(r"^(#{1,6})\s+(.+)$", str(raw or "").strip())
|
||||
if not m:
|
||||
continue
|
||||
level = len(m.group(1))
|
||||
title = _clean_heading_title(m.group(2).strip())
|
||||
if not title or len(title) > _MAX_SECTION_TITLE_LEN:
|
||||
continue
|
||||
if not _looks_like_real_heading_title(title):
|
||||
continue
|
||||
if len(counters) < level:
|
||||
counters.extend([0] * (level - len(counters)))
|
||||
else:
|
||||
counters = counters[:level]
|
||||
counters[level - 1] += 1
|
||||
for i in range(level, len(counters)):
|
||||
counters[i] = 0
|
||||
num = ".".join(str(counters[i]) for i in range(level))
|
||||
full_title = f"{num} {title}"
|
||||
canonical = f"{num}|{title}".lower()
|
||||
if canonical in seen:
|
||||
continue
|
||||
seen.add(canonical)
|
||||
out.append(
|
||||
_section_dict(
|
||||
_canonical_to_section_key(canonical, len(out) + 1),
|
||||
full_title,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def extract_sections_from_text(text: str) -> list[dict]:
|
||||
"""抽取模板章节目录(入库 report_template_sections)。
|
||||
|
||||
复刻 eval_report 报告模板管理模块的逻辑:优先按 Markdown 标题层级识别,
|
||||
命中数 >= 8 时直接采用;否则回退到目录/编号行识别。"""
|
||||
md_sections = _extract_sections_from_markdown_headings(text)
|
||||
if len(md_sections) >= 8:
|
||||
return md_sections
|
||||
|
||||
lines = str(text or "").splitlines()
|
||||
out: list[dict] = []
|
||||
seen: set[str] = set()
|
||||
candidates: list[dict] = []
|
||||
|
||||
for raw in lines:
|
||||
line = str(raw or "").strip()
|
||||
if not line:
|
||||
continue
|
||||
line = re.sub(r"^#{1,6}\s*", "", line).strip()
|
||||
line = line.replace("\u3000", " ")
|
||||
line = re.sub(r"\s+", " ", line).strip()
|
||||
|
||||
if re.match(r"^20\d{2}\s*年\s*\d{1,2}\s*月$", line):
|
||||
continue
|
||||
if line in {"目次", "目录"}:
|
||||
continue
|
||||
if re.match(r"^\d+\s*[)\)]\s*.+$", line):
|
||||
continue
|
||||
|
||||
has_page_no = bool(re.search(r"\s+\d+\s*$", line))
|
||||
m = re.match(r"^((?:\d+(?:\.\d+){0,5}))\s*([^\s].*)$", line)
|
||||
if m:
|
||||
num = m.group(1).strip()
|
||||
if not _is_valid_section_number(num):
|
||||
continue
|
||||
rest = _clean_heading_title(m.group(2).strip())
|
||||
if not rest or rest.startswith(")") or rest.startswith(")"):
|
||||
continue
|
||||
if _rest_looks_like_body_text(rest):
|
||||
continue
|
||||
if len(rest) > _MAX_SECTION_TITLE_LEN:
|
||||
continue
|
||||
full_title = f"{num} {rest}"[:_MAX_SECTION_TITLE_LEN].rstrip()
|
||||
canonical = f"{num}|{rest}".lower()
|
||||
else:
|
||||
m2 = re.match(r"^([一二三四五六七八九十]+[、..])\s*([^\s].*)$", line)
|
||||
if not m2:
|
||||
continue
|
||||
rest2 = _clean_heading_title(m2.group(2).strip())
|
||||
if not rest2 or _rest_looks_like_body_text(rest2) or len(rest2) > _MAX_SECTION_TITLE_LEN:
|
||||
continue
|
||||
full_title = f"{m2.group(1)} {rest2}"[:_MAX_SECTION_TITLE_LEN].rstrip()
|
||||
canonical = f"{m2.group(1)}|{rest2}".lower()
|
||||
candidates.append({"canonical": canonical, "title": full_title, "has_page_no": has_page_no})
|
||||
|
||||
use_toc_only = False
|
||||
toc_rows = [c for c in candidates if c["has_page_no"]]
|
||||
toc_nums = set()
|
||||
for c in toc_rows:
|
||||
m_num = re.match(r"^(\d+)", c["title"])
|
||||
if m_num:
|
||||
toc_nums.add(m_num.group(1))
|
||||
if len(toc_rows) >= 20 and {"1", "2", "3", "4", "5", "6", "7"}.issubset(toc_nums):
|
||||
use_toc_only = True
|
||||
|
||||
picked = toc_rows if use_toc_only else candidates
|
||||
for c in picked:
|
||||
canonical = c["canonical"]
|
||||
if canonical in seen:
|
||||
continue
|
||||
if not _looks_like_real_heading_title(c["title"]):
|
||||
continue
|
||||
seen.add(canonical)
|
||||
out.append(_section_dict(_canonical_to_section_key(canonical, len(out) + 1), c["title"]))
|
||||
return out
|
||||
|
||||
|
||||
# ────────────────────────────── 正文按标题拆分 ──────────────────────────────
|
||||
|
||||
|
||||
def split_markdown_into_sections(text: str) -> list[dict[str, str]]:
|
||||
"""
|
||||
按 Markdown 标题切分正文,与目录抽取共用同一套标题识别与自动编号,
|
||||
保证每个目录章节都能拿到对应正文。section_key 为自动编号(如 1.1)。
|
||||
|
||||
若文档不含 Markdown 标题,则回退到"带编号标题"的拆分方式。
|
||||
返回 [{section_key, section_title, content}, ...]。
|
||||
"""
|
||||
walk = _walk_markdown_heading_sections(text)
|
||||
if walk:
|
||||
return [
|
||||
{
|
||||
"section_key": s["number"],
|
||||
"section_title": s["full_title"],
|
||||
"content": s["content"],
|
||||
}
|
||||
for s in walk
|
||||
]
|
||||
return split_markdown_by_headings(text)
|
||||
|
||||
|
||||
def split_markdown_by_headings(text: str) -> list[dict[str, str]]:
|
||||
"""
|
||||
按 Markdown 标题(# ~ ####,且带章节编号,如 ## 1.1 标题)拆分正文。
|
||||
返回 [{section_key, section_title, content}, ...],section_key 为编号(如 1.1)。
|
||||
"""
|
||||
lines = str(text or "").splitlines()
|
||||
heading_pattern = re.compile(r"^#{1,4}\s+(\d+(?:\.\d+)*)\s+(.+)")
|
||||
|
||||
sections: list[dict[str, str]] = []
|
||||
current_key: str | None = None
|
||||
current_title: str | None = None
|
||||
current_lines: list[str] = []
|
||||
|
||||
for line in lines:
|
||||
m = heading_pattern.match(line)
|
||||
if m:
|
||||
if current_key and current_lines:
|
||||
sections.append({
|
||||
"section_key": current_key,
|
||||
"section_title": current_title or "",
|
||||
"content": "\n".join(current_lines).strip(),
|
||||
})
|
||||
current_key = m.group(1)
|
||||
current_title = m.group(2).strip()
|
||||
current_lines = [line]
|
||||
else:
|
||||
if current_key:
|
||||
current_lines.append(line)
|
||||
|
||||
if current_key and current_lines:
|
||||
sections.append({
|
||||
"section_key": current_key,
|
||||
"section_title": current_title or "",
|
||||
"content": "\n".join(current_lines).strip(),
|
||||
})
|
||||
|
||||
return sections
|
||||
|
||||
|
||||
def parse_section_order(section_key: str) -> int:
|
||||
"""将 '1.2.1' 转为整数 121 用于排序。"""
|
||||
digits = str(section_key or "").replace(".", "")
|
||||
return int(digits) if digits.isdigit() else 0
|
||||
|
||||
|
||||
def clamp_text_bytes(text: str, max_bytes: int, *, suffix: str = "\n…(内容过长,已截断)") -> str:
|
||||
"""
|
||||
将文本按 UTF-8 字节数截断到 max_bytes 以内,且不会截断到半个字符。
|
||||
用于适配 MySQL TEXT 列(最大 65535 字节)。
|
||||
"""
|
||||
if not text or max_bytes <= 0:
|
||||
return text
|
||||
data = text.encode("utf-8")
|
||||
if len(data) <= max_bytes:
|
||||
return text
|
||||
suffix_bytes = len(suffix.encode("utf-8"))
|
||||
budget = max(max_bytes - suffix_bytes, 0)
|
||||
# errors="ignore" 会丢弃末尾被切断的不完整字符,保证是合法 UTF-8
|
||||
truncated = data[:budget].decode("utf-8", errors="ignore").rstrip()
|
||||
return truncated + suffix
|
||||
1239
services/template_prompt_mapper.py
Normal file
1239
services/template_prompt_mapper.py
Normal file
File diff suppressed because it is too large
Load Diff
524
services/template_service.py
Normal file
524
services/template_service.py
Normal file
@ -0,0 +1,524 @@
|
||||
"""
|
||||
services/template_service.py
|
||||
复刻自 eval_report:report_template_sections 数据的获取方式。
|
||||
|
||||
- DEFAULT_TEMPLATE_SECTIONS:系统默认后评价报告章节目录(key, title)
|
||||
- default_section_prompt / default_section_output_contract / default_section_examples:
|
||||
按章节标题/编号取对应提示词、输出合同、示例
|
||||
- build_default_template_catalog:默认目录 + 提示词/合同(供上传模版匹配)
|
||||
|
||||
说明:eval_report 会额外从《编制细则》与《模版》Word 文档抽取更细的提示词/示例;
|
||||
本项目默认不含这两个 .doc 文件与 DocParser,故相关函数在缺文件时优雅降级,
|
||||
回退到 SECTION_PROMPT_RULES / SECTION_EXAMPLE_RULES。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database.models import ReportTemplate, ReportTemplateSection
|
||||
from prompts.report_generation.section_output_contracts import (
|
||||
DEFAULT_SECTION_OUTPUT_CONTRACT,
|
||||
SECTION_OUTPUT_CONTRACTS,
|
||||
)
|
||||
from prompts.report_generation.template_prompt_rules import (
|
||||
DEFAULT_SECTION_PROMPT,
|
||||
SECTION_EXAMPLE_RULES,
|
||||
SECTION_PROMPT_RULES,
|
||||
)
|
||||
|
||||
SYSTEM_DEFAULT_TEMPLATE_NAME = "后评价默认模板"
|
||||
GUIDELINE_BASENAME = "炼油化工建设项目后评价报告编制细则(修订)"
|
||||
PROJECT_EXAMPLE_BASENAME = "模版"
|
||||
MAX_SECTION_EXAMPLE_CHARS = 12000
|
||||
|
||||
DEFAULT_TEMPLATE_SECTIONS: list[tuple[str, str]] = [
|
||||
("1", "1 项目概况"),
|
||||
("1-1", "1.1 项目基本情况"),
|
||||
("1-2", "1.2 项目决策要点"),
|
||||
("1-3", "1.3 项目实施情况"),
|
||||
("1-4", "1.4 项目运行情况"),
|
||||
("2", "2 前期工作评价"),
|
||||
("2-1", "2.1 项目要素评价"),
|
||||
("2-1-1", "2.1.1 资源与原料评价"),
|
||||
("2-1-2", "2.1.2 产品方案及市场评价"),
|
||||
("2-1-2-1", "2.1.2.1 产品方案评价"),
|
||||
("2-1-2-2", "2.1.2.2 产品市场评价"),
|
||||
("2-1-3", "2.1.3 工艺方案评价"),
|
||||
("2-1-3-1", "2.1.3.1 总加工方案评价"),
|
||||
("2-1-3-2", "2.1.3.2 建设规模及工艺技术方案评价"),
|
||||
("2-1-3-3", "2.1.3.3 主要设备方案评价"),
|
||||
("2-1-4", "2.1.4 厂址选择及外部条件评价"),
|
||||
("2-1-5", "2.1.5 总图及系统配套工程评价"),
|
||||
("2-1-6", "2.1.6 主要技术指标评价"),
|
||||
("2-1-7", "2.1.7 风险分析评价"),
|
||||
("2-2", "2.2 工作程序评价"),
|
||||
("2-2-1", "2.2.1 编制单位资质及选择方式评价"),
|
||||
("2-2-2", "2.2.2 编制进度评价"),
|
||||
("2-2-3", "2.2.3 与专项评价的结合情况"),
|
||||
("2-2-4", "2.2.4 可行性研究报告的质量评价"),
|
||||
("2-3", "2.3 前评估工作评价"),
|
||||
("2-4", "2.4 初步设计评价"),
|
||||
("2-4-1", "2.4.1 设计单位资质及选择方式评价"),
|
||||
("2-4-2", "2.4.2 初步设计进度评价"),
|
||||
("2-4-3", "2.4.3 初步设计质量评价"),
|
||||
("2-4-4", "2.4.4 初步设计审查工作评价"),
|
||||
("2-5", "2.5 前期决策程序评价"),
|
||||
("2-6", "2.6 前期工作评价结论"),
|
||||
("3", "3 建设实施评价"),
|
||||
("3-1", "3.1 工程建设管理模式评价"),
|
||||
("3-2", "3.2 招投标评价"),
|
||||
("3-3", "3.3 施工图设计评价"),
|
||||
("3-3-1", "3.3.1 与批复后初步设计符合性评价"),
|
||||
("3-3-2", "3.3.2 设计进度评价"),
|
||||
("3-3-3", "3.3.3 施工图设计水平及质量评价"),
|
||||
("3-3-4", "3.3.4 施工图设计变更管理评价"),
|
||||
("3-4", "3.4 工程承包商或施工单位评价"),
|
||||
("3-4-1", "3.4.1 施工准备评价"),
|
||||
("3-4-2", "3.4.2 施工计划的执行情况"),
|
||||
("3-5", "3.5 采购工作评价"),
|
||||
("3-6", "3.6 工程监理评价"),
|
||||
("3-7", "3.7 工程质量评价"),
|
||||
("3-8", "3.8 HSE管理评价"),
|
||||
("3-9", "3.9 三查四定及中间交接"),
|
||||
("3-10", "3.10 工程竣工验收评价"),
|
||||
("3-11", "3.11 建设实施评价结论"),
|
||||
("4", "4 生产运行评价"),
|
||||
("4-1", "4.1 生产准备评价"),
|
||||
("4-2", "4.2 联合试运与试生产情况评价"),
|
||||
("4-3", "4.3 生产运行评价"),
|
||||
("4-3-1", "4.3.1 原料供应评价"),
|
||||
("4-3-2", "4.3.2 生产运行总体情况评价"),
|
||||
("4-3-3", "4.3.3 达标评价"),
|
||||
("4-3-4", "4.3.4 生产工艺技术评价"),
|
||||
("4-3-5", "4.3.5 设备运行评价"),
|
||||
("4-3-6", "4.3.6 公用工程及辅助设施合理性评价"),
|
||||
("4-4", "4.4 生产运行评价结论"),
|
||||
("5", "5 投资与经济效益评价"),
|
||||
("5-1", "5.1 主要经济指标实现程度评价"),
|
||||
("5-2", "5.2 投资和执行情况评价"),
|
||||
("5-2-1", "5.2.1 投资控制及变动原因分析"),
|
||||
("5-2-2", "5.2.2 投资水平分析"),
|
||||
("5-2-3", "5.2.3 资金来源及到位评价"),
|
||||
("5-2-4", "5.2.4 投资控制的经验和教训"),
|
||||
("5-3", "5.3 经济效益分析"),
|
||||
("5-3-1", "5.3.1 项目投产以来生产经营及效益状况"),
|
||||
("5-3-2", "5.3.2 项目经济效益后评价"),
|
||||
("5-4", "5.4 不确定性分析"),
|
||||
("5-5", "5.5 投资与经济效益评价结论"),
|
||||
("6", "6 影响与持续性评价"),
|
||||
("6-1", "6.1 影响评价"),
|
||||
("6-1-1", "6.1.1 环境影响评价"),
|
||||
("6-1-2", "6.1.2 安全影响评价"),
|
||||
("6-1-3", "6.1.3 科技进步影响"),
|
||||
("6-1-4", "6.1.4 项目社会影响评价"),
|
||||
("6-1-5", "6.1.5 项目影响评价结论"),
|
||||
("6-2", "6.2 持续性评价"),
|
||||
("6-2-1", "6.2.1 资源分析"),
|
||||
("6-2-2", "6.2.2 产品分析"),
|
||||
("6-2-3", "6.2.3 主要技术及经济指标对比"),
|
||||
("6-2-4", "6.2.4 项目持续性评价结论"),
|
||||
("7", "7 综合评价结论"),
|
||||
("7-1", "7.1 综合评价结论"),
|
||||
("7-1-1", "7.1.1 总体评价结论"),
|
||||
("7-1-2", "7.1.2 成功度评价"),
|
||||
("7-2", "7.2 主要经验"),
|
||||
("7-3", "7.3 问题与建议"),
|
||||
]
|
||||
|
||||
|
||||
def default_section_output_contract(section_title: str, section_key: str | None = None) -> str:
|
||||
section_no = _extract_number_prefix(section_title) or _section_key_to_number(section_key)
|
||||
if section_no and section_no in SECTION_OUTPUT_CONTRACTS:
|
||||
return SECTION_OUTPUT_CONTRACTS[section_no]
|
||||
return DEFAULT_SECTION_OUTPUT_CONTRACT
|
||||
|
||||
|
||||
def default_section_prompt(section_title: str, section_key: str | None = None) -> str:
|
||||
guideline_prompt = _guideline_prompt_for(section_title, section_key)
|
||||
if guideline_prompt:
|
||||
return guideline_prompt
|
||||
|
||||
title = _normalize_section_identity(section_title)
|
||||
key = str(section_key or "").strip().lower()
|
||||
for pattern, prompt in SECTION_PROMPT_RULES:
|
||||
p = pattern.lower()
|
||||
if title.startswith(p):
|
||||
return prompt
|
||||
if p.isdigit() and (title.startswith(f"{p} ") or key.startswith(f"{p}-") or key == p):
|
||||
return prompt
|
||||
return DEFAULT_SECTION_PROMPT
|
||||
|
||||
|
||||
def build_default_template_catalog() -> list[dict[str, str]]:
|
||||
"""系统默认模板章节目录及对应提示词、输出合同(供上传模版匹配)。"""
|
||||
out: list[dict[str, str]] = []
|
||||
for key, title in DEFAULT_TEMPLATE_SECTIONS:
|
||||
out.append(
|
||||
{
|
||||
"sectionKey": key,
|
||||
"sectionTitle": title,
|
||||
"sectionNumber": _extract_number_prefix(title) or _section_key_to_number(key),
|
||||
"sectionPrompt": default_section_prompt(title, key),
|
||||
"sectionOutputContract": default_section_output_contract(title, key),
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def default_section_examples(section_title: str, section_key: str | None = None) -> str:
|
||||
project_example = _project_example_for(section_title, section_key)
|
||||
if project_example:
|
||||
return project_example
|
||||
|
||||
title = _normalize_section_identity(section_title)
|
||||
key = str(section_key or "").strip().lower()
|
||||
num = _extract_number_prefix(section_title) or _section_key_to_number(section_key)
|
||||
chapter_no = ""
|
||||
if num:
|
||||
chapter_no = num.split(".")[0]
|
||||
elif key:
|
||||
chapter_no = key.split("-")[0]
|
||||
for prefix, examples in SECTION_EXAMPLE_RULES:
|
||||
p = str(prefix).strip().lower()
|
||||
if chapter_no == p:
|
||||
return examples
|
||||
if title.startswith(f"{p} "):
|
||||
return examples
|
||||
if key.startswith(f"{p}-") or key == p:
|
||||
return examples
|
||||
return ""
|
||||
|
||||
|
||||
def _normalize_section_identity(value: str | None) -> str:
|
||||
text = str(value or "").strip().lower()
|
||||
text = text.replace(".", ".").replace("。", ".")
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text
|
||||
|
||||
|
||||
def _section_key_to_number(section_key: str | None) -> str:
|
||||
key = str(section_key or "").strip()
|
||||
if not key:
|
||||
return ""
|
||||
if re.fullmatch(r"\d+(?:-\d+)*", key):
|
||||
return key.replace("-", ".")
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_number_prefix(title: str) -> str:
|
||||
m = re.match(r"^\s*(\d+(?:\.\d+)*)\s*", str(title or ""))
|
||||
return m.group(1) if m else ""
|
||||
|
||||
|
||||
def _normalize_heading_key(value: str) -> str:
|
||||
s = str(value or "").strip().lower()
|
||||
s = s.replace(".", ".").replace("。", ".")
|
||||
s = re.sub(r"\s+", "", s)
|
||||
return s
|
||||
|
||||
|
||||
def _tuple_from_number(number_str: str) -> tuple[int, ...]:
|
||||
if not number_str:
|
||||
return tuple()
|
||||
parts = []
|
||||
for p in number_str.split("."):
|
||||
if p.isdigit():
|
||||
parts.append(int(p))
|
||||
else:
|
||||
return tuple()
|
||||
return tuple(parts)
|
||||
|
||||
|
||||
def _read_doc_text(path: str) -> str:
|
||||
"""读取 .doc/.docx 文本。本项目无 DocParser 时返回空串(优雅降级)。"""
|
||||
try:
|
||||
from function.documents.doc_parser import DocParser # type: ignore
|
||||
except Exception:
|
||||
return ""
|
||||
try:
|
||||
return DocParser(path).read()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _guideline_section_prompt_map() -> dict[str, str]:
|
||||
guideline_path = _resolve_guideline_path()
|
||||
if not guideline_path:
|
||||
return {}
|
||||
raw_text = _read_doc_text(guideline_path)
|
||||
if not raw_text:
|
||||
return {}
|
||||
return _build_guideline_prompt_map(raw_text)
|
||||
|
||||
|
||||
def _resolve_guideline_path() -> str | None:
|
||||
root = Path(__file__).resolve().parents[1]
|
||||
candidates = [
|
||||
root / f"{GUIDELINE_BASENAME}.doc",
|
||||
root / f"{GUIDELINE_BASENAME}.docx",
|
||||
]
|
||||
for p in candidates:
|
||||
if p.is_file():
|
||||
return str(p)
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_project_example_path() -> str | None:
|
||||
root = Path(__file__).resolve().parents[1]
|
||||
candidates = [
|
||||
root / f"{PROJECT_EXAMPLE_BASENAME}.doc",
|
||||
root / f"{PROJECT_EXAMPLE_BASENAME}.docx",
|
||||
]
|
||||
for p in candidates:
|
||||
if p.is_file():
|
||||
return str(p)
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _project_example_entries() -> list[tuple[str, str]]:
|
||||
path = _resolve_project_example_path()
|
||||
if not path:
|
||||
return []
|
||||
raw_text = _read_doc_text(path)
|
||||
if not raw_text:
|
||||
return []
|
||||
return _build_project_example_entries(raw_text)
|
||||
|
||||
|
||||
def _build_project_example_entries(text: str) -> list[tuple[str, str]]:
|
||||
lines = str(text or "").splitlines()
|
||||
headings: list[tuple[int, int, str]] = []
|
||||
for idx, raw in enumerate(lines):
|
||||
line = str(raw or "").strip()
|
||||
m = re.match(r"^\s*(#{1,6})\s*(.+?)\s*$", line)
|
||||
if not m:
|
||||
continue
|
||||
level = len(m.group(1))
|
||||
heading_title = m.group(2).strip()
|
||||
if not heading_title:
|
||||
continue
|
||||
headings.append((idx, level, heading_title))
|
||||
|
||||
out: list[tuple[str, str]] = []
|
||||
for i, (start_idx, level, title) in enumerate(headings):
|
||||
end_idx = len(lines)
|
||||
for j in range(i + 1, len(headings)):
|
||||
next_idx, next_level, _ = headings[j]
|
||||
if next_level <= level:
|
||||
end_idx = next_idx
|
||||
break
|
||||
body = "\n".join(lines[start_idx + 1 : end_idx]).strip()
|
||||
body = re.sub(r"\n{3,}", "\n\n", body)
|
||||
if not body:
|
||||
continue
|
||||
out.append((title, body))
|
||||
return out
|
||||
|
||||
|
||||
def _project_example_for(section_title: str, section_key: str | None = None) -> str:
|
||||
entries = _project_example_entries()
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
target_title = _clean_section_title(section_title)
|
||||
target_key = _section_key_to_number(section_key)
|
||||
target_core = _core_title(target_title or target_key)
|
||||
if not target_core:
|
||||
return ""
|
||||
|
||||
best_title = ""
|
||||
best_body = ""
|
||||
best_score = -1
|
||||
for heading, body in entries:
|
||||
heading_clean = _clean_section_title(heading)
|
||||
heading_core = _core_title(heading_clean)
|
||||
score = _title_match_score(target_core, heading_core)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_title = heading_clean
|
||||
best_body = body
|
||||
|
||||
if best_score < 4 or not best_body:
|
||||
return ""
|
||||
|
||||
text = f"### {best_title}\n\n{best_body}".strip()
|
||||
if len(text) > MAX_SECTION_EXAMPLE_CHARS:
|
||||
text = text[:MAX_SECTION_EXAMPLE_CHARS].rstrip() + "\n\n(示例过长,已截断)"
|
||||
return text
|
||||
|
||||
|
||||
def _clean_section_title(value: str | None) -> str:
|
||||
s = str(value or "").strip()
|
||||
s = re.sub(r"^\s*\d+(?:[.\-]\d+)*\s*", "", s)
|
||||
return s.strip()
|
||||
|
||||
|
||||
def _core_title(value: str | None) -> str:
|
||||
s = str(value or "").strip()
|
||||
s = s.replace("(", "(").replace(")", ")")
|
||||
s = re.sub(r"\([^)]*\)", "", s)
|
||||
s = re.sub(r"[、,。;::()()\-\s]", "", s)
|
||||
s = s.replace("项目", "")
|
||||
s = s.replace("情况", "")
|
||||
s = s.replace("工作", "")
|
||||
return s.strip().lower()
|
||||
|
||||
|
||||
def _title_match_score(target: str, candidate: str) -> int:
|
||||
if not target or not candidate:
|
||||
return 0
|
||||
if target == candidate:
|
||||
return 100
|
||||
score = 0
|
||||
if target in candidate or candidate in target:
|
||||
score += 40
|
||||
tks_t = re.findall(r"[\u4e00-\u9fa5]{2,8}|[a-z]{2,12}", target)
|
||||
tks_c = re.findall(r"[\u4e00-\u9fa5]{2,8}|[a-z]{2,12}", candidate)
|
||||
if tks_t and tks_c:
|
||||
overlap = len(set(tks_t) & set(tks_c))
|
||||
score += overlap * 8
|
||||
ch_overlap = len(set(target) & set(candidate))
|
||||
score += min(ch_overlap, 20)
|
||||
return score
|
||||
|
||||
|
||||
def _build_guideline_prompt_map(text: str) -> dict[str, str]:
|
||||
lines = str(text or "").splitlines()
|
||||
headings: list[tuple[int, str, str, tuple[int, ...]]] = []
|
||||
for idx, raw in enumerate(lines):
|
||||
line = str(raw or "").strip()
|
||||
m = re.match(r"^\s*#{1,6}\s*(.+?)\s*$", line)
|
||||
if not m:
|
||||
continue
|
||||
heading_title = m.group(1).strip()
|
||||
number = _extract_number_prefix(heading_title)
|
||||
number_tuple = _tuple_from_number(number)
|
||||
if not number_tuple:
|
||||
continue
|
||||
headings.append((idx, heading_title, number, number_tuple))
|
||||
|
||||
prompt_map: dict[str, str] = {}
|
||||
for i, (start_idx, heading_title, number, number_tuple) in enumerate(headings):
|
||||
end_idx = len(lines)
|
||||
for j in range(i + 1, len(headings)):
|
||||
next_start, _, _, next_tuple = headings[j]
|
||||
if len(next_tuple) < len(number_tuple) or next_tuple[: len(number_tuple)] != number_tuple:
|
||||
end_idx = next_start
|
||||
break
|
||||
body = "\n".join(lines[start_idx + 1 : end_idx]).strip()
|
||||
body = re.sub(r"\n{3,}", "\n\n", body)
|
||||
if not body:
|
||||
continue
|
||||
key_title = _normalize_heading_key(heading_title)
|
||||
key_number = _normalize_heading_key(number)
|
||||
prompt_map[key_title] = body
|
||||
prompt_map[key_number] = body
|
||||
return prompt_map
|
||||
|
||||
|
||||
def _guideline_prompt_for(section_title: str, section_key: str | None = None) -> str:
|
||||
mapping = _guideline_section_prompt_map()
|
||||
if not mapping:
|
||||
return ""
|
||||
title = str(section_title or "").strip()
|
||||
number = _extract_number_prefix(title) or _section_key_to_number(section_key)
|
||||
candidates = [
|
||||
_normalize_heading_key(title),
|
||||
_normalize_heading_key(number),
|
||||
]
|
||||
for key in candidates:
|
||||
if key and key in mapping:
|
||||
return mapping[key]
|
||||
return ""
|
||||
|
||||
|
||||
def list_templates(db: Session) -> list[ReportTemplate]:
|
||||
return (
|
||||
db.query(ReportTemplate)
|
||||
.order_by(ReportTemplate.is_default.desc(), ReportTemplate.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def ensure_default_template(db: Session) -> None:
|
||||
now = datetime.now()
|
||||
system_default = (
|
||||
db.query(ReportTemplate)
|
||||
.filter(ReportTemplate.name == SYSTEM_DEFAULT_TEMPLATE_NAME)
|
||||
.first()
|
||||
)
|
||||
if not system_default:
|
||||
system_default = ReportTemplate(
|
||||
id=uuid.uuid4().hex,
|
||||
name=SYSTEM_DEFAULT_TEMPLATE_NAME,
|
||||
description="系统预置模板(细则完整章节)",
|
||||
is_default=True,
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(system_default)
|
||||
db.flush()
|
||||
|
||||
current_rows = (
|
||||
db.query(ReportTemplateSection)
|
||||
.filter(ReportTemplateSection.template_id == system_default.id)
|
||||
.order_by(ReportTemplateSection.section_order.asc())
|
||||
.all()
|
||||
)
|
||||
current_pairs = [(r.section_key, r.section_title) for r in current_rows]
|
||||
expected_pairs = list(DEFAULT_TEMPLATE_SECTIONS)
|
||||
|
||||
db.query(ReportTemplate).update({ReportTemplate.is_default: False})
|
||||
system_default.is_default = True
|
||||
system_default.is_active = True
|
||||
system_default.updated_at = now
|
||||
|
||||
if current_pairs == expected_pairs:
|
||||
has_changed = False
|
||||
for row in current_rows:
|
||||
current_examples = str(row.examples or "").strip()
|
||||
new_examples = default_section_examples(row.section_title, row.section_key).strip()
|
||||
if new_examples and current_examples != new_examples:
|
||||
row.examples = new_examples
|
||||
row.updated_at = now
|
||||
has_changed = True
|
||||
current_out = str(getattr(row, "section_output_contract", None) or "").strip()
|
||||
new_out = default_section_output_contract(row.section_title, row.section_key).strip()
|
||||
if not current_out and new_out:
|
||||
row.section_output_contract = new_out
|
||||
row.updated_at = now
|
||||
has_changed = True
|
||||
if has_changed:
|
||||
system_default.updated_at = now
|
||||
db.commit()
|
||||
return
|
||||
|
||||
db.query(ReportTemplateSection).filter(
|
||||
ReportTemplateSection.template_id == system_default.id
|
||||
).delete()
|
||||
for i, (key, title) in enumerate(DEFAULT_TEMPLATE_SECTIONS):
|
||||
db.add(
|
||||
ReportTemplateSection(
|
||||
id=uuid.uuid4().hex,
|
||||
template_id=system_default.id,
|
||||
section_key=key,
|
||||
section_title=title,
|
||||
section_prompt="",
|
||||
section_output_contract=default_section_output_contract(title, key),
|
||||
section_order=i,
|
||||
examples=default_section_examples(title, key),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
Loading…
x
Reference in New Issue
Block a user